Skip to content

Commit 292795a

Browse files
committed
add docs
1 parent 2aa5f37 commit 292795a

File tree

8 files changed

+74
-43
lines changed

8 files changed

+74
-43
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ Controls the rollout models and workflow execution.
313313
```yaml
314314
explorer:
315315
name: explorer
316-
runner_num: 32
316+
runner_per_model: 8
317317
max_timeout: 900
318318
max_retry_times: 2
319319
env_vars: {}
@@ -324,17 +324,22 @@ explorer:
324324
auxiliary_models:
325325
- model_path: /PATH/TO/MODEL
326326
tensor_parallel_size: 1
327+
eval_interval: 100
328+
eval_on_startup: True
327329
```
328330

329331
- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique.
330-
- `runner_num`: Number of parallel workflow runners.
332+
- `runner_per_model`: Number of parallel workflow runners per each rollout model.
331333
- `max_timeout`: Maximum time (in seconds) for a workflow to complete.
332334
- `max_retry_times`: Maximum number of retries for a workflow.
333335
- `env_vars`: Environment variables to be set for every workflow runners.
334336
- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`.
335337
- `rollout_model.engine_num`: Number of inference engines.
336338
- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism.
337339
- `auxiliary_models`: Additional models used for custom workflows.
340+
- `eval_interval`: Interval (in steps) for evaluating the model.
341+
- `eval_on_startup`: Whether to evaluate the model on startup. More precisely, at step 0 with the original model, so it will not be triggered when restarting.
342+
- `runner_num`: (*Deprecated*) Number of parallel workflow runners.
338343

339344
---
340345

tests/template/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ buffer:
3737
default_reward_fn_type: ''
3838
explorer:
3939
eval_interval: 100
40-
runner_num: 4
40+
runner_per_model: 8
4141
rollout_model:
4242
engine_type: vllm_async
4343
engine_num: 2

tests/trainer/trainer_test.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_trainer(self):
6060
self.config.buffer.explorer_input.eval_tasksets.append(
6161
get_unittest_dataset_config("copy_countdown", "test")
6262
)
63-
self.config.trainer.save_interval = 4
63+
self.config.trainer.save_interval = 6
6464
self.config.check_and_update()
6565
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
6666
self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 2
@@ -84,24 +84,25 @@ def test_trainer(self):
8484
self.assertEqual(parser.metric_max_step(response_metrics[0]), 8)
8585
ray.shutdown(_exiting_interpreter=True)
8686
# check checkpoint
87-
checkpoint_step_4, _ = get_checkpoint_dir_with_step_num(
87+
checkpoint_step_6, _ = get_checkpoint_dir_with_step_num(
8888
checkpoint_root_path=self.config.checkpoint_job_dir,
8989
trainer_type=self.config.trainer.trainer_type,
90-
step_num=4,
90+
step_num=6,
9191
)
92-
checkpoint_step_8, _ = get_checkpoint_dir_with_step_num(
92+
# check save lastest checkpoint
93+
checkpoint_step_8, step_num = get_checkpoint_dir_with_step_num(
9394
checkpoint_root_path=self.config.checkpoint_job_dir,
9495
trainer_type=self.config.trainer.trainer_type,
95-
step_num=8,
9696
)
97-
self.assertTrue(os.path.exists(checkpoint_step_4))
98-
self.assertTrue(os.path.exists(checkpoint_step_8))
97+
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_6, "actor"))) > 0)
98+
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0)
99+
self.assertEqual(step_num, 8)
99100
# TODO: Reinit will fail when using v1 engine, find a way to fix it
100101
ray.init(ignore_reinit_error=True)
101102
# test bench mode
102103
self.config.mode = "bench"
103104
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
104-
self.config.explorer.eval_on_latest_checkpoint = False
105+
self.config.explorer.bench_on_latest_checkpoint = False
105106
self.config.check_and_update()
106107
bench(self.config)
107108
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
@@ -116,7 +117,8 @@ def test_trainer(self):
116117

117118
def tearDown(self):
118119
# remove dir only when the test passed
119-
shutil.rmtree(self.config.checkpoint_job_dir)
120+
# shutil.rmtree(self.config.checkpoint_job_dir)
121+
pass
120122

121123

122124
class TestStepAheadAsyncRL(BaseTrainerCase):
@@ -328,7 +330,6 @@ def test_fully_async_mode(self):
328330
config.cluster.node_num = 1
329331
explorer1_config.explorer.rollout_model.engine_num = 1
330332
explorer1_config.explorer.rollout_model.tensor_parallel_size = 1
331-
explorer1_config.explorer.runner_num = 4
332333
explorer1_config.buffer.explorer_output = StorageConfig(
333334
name="exp_buffer",
334335
storage_type=StorageType.QUEUE,

trinity/common/config.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,12 +301,12 @@ class ExplorerConfig:
301301
name: str = EXPLORER_NAME
302302
# for workflow runner
303303
# number of workflow runners.
304-
# For sync engine (vllm), it should be equal to `engine_num`.
305-
# For async engine (vllm_async), it can be larger than `engine_num`, e.g. 16 * `engine_num`
306-
runner_num: int = 1
304+
# For sync engine (vllm), it should be `1`.
305+
# For async engine (vllm_async), it could be a large number.
306+
runner_per_model: int = 8 # number of runners per each rollout model
307307
max_timeout: int = 900 # wait each task for 15 minutes
308308
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout
309-
runner_per_model: int = 8
309+
runner_num: Optional[int] = None # deprecated
310310

311311
# for inference models
312312
# for rollout model
@@ -316,7 +316,10 @@ class ExplorerConfig:
316316

317317
# for evaluation
318318
eval_interval: int = 100
319-
eval_on_latest_checkpoint: bool = False
319+
eval_on_startup: bool = True # evalulate at step 0
320+
321+
# for benchmark
322+
bench_on_latest_checkpoint: bool = False # only benchmark the latest checkpoint
320323

321324

322325
@dataclass

trinity/explorer/explorer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, config: Config):
3939
self.cache = CacheManager(config)
4040
explorer_meta = self.cache.load_explorer()
4141
self.explore_step_num = explorer_meta.get("latest_iteration", 0)
42-
self.last_sync_step = self.explore_step_num
42+
self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1
4343
self.config = config
4444
self.algorithm_manager = AlgorithmManager(config)
4545
self.models, self.auxiliary_models = create_inference_models(config)
@@ -169,6 +169,8 @@ async def prepare(self) -> None:
169169
asyncio.create_task(self.setup_weight_sync_group(master_address, master_port))
170170
)
171171
asyncio.gather(*futures, return_exceptions=True)
172+
if self.config.explorer.eval_on_startup and self.explore_step_num == 0:
173+
self.eval()
172174

173175
async def get_weight(self, name: str) -> torch.Tensor:
174176
"""Get the weight of the loaded model (For checkpoint weights update)."""
@@ -177,21 +179,21 @@ async def get_weight(self, name: str) -> torch.Tensor:
177179
async def explore(self) -> str:
178180
"""
179181
The timeline of the exploration process:
180-
explorer | <--------------------------------- one period -------------------------------------> |
181-
| <------------------------------ eval -------------------------------> | <-- sync --> |
182-
| <---------------- step_1 --------------> | |
182+
| <--------------------------------- one period -------------------------------------> |
183+
explorer | <---------------- step_1 --------------> | |
183184
| | <---------------- step_2 --------------> | |
184185
| ... |
185186
| | <---------------- step_n ---------------> | |
186187
| | <---------------------- eval --------------------> | <-- sync --> |
187-
trainer |--------------------------------------------------------------------------------------|
188-
| <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- sync --> |
188+
|--------------------------------------------------------------------------------------|
189+
trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- sync --> |
189190
"""
190191
while True:
191192
try:
192193
self.logger.info(f"Explore step {self.explore_step_num + 1} started.")
193194
explore_contionue = await self.explore_step()
194195
if not explore_contionue:
196+
# TODO: support eval on last checkpoint
195197
break
196198
if self.need_eval():
197199
self.eval()
@@ -253,7 +255,7 @@ def eval(self):
253255
async def benchmark(self) -> bool:
254256
"""Benchmark the model checkpoints."""
255257
# benchmark on the latest checkpoint
256-
if self.config.explorer.eval_on_latest_checkpoint:
258+
if self.config.explorer.bench_on_latest_checkpoint:
257259
self.explore_step_num = await self._checkpoint_weights_update()
258260
self.eval()
259261
await self._log_eval_metrics()

trinity/explorer/scheduler.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Scheduler for rollout tasks."""
22

33
import asyncio
4+
import re
45
import time
56
import traceback
67
from collections import defaultdict, deque
@@ -89,6 +90,20 @@ def restart_runner(self):
8990
pass
9091

9192

93+
def sort_batch_id(batch_id: Union[int, str]):
94+
"""Priority of batch_id"""
95+
# TODO: avoid sort the batch_id every time
96+
if isinstance(batch_id, int):
97+
return (batch_id, 0)
98+
else:
99+
match = re.match(r"^(\d+)", batch_id)
100+
if match:
101+
num = int(match.group(1))
102+
return (num, 1)
103+
else:
104+
return (float("inf"), 1)
105+
106+
92107
class Scheduler:
93108
"""Scheduler for rollout tasks."""
94109

@@ -112,9 +127,14 @@ def __init__(
112127
self.idle_runners = set() # runner_id
113128
self.busy_runners = dict() # runner_id -> (task, batch_id)
114129

115-
self.pending_tasks: Dict[str, deque] = defaultdict(deque) # batch_id -> tasks
116-
self.running_tasks: Dict[str, set[asyncio.Future]] = defaultdict(set) # batch_id -> futures
117-
self.completed_tasks: Dict[str, deque[Status]] = defaultdict(deque) # batch_id -> results
130+
self.pending_tasks_heap = []
131+
self.pending_tasks: Dict[Union[int, str], deque] = defaultdict(deque) # batch_id -> tasks
132+
self.running_tasks: Dict[Union[int, str], set[asyncio.Future]] = defaultdict(
133+
set
134+
) # batch_id -> futures
135+
self.completed_tasks: Dict[Union[int, str], deque[Status]] = defaultdict(
136+
deque
137+
) # batch_id -> results
118138

119139
self.scheduler_task: Optional[asyncio.Task] = None
120140
self.running = False
@@ -168,7 +188,7 @@ async def _schedule_pending_tasks(self) -> None:
168188
return
169189

170190
# TODO: Support more advanced scheduling strategies
171-
for batch_id in sorted(self.pending_tasks.keys()):
191+
for batch_id in sorted(self.pending_tasks.keys(), key=sort_batch_id):
172192
task_queue = self.pending_tasks[batch_id]
173193

174194
while task_queue and self.idle_runners:
@@ -205,7 +225,7 @@ async def _check_completed_tasks(self) -> None:
205225
if not futures:
206226
del self.running_tasks[batch_id]
207227

208-
def _clear_timeout_tasks(self, batch_id: str) -> None:
228+
def _clear_timeout_tasks(self, batch_id: Union[int, str]) -> None:
209229
if batch_id in self.pending_tasks:
210230
self.logger.info(f"Clear timeout pending tasks at batch_id {batch_id}.")
211231
del self.pending_tasks[batch_id]
@@ -252,11 +272,11 @@ def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None:
252272
253273
Args:
254274
tasks (`List[Task]`): The tasks to schedule.
255-
batch_id (`Union[int, str]`): The id of provided tasks.
275+
batch_id (`Union[int, str]`): The id of provided tasks. It should be an integer or a string
276+
starting with an integer (e.g., 123, "123/my_task")
256277
"""
257278
if not tasks:
258279
return
259-
batch_id = str(batch_id)
260280
for task in tasks:
261281
self.pending_tasks[batch_id].appendleft(task)
262282

@@ -276,7 +296,6 @@ async def get_results(
276296
clear_timeout_tasks (`bool`): Whether to clear timeout tasks.
277297
"""
278298
timeout = timeout or self.timeout
279-
batch_id = str(batch_id)
280299
start_time = time.time()
281300
if min_num is None:
282301
min_num = 0
@@ -320,7 +339,6 @@ async def get_results(
320339
return results
321340

322341
def has_step(self, batch_id: Union[int, str]) -> bool:
323-
batch_id = str(batch_id)
324342
return (
325343
batch_id in self.completed_tasks
326344
or batch_id in self.pending_tasks
@@ -353,8 +371,8 @@ async def wait_all(
353371
running_count = sum(len(futures) for futures in self.running_tasks.values())
354372

355373
self.logger.debug(f"Pending tasks: {pending_count}, Running tasks: {running_count}")
356-
357374
await asyncio.sleep(0.1)
375+
358376
pending_count = sum(len(tasks) for tasks in self.pending_tasks.values())
359377
running_count = sum(len(futures) for futures in self.running_tasks.values())
360378
error_msg = f"Timeout after {timeout} seconds. Still have {pending_count} pending tasks and {running_count} running tasks."

trinity/manager/config_manager.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,11 @@ def _expert_buffer_part(self):
199199
def _expert_explorer_part(self):
200200
self.get_configs("sync_method", "sync_interval", "sync_timeout")
201201

202-
self.get_configs("runner_num", "max_timeout", "explorer_max_retry_times", "eval_interval")
202+
self.get_configs(
203+
"runner_per_model", "max_timeout", "explorer_max_retry_times", "eval_interval"
204+
)
203205

204-
self.get_configs("eval_on_latest_checkpoint")
206+
self.get_configs("bench_on_latest_checkpoint")
205207

206208
with st.expander("Rollout Model Config", expanded=True):
207209
self.get_configs("engine_type", "engine_num", "tensor_parallel_size")
@@ -571,7 +573,7 @@ def _gen_buffer_config(self):
571573

572574
def _gen_explorer_config(self):
573575
explorer_config = {
574-
"runner_num": st.session_state["runner_num"],
576+
"runner_per_model": st.session_state["runner_per_model"],
575577
"max_timeout": st.session_state["max_timeout"],
576578
"max_retry_times": st.session_state["explorer_max_retry_times"],
577579
"rollout_model": {
@@ -584,7 +586,7 @@ def _gen_explorer_config(self):
584586
},
585587
"auxiliary_models": [],
586588
"eval_interval": st.session_state["eval_interval"],
587-
"eval_on_latest_checkpoint": st.session_state["eval_on_latest_checkpoint"],
589+
"bench_on_latest_checkpoint": st.session_state["bench_on_latest_checkpoint"],
588590
}
589591
for i in range(st.session_state["_auxiliary_models_num"]):
590592
auxiliary_model_config = {

trinity/manager/config_registry/explorer_config_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ def explorer_visible() -> bool:
99
return st.session_state["mode"] == "both"
1010

1111

12-
@CONFIG_GENERATORS.register_config(default_value=32, visible=explorer_visible)
13-
def set_runner_num(**kwargs):
14-
st.number_input("Runner Num", min_value=1, **kwargs)
12+
@CONFIG_GENERATORS.register_config(default_value=8, visible=explorer_visible)
13+
def set_runner_per_model(**kwargs):
14+
st.number_input("Runner per Model", min_value=1, **kwargs)
1515

1616

1717
@CONFIG_GENERATORS.register_config(default_value=900, visible=explorer_visible)
@@ -30,7 +30,7 @@ def set_eval_interval(**kwargs):
3030

3131

3232
@CONFIG_GENERATORS.register_config(default_value=True, visible=explorer_visible)
33-
def set_eval_on_latest_checkpoint(**kwargs):
33+
def set_bench_on_latest_checkpoint(**kwargs):
3434
st.checkbox("Eval on Latest Checkpoint", **kwargs)
3535

3636

0 commit comments

Comments
 (0)