Skip to content

Commit 99a772a

Browse files
authored
Unify async/sync RL (#91)
1 parent 7a1c526 commit 99a772a

File tree

13 files changed

+238
-243
lines changed

13 files changed

+238
-243
lines changed

tests/template/verl_config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ actor_rollout_ref:
1616
shuffle: False
1717
ulysses_sequence_parallel_size: 1 # sp size
1818
checkpoint:
19-
contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
19+
contents: ["model", "optimizer", "extra"] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
2020
optim:
2121
lr: 1e-6
2222
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
@@ -72,6 +72,8 @@ critic:
7272
shuffle: ${actor_rollout_ref.actor.shuffle}
7373
grad_clip: 1.0
7474
cliprange_value: 0.5
75+
checkpoint:
76+
contents: ["model", "optimizer", "extra"] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
7577

7678
trainer:
7779
balance_batch: True

trinity/algorithm/sample_strategy/sample_strategy.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,40 @@
1212

1313

1414
class SampleStrategy(ABC):
15-
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
15+
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs) -> None:
1616
self.pad_token_id = buffer_config.pad_token_id
1717
self.trainer_type = trainer_type
1818

1919
@abstractmethod
2020
def sample(self, step: int) -> Tuple[Any, Dict, List]:
21-
"""Sample experiences from buffer.
21+
"""Sample data from buffer.
2222
2323
Args:
2424
step (`int`): The step number of current step.
2525
2626
Returns:
27-
`Any`: The sampled experiences.
27+
`Any`: The sampled data.
2828
`Dict`: Metrics for logging.
29-
`List`: Representative experiences for logging.
29+
`List`: Representative data for logging.
30+
"""
31+
32+
# Experimental API
33+
@abstractmethod
34+
def warmup_state(self, step: int) -> Tuple[bool, bool]:
35+
"""Check the warmup state of the current step.
36+
37+
Args:
38+
step (`int`): The step number of current step.
39+
40+
Returns:
41+
`bool`: Current step is in warmup or not.
42+
`bool`: Warmup is finished on this step or not.
3043
"""
3144

3245
@classmethod
46+
@abstractmethod
3347
def default_args(cls) -> dict:
34-
return {}
48+
"""Get the default arguments of the sample strategy."""
3549

3650

3751
@SAMPLE_STRATEGY.register_module("warmup")
@@ -70,6 +84,13 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
7084
else:
7185
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
7286

87+
def warmup_state(self, step: int) -> Tuple[bool, bool]:
88+
return step <= self.sft_warmup_steps, step == self.sft_warmup_steps
89+
90+
@classmethod
91+
def default_args(cls) -> dict:
92+
return {}
93+
7394

7495
@SAMPLE_STRATEGY.register_module("default")
7596
class DefaultSampleStrategy(SampleStrategy):
@@ -93,6 +114,13 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
93114
else:
94115
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
95116

117+
def warmup_state(self, step: int) -> Tuple[bool, bool]:
118+
return False, False
119+
120+
@classmethod
121+
def default_args(cls) -> dict:
122+
return {}
123+
96124

97125
@SAMPLE_STRATEGY.register_module("dpo")
98126
class DPOSampleStrategy(WarmupSampleStrategy):

trinity/cli/launcher.py

Lines changed: 32 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import argparse
33
import os
44
import sys
5+
import traceback
56
from pathlib import Path
67
from pprint import pprint
78

@@ -18,44 +19,41 @@
1819

1920
def bench(config: Config) -> None:
2021
"""Evaluate model."""
21-
explorer = Explorer.remote(config)
22+
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
2223
try:
2324
ray.get(explorer.prepare.remote())
2425
ray.get(explorer.benchmark.remote())
2526
logger.info("Benchmark finished.")
2627
ray.get(explorer.shutdown.remote())
27-
except Exception as e:
28-
logger.error(f"Benchmark failed: {e}")
29-
raise e
28+
except Exception:
29+
error_msg = traceback.format_exc()
30+
logger.error(f"Benchmark failed:\n{error_msg}")
3031

3132

3233
def explore(config: Config) -> None:
3334
"""Run explorer."""
34-
explorer = Explorer.remote(config)
3535
try:
36+
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
3637
ray.get(explorer.prepare.remote())
3738
ray.get(explorer.sync_weight.remote())
3839
ray.get(explorer.explore.remote())
39-
logger.info("Explore finished.")
4040
ray.get(explorer.shutdown.remote())
41-
except Exception as e:
42-
logger.error(f"Explore failed: {e}")
43-
raise e
41+
except Exception:
42+
error_msg = traceback.format_exc()
43+
logger.error(f"Explorer failed:\n{error_msg}")
4444

4545

4646
def train(config: Config) -> None:
4747
"""Run trainer."""
48-
49-
trainer = Trainer.remote(config)
50-
ray.get(trainer.prepare.remote())
51-
5248
try:
49+
trainer = ray.remote(Trainer).options(name="trainer").remote(config)
50+
ray.get(trainer.prepare.remote())
51+
ray.get(trainer.sync_weight.remote())
5352
ray.get(trainer.train.remote())
54-
logger.info("Train finished.")
5553
ray.get(trainer.shutdown.remote())
56-
except Exception as e:
57-
logger.error(f"Train failed {e}.")
58-
raise e
54+
except Exception:
55+
error_msg = traceback.format_exc()
56+
logger.error(f"Trainer failed:\n{error_msg}")
5957

6058

6159
def both(config: Config) -> None:
@@ -68,54 +66,30 @@ def both(config: Config) -> None:
6866
the latest step. The specific number of experiences may vary for different
6967
algorithms and tasks.
7068
"""
71-
explorer = Explorer.remote(config)
72-
trainer = Trainer.remote(config)
69+
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
70+
trainer = ray.remote(Trainer).options(name="trainer").remote(config)
7371
ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()])
74-
logger.info("Setup explorer and trainer finished.")
7572
ray.get(
7673
[
7774
explorer.prepare.remote(),
7875
trainer.prepare.remote(),
7976
]
8077
)
81-
# sync weight before training start
82-
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])
83-
84-
while True:
85-
try:
86-
ref_explore = explorer.explore_one_period.remote()
87-
ref_train = trainer.train_one_period.remote()
88-
explore_continue, explore_step_num = ray.get(ref_explore)
89-
train_continue, train_step_num = ray.get(ref_train)
90-
if not explore_continue:
91-
# If explore finished, the trainer may not have enough experiences to continue,
92-
# which will cause the trainer be blocked. So we stop the training process
93-
# immediately.
94-
# TODO: use a more elegant way to stop the training process.
95-
logger.info("Explorer finished, stopping...")
96-
break
97-
if not train_continue:
98-
logger.info("Trainer finished, stopping...")
99-
break
100-
ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()])
101-
logger.info("Model weight synchronized.")
102-
except Exception as e:
103-
logger.error(e)
104-
logger.error("Training stopped due to exception.")
105-
raise e
106-
if explore_step_num % config.explorer.eval_interval == 0:
107-
try:
108-
ray.get(explorer.eval.remote())
109-
logger.info("Evaluation finished.")
110-
except Exception as e:
111-
logger.error(e)
112-
logger.error("Evaluation failed.")
113-
raise e
114-
ray.get(explorer.flush_log.remote(step=explore_step_num))
115-
ray.get(trainer.flush_log.remote(step=train_step_num))
116-
117-
ray.get(explorer.shutdown.remote())
118-
ray.get(trainer.shutdown.remote())
78+
ray.get(
79+
[
80+
explorer.sync_weight.remote(),
81+
trainer.sync_weight.remote(),
82+
]
83+
)
84+
_, _ = ray.wait(
85+
[
86+
explorer.explore.remote(),
87+
trainer.train.remote(),
88+
],
89+
num_returns=1,
90+
)
91+
explorer.shutdown.remote(),
92+
trainer.shutdown.remote(),
11993

12094

12195
def activate_data_module(data_workflow_url: str, config_path: str):

trinity/common/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,10 @@ class SynchronizerConfig:
319319
sync_method: SyncMethod = SyncMethod.NCCL
320320
# sync weights every `sync_interval` steps
321321
sync_interval: int = 1
322+
# allow explorer to run `sync_offset` steps before sync
323+
sync_offset: int = 0
322324
# waiting for `sync_timeout` seconds before timeout in `nccl` method
323-
sync_timeout: int = 1200
325+
sync_timeout: int = 1800
324326
# wait for the lastest checkpoint to be ready # TODO: to be used
325327
wait_for_checkpoint: bool = False
326328

trinity/common/models/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None
156156
iteration = f.read().strip()
157157
return os.path.join(checkpoint_path, f"global_step_{iteration}")
158158
else:
159-
logger.error(f"No iteration file found in {checkpoint_path}")
160159
raise FileNotFoundError(f"No iteration file found in {checkpoint_path}")
161160
else:
162161
# load specific iteration checkpoint

trinity/common/models/vllm_async_model.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,9 @@ async def _collective_rpc(
267267

268268
async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
269269
"""Sync model weights to vLLM."""
270-
if self.state_dict_meta is None:
271-
self.state_dict_meta = update_weight_args_list
272-
for args in self.state_dict_meta:
273-
await self._collective_rpc("update_weight", args=args)
270+
if update_weight_args_list is not None:
271+
await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
272+
await self._collective_rpc("update_weight")
274273
self.logger.info("Sync model weights to vLLM successfully.")
275274
self.ckp_version += 1
276275
return True
@@ -287,7 +286,6 @@ async def init_process_group(
287286
update_with_checkpoint: bool = True,
288287
state_dict_meta: dict = None,
289288
):
290-
self.state_dict_meta = state_dict_meta
291289
return await self._collective_rpc(
292290
"init_process_group",
293291
args=(
@@ -299,12 +297,10 @@ async def init_process_group(
299297
backend,
300298
timeout,
301299
update_with_checkpoint,
300+
state_dict_meta,
302301
),
303302
)
304303

305-
async def update_weight(self, name, dtype, shape, empty_cache=False):
306-
return await self._collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))
307-
308304
async def run_api_server(self):
309305
"""Run the OpenAI API server in a Ray actor.
310306

trinity/common/models/vllm_model.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def init_process_group(
100100
update_with_checkpoint: bool = True,
101101
state_dict_meta: dict = None,
102102
):
103-
self.state_dict_meta = state_dict_meta
104103
return self.llm.collective_rpc(
105104
"init_process_group",
106105
args=(
@@ -112,12 +111,10 @@ def init_process_group(
112111
backend,
113112
timeout,
114113
update_with_checkpoint,
114+
state_dict_meta,
115115
),
116116
)
117117

118-
def update_weight(self, name, dtype, shape, empty_cache=False):
119-
return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))
120-
121118
def reset_prefix_cache(self):
122119
self.llm.llm_engine.reset_prefix_cache()
123120

@@ -279,11 +276,9 @@ def has_api_server(self) -> bool:
279276

280277
def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool:
281278
"""Sync model weights to vLLM."""
282-
if self.state_dict_meta is None:
283-
self.state_dict_meta = update_weight_args_list
284-
with self.lock:
285-
for args in self.state_dict_meta:
286-
self.llm.collective_rpc("update_weight", args=args)
279+
if update_weight_args_list is not None:
280+
self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,))
281+
self._collective_rpc("update_weight")
287282
self.logger.info("Sync model weights to vLLM successfully.")
288283
self.ckp_version += 1
289284
return True

0 commit comments

Comments
 (0)