Skip to content

Commit ab66fae

Browse files
authored
Move Selector into TaskFileReader (#486)
1 parent 33a94bb commit ab66fae

File tree

8 files changed

+62
-52
lines changed

8 files changed

+62
-52
lines changed

docs/sphinx_doc/source/tutorial/develop_selector.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ To create a new selector, inherit from `BaseSelector` and implement the followin
5151
| Method | Purpose |
5252
|-------|--------|
5353
| `get_indices(batch_size: int, return_extra_info=False) -> List[int]` | Return a list of sample indices to read next. |
54-
| `update(indices: List[int], values: List[float])` | Update internal state using feedback (e.g., rewards, losses). |
54+
| `feedback(indices: List[int], values: List[float])` | Update internal state using feedback (e.g., rewards, losses). |
5555
| `state_dict() -> Dict` | Serialize current state for checkpointing. |
5656
| `load_state_dict(state_dict: Dict)` | Restore state from a saved dictionary. |
5757

@@ -113,7 +113,7 @@ class DifficultyBasedSelector(BaseSelector):
113113
else:
114114
return selected_indices
115115

116-
def update(self, indices: List[int], values: List[float]) -> None:
116+
def feedback(self, indices: List[int], values: List[float]) -> None:
117117
# Update difficulty model with observed rewards
118118
self.diff_estimator.update(indices, values)
119119

docs/sphinx_doc/source_zh/tutorial/develop_selector.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
| 方法 | 功能说明 |
5050
|------|---------|
5151
| `get_indices(batch_size: int, return_extra_info=False) -> List[int]` | 返回接下来要读取的样本索引列表。 |
52-
| `update(indices: List[int], values: List[float])` | 使用反馈信息(如奖励、损失)更新内部状态,用于自适应调整。 |
52+
| `feedback(indices: List[int], values: List[float])` | 使用反馈信息(如奖励、损失)更新内部状态,用于自适应调整。 |
5353
| `state_dict() -> Dict` | 序列化当前状态,用于保存检查点。 |
5454
| `load_state_dict(state_dict: Dict)` | 从保存的状态字典中恢复选择器状态。 |
5555

@@ -111,7 +111,7 @@ class DifficultyBasedSelector(BaseSelector):
111111
else:
112112
return selected_indices
113113

114-
def update(self, indices: List[int], values: List[float]) -> None:
114+
def feedback(self, indices: List[int], values: List[float]) -> None:
115115
# 使用观测到的奖励更新难度模型
116116
self.diff_estimator.update(indices, values)
117117

tests/buffer/task_scheduler_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ async def test_task_scheduler_simple(self):
340340
self.assertEqual(len(task_scheduler_state), 1)
341341
self.assertEqual(task_scheduler_state[0]["current_index"], 4)
342342
# no effect
343-
task_scheduler.update({"metric1": 0.5})
343+
task_scheduler.feedback({"metric1": 0.5})
344344

345345
task_scheduler = get_taskset_scheduler(
346346
{

trinity/buffer/reader/file_reader.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,14 @@ def select_batch(self, indices: List[int]) -> List:
7979
batch = []
8080
for i in indices:
8181
assert 0 <= i < self.dataset_size
82+
if self.current_offset >= self.total_samples:
83+
if not self.drop_last and len(batch) > 0:
84+
break
85+
self.progress_bar.close()
86+
raise StopIteration
8287
batch.append(self.dataset[int(i)])
88+
self.current_offset += 1
89+
8390
self.progress_bar.update(len(batch)) # update progress bar
8491
return batch
8592

@@ -104,20 +111,16 @@ def __init__(self, config: StorageConfig):
104111
def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
105112
return self.reader.read(batch_size)
106113

107-
def read_with_indices(self, indices: List[int]) -> List:
108-
"""Read tasks with indices."""
109-
return self.reader.read_with_indices(indices)
110-
111-
async def read_with_indices_async(self, indices: List[int]) -> List:
112-
"""Read tasks with indices asynchronously."""
113-
return await self.reader.read_with_indices_async(indices)
114-
115114
def state_dict(self):
116115
return self.reader.state_dict()
117116

118117
def load_state_dict(self, state_dict):
119118
return self.reader.load_state_dict(state_dict)
120119

120+
def feedback(self, **pipeline_metrics):
121+
if self.reader.selector is not None:
122+
self.reader.selector.feedback(**pipeline_metrics)
123+
121124
def __len__(self):
122125
return self.reader.__len__()
123126

@@ -139,6 +142,7 @@ def __init__(self, config: StorageConfig):
139142
total_steps=config.total_steps,
140143
enable_progress_bar=config.enable_progress_bar,
141144
)
145+
self.selector = None
142146

143147
def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
144148
samples, _ = self.dataset.read_batch(batch_size or self.read_batch_size)
@@ -178,6 +182,15 @@ def __init__(self, config: StorageConfig):
178182
enable_progress_bar=self.config.enable_progress_bar,
179183
)
180184
self.formatter = FORMATTER.get("task")(config)
185+
if self.config.task_selector is not None:
186+
from trinity.buffer.selector import SELECTORS
187+
from trinity.buffer.selector.selector import BaseSelector
188+
189+
self.selector: BaseSelector = SELECTORS.get(self.config.task_selector.selector_type)(
190+
self.dataset, self.config.task_selector
191+
)
192+
else:
193+
self.selector = None
181194

182195
def _get_tasks(self, samples: List, indices: List) -> List:
183196
tasks = []
@@ -189,22 +202,21 @@ def _get_tasks(self, samples: List, indices: List) -> List:
189202

190203
def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
191204
batch_size = batch_size or self.read_batch_size
192-
samples, indices = self.dataset.read_batch(batch_size)
193-
return self._get_tasks(samples, indices)
194-
195-
def read_with_indices(self, indices: List[int]) -> List:
196-
"""Read tasks with indices."""
197-
samples = self.dataset.select_batch(indices)
205+
if self.selector is not None:
206+
indices = self.selector.get_indices(batch_size)
207+
samples = self.dataset.select_batch(indices)
208+
else:
209+
samples, indices = self.dataset.read_batch(batch_size)
198210
return self._get_tasks(samples, indices)
199211

200-
async def read_with_indices_async(self, indices: List[int]) -> List:
201-
"""Read tasks with indices asynchronously."""
202-
return self.read_with_indices(indices)
203-
204212
def state_dict(self):
213+
if self.selector is not None:
214+
return self.selector.state_dict()
205215
return {"current_index": self.dataset.current_offset}
206216

207217
def load_state_dict(self, state_dict):
218+
if self.selector is not None:
219+
self.selector.load_state_dict(state_dict)
208220
self.dataset.current_offset = state_dict["current_index"]
209221

210222
def __len__(self):

trinity/buffer/selector/selector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[
4444
"""
4545
raise NotImplementedError
4646

47-
def update(self, indices: List[int], values: List[float]) -> None:
47+
def feedback(self, indices: List[int], values: List[float]) -> None:
4848
"""
4949
Update internal state based on feedback (e.g., model loss, accuracy).
5050
@@ -95,7 +95,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[
9595
return list(range(start, end))
9696
return list(range(start, self.dataset_size)) + list(range(0, end - self.dataset_size))
9797

98-
def update(self, indices: List[int], values: List[float]) -> None:
98+
def feedback(self, indices: List[int], values: List[float]) -> None:
9999
# No-op: sequential selection doesn't adapt based on feedback
100100
pass
101101

@@ -150,7 +150,7 @@ def get_indices(self, batch_size: int, return_extra_info: bool = False) -> List[
150150
self.current_index += batch_size
151151
return ret
152152

153-
def update(self, indices: List[int], values: List[float]) -> None:
153+
def feedback(self, indices: List[int], values: List[float]) -> None:
154154
# No-op: static shuffling does not adapt
155155
pass
156156

@@ -188,7 +188,7 @@ def get_indices(self, batch_size, return_extra_info=False):
188188
else:
189189
return selected_indices
190190

191-
def update(self, indices: List[int], values: List[float]) -> None:
191+
def feedback(self, indices: List[int], values: List[float]) -> None:
192192
# No-op: basic random selection doesn't adapt
193193
pass
194194

@@ -239,7 +239,7 @@ def __init__(self, data_source, config: TaskSelectorConfig):
239239
self.dataset_size = data_source.dataset_size
240240
self.current_index = 0
241241

242-
def update(self, indices: List[int], values: List[float]) -> None:
242+
def feedback(self, indices: List[int], values: List[float]) -> None:
243243
# No-op: this selector does not adapt based on runtime feedback
244244
pass
245245

@@ -340,7 +340,7 @@ def build_diff_estimator(self, dataset, feature_keys: List[str], config: dict):
340340
adaptive_rho=adaptive_rho,
341341
)
342342

343-
def update(self, indices: List[int], values: List[float]) -> None:
343+
def feedback(self, indices: List[int], values: List[float]) -> None:
344344
"""
345345
Updates the difficulty estimator with observed performance on selected samples.
346346

trinity/buffer/task_scheduler.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
"""The taskset scheduler."""
33

44
from collections import Counter
5+
from copy import deepcopy
56
from typing import Dict, List
67

78
import numpy as np
89

910
from trinity.buffer.buffer import get_buffer_reader
10-
from trinity.buffer.selector import SELECTORS
1111
from trinity.common.config import Config
1212
from trinity.common.constants import SELECTOR_METRIC
1313
from trinity.utils.annotations import Experimental
@@ -47,7 +47,7 @@ def state_dict(self) -> List[Dict]:
4747
"""
4848
raise NotImplementedError
4949

50-
def update(self, pipeline_metrics: Dict) -> None:
50+
def feedback(self, pipeline_metrics: Dict) -> None:
5151
"""Update selectors using feedback from the training pipeline."""
5252
raise NotImplementedError
5353

@@ -68,16 +68,18 @@ def __init__(self, explorer_state: Dict, config: Config):
6868
index = self.explorer_state.get("taskset_states", [{"current_index": 0}])[0].get(
6969
"current_index", 0
7070
)
71-
self.config.buffer.explorer_input.tasksets[0].index = index
72-
self.reader = get_buffer_reader(config.buffer.explorer_input.tasksets[0])
71+
taskset_config = deepcopy(self.config.buffer.explorer_input.tasksets[0])
72+
taskset_config.index = index
73+
taskset_config.task_selector = None # disable selection
74+
self.reader = get_buffer_reader(taskset_config)
7375

7476
async def read_async(self) -> List:
7577
return await self.reader.read_async()
7678

7779
def state_dict(self) -> List[Dict]:
7880
return [self.reader.state_dict()]
7981

80-
def update(self, pipeline_metrics: Dict) -> None:
82+
def feedback(self, pipeline_metrics: Dict) -> None:
8183
# do nothing here
8284
return
8385

@@ -127,7 +129,6 @@ def __init__(self, explorer_state: Dict, config: Config):
127129
"taskset_states", [{"current_index": 0}] * len(taskset_configs)
128130
)
129131
self.tasksets = []
130-
self.selectors = []
131132
for taskset_config, taskset_state in zip(taskset_configs, taskset_states):
132133
assert not taskset_config.is_eval # assume drop last
133134
taskset = get_buffer_reader(taskset_config)
@@ -136,15 +137,8 @@ def __init__(self, explorer_state: Dict, config: Config):
136137
f"Taskset '{taskset_config.name}' has an unsupported type '{type(taskset).__name__}'."
137138
f"Currently, only 'FileReader' is supported by TasksetScheduler."
138139
)
139-
140-
# Create selector based on type specified in config (e.g., 'sequential', 'shuffle')
141-
selector = SELECTORS.get(taskset_config.task_selector.selector_type)(
142-
taskset.reader.dataset, taskset_config.task_selector
143-
)
144-
selector.load_state_dict(taskset_state) # Restore any prior state
145-
140+
taskset.load_state_dict(taskset_state) # Restore any prior state
146141
self.tasksets.append(taskset)
147-
self.selectors.append(selector)
148142

149143
# Each explorer step calls read_async once → track step globally
150144
self.step = explorer_state.get("latest_iteration", 0)
@@ -224,8 +218,7 @@ async def read_async(self) -> List:
224218
counter = Counter(taskset_ids)
225219
batch = []
226220
for taskset_id, count in counter.items():
227-
indices = self.selectors[taskset_id].get_indices(batch_size=count)
228-
tasks = await self.tasksets[taskset_id].read_with_indices_async(indices)
221+
tasks = await self.tasksets[taskset_id].read_async(batch_size=count)
229222
# Annotate each task with its origin
230223
for task in tasks:
231224
task.index["taskset_id"] = taskset_id
@@ -239,13 +232,13 @@ def state_dict(self) -> List[Dict]:
239232
Save persistent state for checkpointing.
240233
241234
Returns:
242-
List[Dict]: State dicts for all selectors (one per taskset)
235+
List[Dict]: State dicts for all tasksets
243236
"""
244-
return [selector.state_dict() for selector in self.selectors]
237+
return [taskset.state_dict() for taskset in self.tasksets]
245238

246-
def update(self, pipeline_metrics: Dict) -> None:
239+
def feedback(self, pipeline_metrics: Dict) -> None:
247240
"""
248-
Update selectors using feedback from the training pipeline.
241+
Update selectors in tasksets using feedback from the training pipeline.
249242
250243
Expected format:
251244
pipeline_metrics = {
@@ -265,5 +258,5 @@ def update(self, pipeline_metrics: Dict) -> None:
265258
return
266259
selector_metric = pipeline_metrics.pop(SELECTOR_METRIC, {})
267260
for taskset_id, taskset_kwargs in selector_metric.items():
268-
selector = self.selectors[taskset_id]
269-
selector.update(**taskset_kwargs)
261+
taskset = self.tasksets[taskset_id]
262+
taskset.feedback(**taskset_kwargs)

trinity/common/models/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,11 @@ def model_name(self) -> Optional[str]:
454454
"""Get the name of the model."""
455455
return self._model_name
456456

457+
@property
458+
def model_config(self) -> InferenceModelConfig:
459+
"""Get the model config."""
460+
return self.config
461+
457462
@property
458463
def generate_kwargs(self) -> Dict[str, Any]:
459464
"""Get the generation kwargs for openai client."""

trinity/explorer/explorer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None:
411411
batch_id=step, min_num=self.min_wait_num
412412
)
413413
pipeline_metrics = await self.experience_pipeline.process.remote(exps)
414-
self.taskset.update(pipeline_metrics)
414+
self.taskset.feedback(pipeline_metrics)
415415
metric.update(pipeline_metrics)
416416
if statuses:
417417
metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout"))

0 commit comments

Comments
 (0)