|
| 1 | + |
| 2 | +# 🧪 Experimental: Task Selection & Scheduling System |
| 3 | + |
| 4 | +```{note} |
| 5 | +This module is currently in **experimental status**. Interfaces may change in future versions. |
| 6 | +This document describes the functionality and intended usage of the system. |
| 7 | +``` |
| 8 | + |
| 9 | + |
| 10 | + |
| 11 | +## Overview |
| 12 | + |
| 13 | +This system enables **intelligent, adaptive task sampling** from multiple datasets (called *tasksets*) during exploration. It consists of two core components: |
| 14 | + |
| 15 | +1. **`Selector`** – Controls how individual samples are selected *within* each taskset. |
| 16 | +2. **`TasksetScheduler`** – Manages *which* tasksets contribute to each batch and coordinates their sampling. |
| 17 | + |
| 18 | +Together, they support advanced training strategies such as: |
| 19 | +- Curriculum learning (easy → hard) |
| 20 | +- Multi-task interleaving or mixing |
| 21 | +- Difficulty-aware sampling |
| 22 | +- Adaptive data selection based on model performance |
| 23 | + |
| 24 | +These capabilities allow you to train models more efficiently by focusing on informative or challenging examples. |
| 25 | + |
| 26 | + |
| 27 | + |
| 28 | +## Module 1: Selector – Customizable Data Selection |
| 29 | + |
| 30 | +A `Selector` determines **which tasks (samples) to select** from its associated dataset (`Taskset`). Beyond basic strategies like sequential or random access, it supports **adaptive algorithms** that adjust sampling based on feedback—such as sample difficulty, model confidence, or reward signals. |
| 31 | + |
| 32 | +### Built-in Selectors |
| 33 | + |
| 34 | +| Selector Type | Description | |
| 35 | +|---------------|-------------| |
| 36 | +| `sequential` | Returns samples in fixed order (0, 1, ..., N). | |
| 37 | +| `shuffle` | Shuffles the dataset once per epoch; then iterates sequentially. | |
| 38 | +| `random` | Randomly samples without replacement within each batch. Independent across batches. | |
| 39 | +| `offline_easy2hard` | Sorts samples by pre-defined features (e.g., loss, length), serving easier ones first, progressing to harder ones. | |
| 40 | +| `difficulty_based` *(custom example)* | Dynamically selects samples near a target difficulty level using probabilistic modeling. | |
| 41 | + |
| 42 | +You can also **implement your own custom selector** to enable adaptive or curriculum-based learning. |
| 43 | + |
| 44 | + |
| 45 | + |
| 46 | +### ✅ Step 1: Implement a Custom Selector |
| 47 | + |
| 48 | +To create a new selector, inherit from `BaseSelector` and implement the following methods: |
| 49 | + |
| 50 | +#### Required Methods |
| 51 | + |
| 52 | +| Method | Purpose | |
| 53 | +|-------|--------| |
| 54 | +| `get_indices(batch_size: int, return_extra_info=False) -> List[int]` | Return a list of sample indices to read next. | |
| 55 | +| `update(indices: List[int], values: List[float])` | Update internal state using feedback (e.g., rewards, losses). | |
| 56 | +| `state_dict() -> Dict` | Serialize current state for checkpointing. | |
| 57 | +| `load_state_dict(state_dict: Dict)` | Restore state from a saved dictionary. | |
| 58 | + |
| 59 | +#### Example: `DifficultyBasedSelector` |
| 60 | + |
| 61 | +This selector focuses on samples whose predicted performance is closest to a target (e.g., 90% success rate), effectively choosing "just right" difficulty tasks. |
| 62 | + |
| 63 | +```python |
| 64 | +@SELECTORS.register_module("difficulty_based") |
| 65 | +class DifficultyBasedSelector(BaseSelector): |
| 66 | + def __init__(self, data_source, config: TaskSelectorConfig) -> None: |
| 67 | + super().__init__(data_source, config) |
| 68 | + self.logger = get_logger("difficulty_based_selector") |
| 69 | + |
| 70 | + # Build difficulty estimator using two input features (e.g., correctness, uncertainty) |
| 71 | + self.diff_estimator = self.build_diff_estimator( |
| 72 | + data_source.dataset, config.feature_keys, config.kwargs |
| 73 | + ) |
| 74 | + self.current_index = 0 |
| 75 | + self.seed = config.seed |
| 76 | + |
| 77 | + # Configuration parameters |
| 78 | + self.do_sample = config.kwargs.get("do_sample", False) |
| 79 | + self.target_reward = config.kwargs.get("target_reward", 1.0) |
| 80 | + self.tau = config.kwargs.get("tau", 1.0) |
| 81 | + |
| 82 | + # ... detailed implementation |
| 83 | + |
| 84 | + def get_indices(self, batch_size, return_extra_info=False): |
| 85 | + # Compute scores based on proximity to target reward |
| 86 | + sampling_scores = self.get_scores() |
| 87 | + sampling_scores = torch.from_numpy(sampling_scores) |
| 88 | + |
| 89 | + if self.tau == 0: |
| 90 | + # Greedy: take top-k highest scoring samples |
| 91 | + selected_indices = torch.topk(sampling_scores, batch_size).indices |
| 92 | + else: |
| 93 | + # Stochastic: sample via softmax with temperature scaling |
| 94 | + sampling_logits = sampling_scores / self.tau |
| 95 | + sampling_logits -= sampling_logits.max() # Stability |
| 96 | + sampling_probabilities = torch.softmax(sampling_logits, dim=0) |
| 97 | + rng = torch.Generator().manual_seed(self.seed + self.current_index) |
| 98 | + selected_indices = torch.multinomial( |
| 99 | + sampling_probabilities, |
| 100 | + batch_size, |
| 101 | + replacement=False, |
| 102 | + generator=rng, |
| 103 | + ) |
| 104 | + |
| 105 | + self.current_index += batch_size |
| 106 | + |
| 107 | + if return_extra_info: |
| 108 | + # Optional debugging info |
| 109 | + extra_info = { |
| 110 | + "indices": selected_indices.tolist(), |
| 111 | + "scores": sampling_scores[selected_indices].tolist(), |
| 112 | + # ... other metadata |
| 113 | + } |
| 114 | + return selected_indices, extra_info |
| 115 | + else: |
| 116 | + return selected_indices |
| 117 | + |
| 118 | + def update(self, indices: List[int], values: List[float]) -> None: |
| 119 | + # Update difficulty model with observed rewards |
| 120 | + self.diff_estimator.update(indices, values) |
| 121 | + |
| 122 | + def state_dict(self) -> Dict: |
| 123 | + return {"current_index": self.current_index} |
| 124 | + |
| 125 | + def load_state_dict(self, state_dict: Dict) -> None: |
| 126 | + self.current_index = state_dict.get("current_index", 0) |
| 127 | +``` |
| 128 | + |
| 129 | +> 🔁 After defining your class, use `@SELECTORS.register_module("your_name")` so it can be referenced by name in configs. |
| 130 | +
|
| 131 | + |
| 132 | + |
| 133 | +### ✅ Step 2: Implement a Feedback Operator |
| 134 | + |
| 135 | +For adaptive selectors like `DifficultyBasedSelector`, you need to provide runtime feedback (e.g., task rewards). This is done via an **Experience Operator** that processes rollouts and computes metrics. |
| 136 | + |
| 137 | +> 📚 See the {ref}`Operator Development Guide<Operators>` for more on building custom experience processors. |
| 138 | +
|
| 139 | +The operator must output a metric under the key `trinity.common.constants.SELECTOR_METRIC`, structured as: |
| 140 | + |
| 141 | +```python |
| 142 | +{ |
| 143 | + SELECTOR_METRIC: { |
| 144 | + 0: { # taskset_id |
| 145 | + "indices": [10, 25, 43], |
| 146 | + "values": [0.8, 0.6, 0.9] # e.g., average reward |
| 147 | + }, |
| 148 | + 1: { ... } |
| 149 | + } |
| 150 | +} |
| 151 | +``` |
| 152 | + |
| 153 | +#### Example: Pass Rate Calculator |
| 154 | + |
| 155 | +```python |
| 156 | +@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator") |
| 157 | +class PassRateCalculator(ExperienceOperator): |
| 158 | + def __init__(self, **kwargs): |
| 159 | + pass |
| 160 | + |
| 161 | + def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: |
| 162 | + raw_metric = defaultdict(lambda: defaultdict(list)) |
| 163 | + |
| 164 | + for exp in exps: |
| 165 | + task_index = exp.info["task_index"] |
| 166 | + assert "taskset_id" in task_index and "index" in task_index |
| 167 | + raw_metric[task_index["taskset_id"]][task_index["index"]].append(exp.reward) |
| 168 | + |
| 169 | + metric = {} |
| 170 | + for taskset_id, task_metrics in raw_metric.items(): |
| 171 | + indices = [] |
| 172 | + reward_means = [] |
| 173 | + for idx, rewards in task_metrics.items(): |
| 174 | + indices.append(idx) |
| 175 | + reward_means.append(float(np.mean(rewards))) |
| 176 | + metric[taskset_id] = { |
| 177 | + "indices": indices, |
| 178 | + "values": reward_means, |
| 179 | + } |
| 180 | + |
| 181 | + return exps, {SELECTOR_METRIC: metric} |
| 182 | +``` |
| 183 | + |
| 184 | +This operator calculates the average reward per task and passes it back to the corresponding selector for updating difficulty estimates. |
| 185 | + |
| 186 | + |
| 187 | + |
| 188 | +### ✅ Step 3: Update Configuration |
| 189 | + |
| 190 | +After implementing your selector and operator, register them in the config file. |
| 191 | + |
| 192 | +#### Add the Operator to the Pipeline |
| 193 | + |
| 194 | +```yaml |
| 195 | +data_processor: |
| 196 | + experience_pipeline: |
| 197 | + operators: |
| 198 | + - name: pass_rate_calculator # Must match @register_module name |
| 199 | +``` |
| 200 | +
|
| 201 | +#### Configure the Taskset with Your Selector |
| 202 | +
|
| 203 | +```yaml |
| 204 | +buffer: |
| 205 | + explorer_input: |
| 206 | + tasksets: |
| 207 | + - name: my_taskset |
| 208 | + storage_type: file |
| 209 | + path: ./path/to/tasks |
| 210 | + task_selector: |
| 211 | + selector_type: difficulty_based # Matches @register_module name |
| 212 | + feature_keys: ["correct", "uncertainty"] |
| 213 | + kwargs: |
| 214 | + m: 16 |
| 215 | + lamb: 0.2 |
| 216 | + rho: 0.2 |
| 217 | + target_reward: 0.9 |
| 218 | + tau: 0.5 |
| 219 | + do_sample: true |
| 220 | +``` |
| 221 | +
|
| 222 | +> 💡 You can define multiple tasksets, each with its own selector type and configuration. |
| 223 | +
|
| 224 | +
|
| 225 | +
|
| 226 | +## Module 2: TasksetScheduler – Multi-Taskset Orchestration |
| 227 | +
|
| 228 | +The `TasksetScheduler` manages **how different tasksets are interleaved or mixed** during training. |
| 229 | + |
| 230 | +### Key Features |
| 231 | + |
| 232 | +- Supports **multiple tasksets** simultaneously. |
| 233 | +- Balances sampling proportionally to dataset sizes. |
| 234 | +- **Shuffles taskset access order** at the start of each epoch. |
| 235 | +- Enables **curriculum-style** or **interleaved multi-task training**. |
| 236 | +- Fully **checkpointable**: resumes exactly where it left off. |
| 237 | +- Integrates with any registered `Selector`. |
| 238 | + |
| 239 | +### How It Works |
| 240 | + |
| 241 | +At each training step: |
| 242 | +1. Determines which tasksets should contribute to the current batch. |
| 243 | +2. Queries each taskset’s selector to get specific sample indices. |
| 244 | +3. Reads the actual data asynchronously. |
| 245 | +4. Tags each task with `"taskset_id"` for downstream routing or analysis. |
| 246 | + |
| 247 | +Epochs are defined based on total data volume and batch size: |
| 248 | +```python |
| 249 | +steps_per_epoch = total_samples // batch_size |
| 250 | +``` |
| 251 | + |
| 252 | +At the beginning of each epoch, the scheduler reshuffles the sequence of taskset accesses to introduce variability. |
| 253 | + |
| 254 | + |
| 255 | + |
| 256 | +## Summary |
| 257 | + |
| 258 | +With these components, you can: |
| 259 | +- Use simple strategies like random or sequential sampling. |
| 260 | +- Design **adaptive curricula** using custom selectors. |
| 261 | +- Combine multiple datasets intelligently. |
| 262 | +- Optimize training efficiency by focusing on high-value samples. |
| 263 | + |
| 264 | +By combining smart `Selectors` with the flexible `TasksetScheduler`, you gain fine-grained control over what your model sees—and when. |
0 commit comments