Skip to content

Commit cb3f16b

Browse files
committed
add taskset scheduler
1 parent 3d12bd9 commit cb3f16b

File tree

5 files changed

+156
-72
lines changed

5 files changed

+156
-72
lines changed

trinity/buffer/reader/file_reader.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List, Optional
44

55
import datasets
6-
from datasets import Dataset, load_dataset
6+
from datasets import Dataset, IterableDataset, load_dataset
77

88
from trinity.buffer.buffer_reader import BufferReader
99
from trinity.buffer.schema.formatter import FORMATTER
@@ -32,18 +32,24 @@ def __init__(
3232
drop_last: bool = True,
3333
total_steps: Optional[int] = None,
3434
enable_progress_bar: Optional[bool] = True,
35+
shuffle: bool = False,
36+
base_seed: Optional[int] = 42,
3537
):
36-
self.dataset = dataset
3738
self.dataset_size = len(dataset)
3839
self.name = name
3940
self.current_batch_size = None
4041
self.drop_last = drop_last
42+
self.shuffle = shuffle
43+
self.base_seed = base_seed
4144

4245
self.current_offset = offset
43-
self.iter = iter(self.dataset)
44-
45-
for _ in range(self.current_offset % self.dataset_size):
46-
next(self.iter)
46+
if self.shuffle:
47+
assert not isinstance(
48+
dataset, IterableDataset
49+
), "Shuffle is not supported for IterableDataset"
50+
self.dataset = dataset.shuffle(seed=self.current_seed)
51+
else:
52+
self.dataset = dataset
4753

4854
# convert epochs/steps to sample number
4955
if total_steps:
@@ -63,29 +69,29 @@ def __init__(
6369

6470
self.progress_bar.update(self.current_offset)
6571

72+
def current_seed(self):
73+
return self.base_seed + self.current_offset // self.dataset_size
74+
6675
def read_batch(self, batch_size: int) -> List:
6776
if self.current_offset >= self.total_samples:
6877
self.progress_bar.close()
6978
raise StopIteration
7079
batch = []
7180

7281
while len(batch) < batch_size:
73-
try:
74-
item = next(self.iter)
75-
batch.append(item)
76-
self.current_offset += 1
77-
except StopIteration:
78-
if self.current_offset >= self.total_samples:
79-
# No more data to read
80-
if not self.drop_last and len(batch) > 0:
81-
# return last batch
82-
self.progress_bar.update(len(batch))
83-
return batch
84-
else:
85-
self.progress_bar.close()
86-
raise StopIteration
87-
# Step to the next epoch
88-
self.iter = iter(self.dataset)
82+
batch.append(self.dataset[self.current_offset % self.dataset_size])
83+
self.current_offset += 1
84+
if self.shuffle and self.current_offset % self.dataset_size == 0:
85+
self.dataset = self.dataset.shuffle(seed=self.current_seed)
86+
if self.current_offset >= self.total_samples:
87+
# No more data to read
88+
if not self.drop_last and len(batch) > 0:
89+
# return last batch
90+
self.progress_bar.update(len(batch))
91+
return batch
92+
else:
93+
self.progress_bar.close()
94+
raise StopIteration
8995
self.progress_bar.update(batch_size)
9096
return batch
9197

@@ -144,9 +150,15 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
144150
drop_last=not self.meta.is_eval,
145151
total_steps=meta.total_steps,
146152
enable_progress_bar=meta.enable_progress_bar,
153+
shuffle=meta.shuffle,
154+
base_seed=meta.seed,
147155
)
148156
self.formatter = FORMATTER.get("task")(meta)
149157

158+
@property
159+
def index(self) -> int:
160+
return self.dataset.current_offset
161+
150162
def read(self, batch_size: Optional[int] = None) -> List:
151163
batch_size = batch_size or self.read_batch_size
152164
tasks = []

trinity/buffer/task_scheduler.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# -*- coding: utf-8 -*-
2+
"""The taskset scheduler."""
3+
4+
from collections import deque
5+
from typing import Dict, List, Optional
6+
from trinity.buffer.buffer import get_buffer_reader
7+
from trinity.common.config import Config
8+
9+
10+
class TasksetScheduler:
11+
def __init__(self, explorer_state, config: Config):
12+
if 'latest_task_index' in explorer_state:
13+
assert len(config.buffer.explorer_input.taskset) == 1 # old format
14+
explorer_state['taskset'] = [
15+
{
16+
"index": explorer_state['latest_task_index'],
17+
}
18+
]
19+
20+
tasksets_config = config.buffer.explorer_input.tasksets
21+
22+
tasksets_state = explorer_state.get('taskset', [{"index": 0}] * len(tasksets_config))
23+
self.tasksets = []
24+
for taskset_config, taskset_state in zip(tasksets_config, tasksets_state):
25+
taskset_config.index = taskset_state["index"]
26+
assert not taskset_config.is_eval
27+
self.tasksets.append(get_buffer_reader(taskset_config, config.buffer))
28+
self.tasksets_queue = deque()
29+
for taskset in self.tasksets:
30+
self.tasksets_queue.append(taskset)
31+
32+
def read(self, batch_size: Optional[int] = None) -> List:
33+
batch = []
34+
for _ in range(len(self.tasksets_queue)):
35+
taskset = self.tasksets_queue.popleft()
36+
try:
37+
batch = taskset.read(batch_size)
38+
assert len(batch) == batch_size
39+
self.tasksets_queue.append(taskset)
40+
break
41+
except StopIteration:
42+
pass
43+
if len(batch) == 0:
44+
raise StopIteration
45+
return batch
46+
47+
async def read_async(self, batch_size: Optional[int] = None) -> List:
48+
try:
49+
return self.read(batch_size)
50+
except StopIteration as e:
51+
raise StopAsyncIteration from e
52+
53+
def save_state(self) -> Dict:
54+
return [
55+
{
56+
"index": taskset.index,
57+
}
58+
for taskset in self.tasksets
59+
]
60+
61+
def update(self, experiences, explore_metric, eval_metric) -> None:
62+
pass

trinity/common/config.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import dataclass, field
88
from datetime import datetime
99
from enum import Enum
10-
from typing import Any, Dict, List, Optional
10+
from typing import Any, Dict, List, Optional, Union
1111

1212
from omegaconf import OmegaConf
1313

@@ -108,6 +108,10 @@ class StorageConfig:
108108
path: Optional[str] = None
109109
repeat_times: Optional[int] = None
110110

111+
# For shuffle
112+
shuffle: bool = False
113+
seed: int = 42
114+
111115
# For continuing training
112116
index: int = 0
113117

@@ -369,7 +373,8 @@ class ClusterConfig:
369373
class ExplorerInput:
370374
"""Config for explorer input."""
371375

372-
taskset: StorageConfig = field(default_factory=StorageConfig)
376+
taskset: Optional[StorageConfig] = None
377+
tasksets: List[StorageConfig] = field(default_factory=list)
373378
eval_tasksets: List[StorageConfig] = field(default_factory=list)
374379
# The following args provide default values for the corresponding args in `taskset` and `eval_tasksets`
375380
default_workflow_type: Optional[str] = None
@@ -630,40 +635,44 @@ def _check_buffer(self) -> None: # noqa: C901
630635
trainer_input = self.buffer.trainer_input
631636
experience_buffer = trainer_input.experience_buffer
632637
explorer_input = self.buffer.explorer_input
633-
taskset = explorer_input.taskset
634638

635-
if self.mode != "train" and not taskset.path:
636-
raise ValueError(
637-
"`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
638-
)
639-
if not taskset.name:
640-
taskset.name = "taskset"
641-
if taskset.repeat_times is None or taskset.repeat_times != self.algorithm.repeat_times:
642-
taskset.repeat_times = self.algorithm.repeat_times
643-
logger.info(
644-
"`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`"
645-
f" (={self.algorithm.repeat_times})."
639+
if len(explorer_input.tasksets) == 0 and explorer_input.taskset:
640+
explorer_input.tasksets.append(explorer_input.taskset)
641+
tasksets = explorer_input.tasksets
642+
643+
for taskset in tasksets:
644+
if self.mode != "train" and not taskset.path:
645+
raise ValueError(
646+
"`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
647+
)
648+
if not taskset.name:
649+
taskset.name = "taskset"
650+
if taskset.repeat_times is None or taskset.repeat_times != self.algorithm.repeat_times:
651+
taskset.repeat_times = self.algorithm.repeat_times
652+
logger.info(
653+
"`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`"
654+
f" (={self.algorithm.repeat_times})."
655+
)
656+
if self.mode == "train":
657+
assert (
658+
experience_buffer is not None
659+
), "`buffer.trainer_input.experience_buffer` is required when `mode` is `train`."
660+
experience_buffer.total_epochs = self.buffer.total_epochs
661+
experience_buffer.total_steps = self.buffer.total_steps
662+
else:
663+
taskset.is_eval = False
664+
taskset.total_epochs = self.buffer.total_epochs
665+
taskset.total_steps = self.buffer.total_steps
666+
667+
set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type)
668+
set_if_none(
669+
taskset, "default_eval_workflow_type", explorer_input.default_eval_workflow_type
646670
)
647-
if self.mode == "train":
648-
assert (
649-
experience_buffer is not None
650-
), "`buffer.trainer_input.experience_buffer` is required when `mode` is `train`."
651-
experience_buffer.total_epochs = self.buffer.total_epochs
652-
experience_buffer.total_steps = self.buffer.total_steps
653-
else:
654-
taskset.is_eval = False
655-
taskset.total_epochs = self.buffer.total_epochs
656-
taskset.total_steps = self.buffer.total_steps
657-
658-
set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type)
659-
set_if_none(
660-
taskset, "default_eval_workflow_type", explorer_input.default_eval_workflow_type
661-
)
662-
set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
663-
set_if_none(taskset.format, "system_prompt", explorer_input.system_prompt)
664-
set_if_none(taskset.format, "reply_prefix", explorer_input.reply_prefix)
665-
set_if_none(taskset, "ray_namespace", self.ray_namespace)
666-
set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens)
671+
set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
672+
set_if_none(taskset.format, "system_prompt", explorer_input.system_prompt)
673+
set_if_none(taskset.format, "reply_prefix", explorer_input.reply_prefix)
674+
set_if_none(taskset, "ray_namespace", self.ray_namespace)
675+
set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens)
667676

668677
remained_tasksets = []
669678
for idx, dataset in enumerate(explorer_input.eval_tasksets):
@@ -730,8 +739,8 @@ def _check_buffer(self) -> None: # noqa: C901
730739
task_pipeline = self.data_processor.task_pipeline
731740
if task_pipeline is not None:
732741
if task_pipeline.output is None:
733-
if taskset.path is not None:
734-
task_pipeline.output = taskset
742+
if tasksets[0].path is not None:
743+
task_pipeline.output = tasksets[0]
735744
elif (
736745
experience_buffer.schema_type in {"dpo", "sft"}
737746
and experience_buffer.path is not None
@@ -740,7 +749,7 @@ def _check_buffer(self) -> None: # noqa: C901
740749
else:
741750
raise ValueError(
742751
"`data_processor.task_pipeline.output` is required when both "
743-
"`buffer.explorer_input.taskset.path` and `buffer.trainer_input.experience_buffer.path` are "
752+
"`buffer.explorer_input.tasksets[0].path` and `buffer.trainer_input.experience_buffer.path` are "
744753
"None"
745754
)
746755
if task_pipeline.output.path and os.path.exists(task_pipeline.output.path):

trinity/explorer/explorer.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from trinity.buffer.buffer import get_buffer_reader
1717
from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline
18+
from trinity.buffer.task_scheduler import TasksetScheduler
1819
from trinity.common.config import Config
1920
from trinity.common.constants import (
2021
ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
@@ -49,12 +50,7 @@ def __init__(self, config: Config):
4950
self.config = config
5051
self.models, self.auxiliary_models = create_inference_models(config)
5152
self.experience_pipeline = self._init_experience_pipeline()
52-
self.config.buffer.explorer_input.taskset.index = explorer_state.get("latest_task_index", 0)
53-
self.taskset = (
54-
get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
55-
if self.config.mode != "serve"
56-
else None
57-
)
53+
self.taskset = TasksetScheduler(explorer_state, config)
5854
self.scheduler = None
5955
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
6056
project=self.config.project,
@@ -324,7 +320,7 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None:
324320
# save explore checkpoint
325321
self.state.save_explorer(
326322
current_step=self.explore_step_num,
327-
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
323+
taskset_state=self.taskset.save_state(),
328324
)
329325

330326
async def sync_weight(self) -> None:
@@ -335,19 +331,23 @@ async def sync_weight(self) -> None:
335331
async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None:
336332
for step in range(start_step, end_step + 1):
337333
self.logger.info(f"Log metrics of step {step}")
338-
await self._finish_explore_step(step=step, model_version=model_version)
339-
await self._finish_eval_step(step=step)
334+
explore_metric, exps = await self._finish_explore_step(
335+
step=step, model_version=model_version
336+
)
337+
eval_metric = await self._finish_eval_step(step=step)
338+
self.taskset.update(exps, explore_metric, eval_metric)
340339

341-
async def _finish_explore_step(self, step: int, model_version: int) -> None:
340+
async def _finish_explore_step(self, step: int, model_version: int):
342341
statuses, exps = await self.scheduler.get_results(batch_id=step)
343342
metric = {"rollout/model_version": model_version}
344343
pipeline_metrics = await self.experience_pipeline.process.remote(exps)
345344
metric.update(pipeline_metrics)
346345
if statuses:
347346
metric.update(gather_metrics([status.metric for status in statuses], "rollout"))
348347
self.monitor.log(metric, step=step)
348+
return metric, exps
349349

350-
async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None:
350+
async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval"):
351351
if not self.pending_eval_tasks:
352352
return
353353
step = step or self.explore_step_num
@@ -366,6 +366,7 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
366366
)
367367
metric[f"{prefix}/total_time"] = time.time() - st
368368
self.monitor.log(metric, step)
369+
return metric
369370

370371
async def shutdown(self) -> None:
371372
if self.scheduler:

trinity/manager/state_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ def _check_config_consistency(self, config: Config) -> None:
4949
def save_explorer(
5050
self,
5151
current_task_index: int,
52-
current_step: int,
52+
taskset_state: dict,
5353
) -> None:
5454
with open(self.explorer_state_path, "w", encoding="utf-8") as f:
5555
json.dump(
5656
{
5757
"latest_task_index": current_task_index,
58-
"latest_iteration": current_step,
58+
"taskset_state": taskset_state,
5959
},
6060
f,
6161
indent=2,

0 commit comments

Comments
 (0)