Skip to content

Commit 6f2d7c7

Browse files
authored
Support one-step ahead async RL (#93)
1 parent 99a772a commit 6f2d7c7

File tree

8 files changed

+156
-30
lines changed

8 files changed

+156
-30
lines changed

tests/trainer/trainer_test.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,56 @@ def tearDown(self):
115115
shutil.rmtree(self.config.checkpoint_job_dir)
116116

117117

118+
class TestStepAheadAsyncRL(BaseTrainerCase):
119+
def test_trainer(self):
120+
"""Test the explore step ahead trainer"""
121+
# train 4 step, sync_offset=1, sync_interval=2
122+
# Explorer:
123+
# | 1 | 2 | 3 |sync| 4 |
124+
# |---|---|---|sync|---|
125+
# Trainer:
126+
# | 1 | 2 |sync| 3 | 4 |
127+
# |---|---|sync|---|---|
128+
self.config.buffer.total_epochs = 1
129+
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
130+
self.config.trainer.save_interval = 4
131+
self.config.synchronizer.sync_interval = 2
132+
self.config.synchronizer.sync_offset = 1
133+
self.config.check_and_update()
134+
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 1
135+
self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 1
136+
137+
both(self.config)
138+
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
139+
rollout_metrics = parser.metric_list("rollout")
140+
self.assertTrue(len(rollout_metrics) > 0)
141+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
142+
actor_metrics = parser.metric_list("actor")
143+
self.assertTrue(len(actor_metrics) > 0)
144+
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4)
145+
actor_kl_metrics = parser.metric_list("actor/kl")
146+
self.assertTrue(len(actor_kl_metrics) > 0)
147+
critic_kl_metrics = parser.metric_list("critic/kl")
148+
self.assertTrue(len(critic_kl_metrics) > 0)
149+
response_metrics = parser.metric_list("response_length")
150+
self.assertTrue(len(response_metrics) > 0)
151+
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
152+
ray.shutdown(_exiting_interpreter=True)
153+
# check checkpoint
154+
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
155+
156+
checkpoint_step_4 = get_checkpoint_dir_with_step_num(
157+
checkpoint_root_path=self.config.checkpoint_job_dir,
158+
trainer_type=self.config.trainer.trainer_type,
159+
step_num=4,
160+
)
161+
self.assertTrue(os.path.exists(checkpoint_step_4))
162+
163+
def tearDown(self):
164+
# remove dir only when the test passed
165+
shutil.rmtree(self.config.checkpoint_job_dir)
166+
167+
118168
class TestTrainerGSM8K(BaseTrainerCase):
119169
def test_trainer(self):
120170
"""Test GSM8K."""
@@ -153,7 +203,7 @@ def tearDown(self):
153203
shutil.rmtree(self.config.checkpoint_job_dir)
154204

155205

156-
class TestTrainerGSM8KWithSFT(BaseTrainerCase):
206+
class TestTrainerSFTWarmupGSM8K(BaseTrainerCase):
157207
def test_trainer(self):
158208
"""Test GSM8K With SFT."""
159209
# test both mode

trinity/cli/launcher.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import ray
1010

1111
from trinity.common.config import Config, load_config
12+
from trinity.common.constants import EXPLORER_NAME, TRAINER_NAME
1213
from trinity.explorer.explorer import Explorer
1314
from trinity.trainer.trainer import Trainer
1415
from trinity.utils.log import get_logger
@@ -19,7 +20,7 @@
1920

2021
def bench(config: Config) -> None:
2122
"""Evaluate model."""
22-
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
23+
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
2324
try:
2425
ray.get(explorer.prepare.remote())
2526
ray.get(explorer.benchmark.remote())
@@ -33,7 +34,7 @@ def bench(config: Config) -> None:
3334
def explore(config: Config) -> None:
3435
"""Run explorer."""
3536
try:
36-
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
37+
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
3738
ray.get(explorer.prepare.remote())
3839
ray.get(explorer.sync_weight.remote())
3940
ray.get(explorer.explore.remote())
@@ -46,7 +47,7 @@ def explore(config: Config) -> None:
4647
def train(config: Config) -> None:
4748
"""Run trainer."""
4849
try:
49-
trainer = ray.remote(Trainer).options(name="trainer").remote(config)
50+
trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
5051
ray.get(trainer.prepare.remote())
5152
ray.get(trainer.sync_weight.remote())
5253
ray.get(trainer.train.remote())
@@ -66,8 +67,8 @@ def both(config: Config) -> None:
6667
the latest step. The specific number of experiences may vary for different
6768
algorithms and tasks.
6869
"""
69-
explorer = ray.remote(Explorer).options(name="explorer").remote(config)
70-
trainer = ray.remote(Trainer).options(name="trainer").remote(config)
70+
explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config)
71+
trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config)
7172
ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()])
7273
ray.get(
7374
[
@@ -81,15 +82,34 @@ def both(config: Config) -> None:
8182
trainer.sync_weight.remote(),
8283
]
8384
)
84-
_, _ = ray.wait(
85+
ready_ref, wait_ref = ray.wait(
8586
[
8687
explorer.explore.remote(),
8788
trainer.train.remote(),
8889
],
8990
num_returns=1,
9091
)
91-
explorer.shutdown.remote(),
92-
trainer.shutdown.remote(),
92+
93+
ready = ray.get(ready_ref[0])
94+
if ready == TRAINER_NAME:
95+
logger.info(
96+
"===========================================================\n"
97+
"> Launcher detected that the `Trainer` process has finished.\n"
98+
"> Stopping the explorer process immediately.\n"
99+
"==========================================================="
100+
)
101+
ray.wait(wait_ref, timeout=5)
102+
elif ready == EXPLORER_NAME:
103+
logger.info(
104+
"============================================================\n"
105+
"> Launcher detected that the `Explorer` process has finished.\n"
106+
f"> Waiting {config.synchronizer.sync_timeout} s for the trainer process...\n"
107+
"> You can force stop the Trainer process by pressing Ctrl+C.\n"
108+
"============================================================"
109+
)
110+
ray.wait(wait_ref, timeout=config.synchronizer.sync_timeout)
111+
explorer.shutdown.remote()
112+
trainer.shutdown.remote()
93113

94114

95115
def activate_data_module(data_workflow_url: str, config_path: str):

trinity/common/constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
# names
1010

11+
EXPLORER_NAME = "explorer"
12+
TRAINER_NAME = "trainer"
13+
1114
ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync"
1215

1316

@@ -92,3 +95,11 @@ class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta):
9295

9396
NCCL = "nccl"
9497
CHECKPOINT = "checkpoint"
98+
99+
100+
class RunningStatus(Enum):
101+
"""Running status of explorer and trainer."""
102+
103+
RUNNING = "running"
104+
WAITING_SYNC = "waiting_sync"
105+
STOPPED = "stopped"

trinity/common/models/vllm_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torch.distributed
66

7+
from trinity.common.constants import EXPLORER_NAME
78
from trinity.utils.distributed import init_process_group, is_ipv6_address
89
from trinity.utils.log import get_logger
910

@@ -60,7 +61,7 @@ def update_weight(self):
6061
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
6162
assert self._state_dict_meta is not None
6263
if self._explorer_actor is None:
63-
self._explorer_actor = ray.get_actor(name="explorer")
64+
self._explorer_actor = ray.get_actor(name=EXPLORER_NAME)
6465
for name, dtype_str, shape in self._state_dict_meta:
6566
if self._weight_update_rank == 0:
6667
weight = ray.get(self._explorer_actor.get_weight.remote(name))

trinity/explorer/explorer.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
from trinity.buffer import get_buffer_writer
1515
from trinity.buffer.buffer import get_buffer_reader
1616
from trinity.common.config import Config
17-
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
17+
from trinity.common.constants import (
18+
EXPLORER_NAME,
19+
ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
20+
RunningStatus,
21+
SyncMethod,
22+
)
1823
from trinity.common.models import create_inference_models
1924
from trinity.common.models.utils import (
2025
get_checkpoint_dir_with_step_num,
@@ -50,7 +55,7 @@ def __init__(self, config: Config):
5055
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
5156
project=self.config.project,
5257
name=self.config.name,
53-
role="explorer",
58+
role=EXPLORER_NAME,
5459
config=config,
5560
)
5661
self.batch_size = config.buffer.batch_size
@@ -69,6 +74,7 @@ def __init__(self, config: Config):
6974
self.state_dict = {}
7075
else: # nccl mode
7176
self.state_dict_meta = []
77+
self.status = RunningStatus.RUNNING
7278
self.logger.info("Finished initializing Explorer.")
7379

7480
async def setup_weight_sync_group(
@@ -162,35 +168,44 @@ async def get_weight(self, name: str) -> torch.Tensor:
162168
"""Get the weight of the loaded model (For checkpoint weights update)."""
163169
return self.state_dict[name]
164170

165-
async def explore(self) -> None:
171+
async def explore(self) -> str:
166172
while True:
167173
try:
168174
explore_contionue = self.explore_step()
175+
if not explore_contionue:
176+
break
169177
if self.need_sync():
170178
self.wait_for_workflow_done()
171179
await self.sync_weight()
172180
if self.explore_step_num % self.config.explorer.eval_interval == 0:
173181
self.wait_for_workflow_done()
174182
self.eval()
175-
if not explore_contionue:
176-
break
177183
except Exception as e:
178184
self.logger.error(f"Error in Explorer: {e}")
179185
break
180-
self.logger.info("--------------------\n> Explorer finished.\n--------------------\n")
186+
self.logger.info("--------------------\n> Explorer finished.\n--------------------")
187+
return EXPLORER_NAME
181188

182189
def explore_step(self) -> bool:
183-
self.explore_step_num += 1
184-
algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num)
190+
algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num + 1)
185191
# skip warmup
186192
if algo_config.algorithm_type == "sft":
193+
self.explore_step_num += 1
187194
return True
188195
try:
189196
tasks = self.taskset.read()
190197
except StopIteration:
191198
self.logger.warning("No more tasks to explore. Stop exploring.")
199+
self.cache.save_explorer(
200+
current_step=self.explore_step_num,
201+
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
202+
)
203+
self.status = RunningStatus.STOPPED
204+
self.wait_for_workflow_done()
205+
self.experience_buffer.finish()
192206
return False
193207
self.runner_pool.run_tasks(tasks)
208+
self.explore_step_num += 1
194209
return True
195210

196211
def need_sync(self) -> bool:
@@ -278,20 +293,25 @@ def wait_for_workflow_done(self) -> None:
278293
if not status.ok:
279294
self.logger.error(f"Error when running task: {status.message}")
280295
# submit another task to replace the failed task
281-
self.runner_pool.run_tasks(self.taskset.read(batch_size=1))
296+
try:
297+
tasks = self.taskset.read(batch_size=1)
298+
except StopIteration:
299+
self.logger.warning("No more tasks in taskset. Stop retrying.")
300+
return
301+
self.runner_pool.run_tasks(tasks)
282302
else:
283303
for metric_name, metric_value in status.metric.items():
284304
all_metrics[metric_name].append(metric_value)
285305
# calculate metrics
286306
log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore
287307
self.monitor.log(log_metrics, step=self.explore_step_num)
288-
289308
self.logger.info(f"Explore step {self.explore_step_num} finished.")
290309

291310
async def sync_weight(self) -> None:
292311
"""Synchronize model weights."""
293312
# call this method before training start to load the latest model weights
294-
self.logger.info(f"Explorer synchronizing weights at step {self.explore_step_num}.")
313+
self.logger.info(f"Explorer sync weights at step {self.explore_step_num}.")
314+
self.status = RunningStatus.WAITING_SYNC
295315
if self.use_checkpoint_weights_update:
296316
await self._checkpoint_weights_update()
297317
else: # nccl weights update
@@ -301,7 +321,11 @@ async def sync_weight(self) -> None:
301321
current_step=self.explore_step_num,
302322
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
303323
)
304-
self.logger.info(f"Explorer synchronizing at step {self.explore_step_num} finished")
324+
self.status = RunningStatus.RUNNING
325+
self.logger.info(f"Explorer sync at step {self.explore_step_num} finished")
326+
327+
async def running_status(self) -> RunningStatus:
328+
return self.status
305329

306330
def flush_log(self, step: int) -> None:
307331
"""Flush the log of the current step."""

trinity/trainer/trainer.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@
77
import os
88
from abc import ABC, abstractmethod
99

10+
import ray
11+
1012
from trinity.common.config import Config
11-
from trinity.common.constants import SyncMethod
13+
from trinity.common.constants import (
14+
EXPLORER_NAME,
15+
TRAINER_NAME,
16+
RunningStatus,
17+
SyncMethod,
18+
)
1219
from trinity.utils.log import get_logger
1320

1421

@@ -19,24 +26,26 @@ def __init__(self, config: Config) -> None:
1926
self.config = config
2027
self.logger = get_logger(__name__)
2128
self.engine = get_trainer_wrapper(config)
29+
self.explorer_ref = None
2230

2331
def prepare(self) -> None:
2432
"""Prepare the trainer."""
2533
self.engine.prepare()
2634

27-
def train(self):
35+
def train(self) -> str:
2836
"""Train the model."""
2937
while True:
3038
try:
3139
train_continue = self.train_step()
32-
if self.need_sync():
33-
self.sync_weight()
3440
if not train_continue:
3541
break
42+
if self.need_sync():
43+
self.sync_weight()
3644
except Exception as e:
3745
self.logger.error(f"Error in Trainer: {e}")
3846
break
39-
self.logger.info("--------------------\n> Trainer finished.\n--------------------\n")
47+
self.logger.info("--------------------\n> Trainer finished.\n--------------------")
48+
return TRAINER_NAME
4049

4150
def train_step(self) -> bool:
4251
"""Train one step.
@@ -53,6 +62,12 @@ def need_sync(self) -> bool:
5362
def sync_weight(self) -> None:
5463
"""Sync the model weight."""
5564
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
65+
if self.explorer_ref is None:
66+
self.explorer_ref = ray.get_actor(EXPLORER_NAME)
67+
explorer_status = ray.get(self.explorer_ref.running_status.remote())
68+
if explorer_status == RunningStatus.STOPPED:
69+
self.logger.warning("Explorer has already stopped. Skipping sync weight.")
70+
return
5671
self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.")
5772
self.engine.sync_weight()
5873

trinity/trainer/verl/fsdp_workers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@
7171
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
7272

7373
from trinity.common.config import AlgorithmConfig
74-
from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod
74+
from trinity.common.constants import (
75+
EXPLORER_NAME,
76+
ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
77+
SyncMethod,
78+
)
7579
from trinity.utils.distributed import init_process_group, is_ipv6_address
7680

7781
logger = logging.getLogger(__file__)
@@ -573,7 +577,7 @@ def setup_weight_sync_group(self):
573577
master_address, master_port = self.get_availale_master_addr_port()
574578
world_size = self.config.synchronizer.explorer_world_size + 1
575579
print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).")
576-
explorer = ray.get_actor("explorer")
580+
explorer = ray.get_actor(EXPLORER_NAME)
577581
setup_ref = explorer.setup_weight_sync_group.remote(
578582
master_address, master_port, self.state_dict_meta
579583
)

0 commit comments

Comments
 (0)