Skip to content

Commit 331bc2f

Browse files
author
daihao
committed
initialize success
1 parent 9fcba4d commit 331bc2f

File tree

6 files changed

+615
-316
lines changed

6 files changed

+615
-316
lines changed

areal/examples/grpo_trainer.py

Lines changed: 81 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,25 @@
22
import os
33
import pprint
44
import sys
5+
from concurrent.futures import ThreadPoolExecutor
56

67
from datasets import load_dataset
78
from torchdata.stateful_dataloader import StatefulDataLoader
89

910
from areal.api.cli_args import (
11+
SchedulingStrategy,
1012
load_expr_config,
1113
)
1214
from areal.api.io_struct import AllocationMode, FinetuneSpec
1315
from areal.extension.asystem.api.cli_args import GRPOConfig
1416
from areal.extension.asystem.ascheduler import AsystemScheduler
15-
from areal.extension.asystem.controller import RolloutController
17+
from areal.extension.asystem.controller import RolloutController, TrainController
1618
from areal.extension.asystem.recover import latest_checkpoint, periodic_checkpoint
1719
from areal.extension.asystem.remote_hybrid_inference_worker import (
1820
RemoteHybridInferenceWorker,
1921
)
20-
from areal.extension.asystem.util import ShuffleSampler
22+
from areal.extension.asystem.remote_hybrid_train_worker import RemoteHybridTrainWorker
23+
from areal.extension.asystem.util import ShuffleSampler, wait_future_ordered
2124
from areal.utils import logging
2225
from areal.utils.hf_utils import load_hf_tokenizer
2326
from areal.utils.stats_logger import StatsLogger
@@ -112,7 +115,6 @@ def main(args):
112115
sampler=ShuffleSampler(train_dataset),
113116
collate_fn=custom_collate_fn,
114117
)
115-
116118
############################## recover #########################################
117119
recover_meta_info_path = config.recover.recover_meta_info_path
118120
enable_recover = True
@@ -154,18 +156,6 @@ def main(args):
154156

155157
stats_logger = StatsLogger(
156158
config,
157-
# StatsLoggerConfig(
158-
# experiment_name=config.experiment_name,
159-
# trial_name=config.trial_name,
160-
# fileroot=config.stats_logger.fileroot,
161-
# wandb=WandBConfig(
162-
# mode=config.stats_logger.wandb.mode,
163-
# id_suffix=config.stats_logger.wandb.id_suffix,
164-
# ),
165-
# tensorboard=TensorBoardConfig(
166-
# path=config.stats_logger.tensorboard.path
167-
# ),
168-
# ),
169159
ft_spec,
170160
)
171161

@@ -177,31 +167,29 @@ def main(args):
177167
inference_config.pp_size = allocate_mode.gen.pp_size
178168
inference_config.storage_path = f"{config.rollout.storage_path}/{config.experiment_name}/{config.trial_name}"
179169
inference_config.seed = config.seed
170+
inference_config.scheduling_strategy = (
171+
SchedulingStrategy(type="colocation", target="actor")
172+
if config.enable_colocate_mode
173+
else inference_config.scheduling_strategy
174+
)
175+
180176
rollout = RolloutController(
181-
RemoteHybridInferenceWorker(inference_config),
177+
RemoteHybridInferenceWorker,
182178
inference_config,
183179
scheduler,
184180
)
185181

186-
# actor = TrainController(
187-
# RemoteHybridTrainWorker(config.actor),
188-
# TrainControllerConfig(
189-
# experiment_name=config.experiment_name,
190-
# trial_name=config.trial_name,
191-
# allocation_mode=config.allocation_mode,
192-
# enable_colocate_mode=config.enable_colocate_mode,
193-
# group_size=config.actor.hybrid_engine.group_size,
194-
# storage_prefix=config.storage_prefix,
195-
# ),
196-
# scheduler,
197-
# )
198-
#
199-
# # engine initialize
200-
# # initialize actor first for colocation mode
201-
# # actor.initialize()
202-
# # rollout.initialize(colocate_with=actor if config.enable_colocate_mode else None)
203-
#
204-
# ref = None
182+
actor = TrainController(
183+
RemoteHybridTrainWorker,
184+
config.actor,
185+
scheduler,
186+
)
187+
188+
allocation_mode = AllocationMode.from_str(config.allocation_mode)
189+
190+
# rollout.initialize(role="rollout", alloc_mode=allocation_mode)
191+
# actor.initialize(role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec)
192+
ref = None
205193
# if config.actor.hybrid_engine.wrap_policy.kl_ctl > 0:
206194
# ref = DistributedReferenceController(
207195
# RemoteHybridTrainWorker(config.ref),
@@ -216,60 +204,64 @@ def main(args):
216204
# scheduler,
217205
# )
218206
# # ref.initialize()
219-
#
220-
# # 共卡:actor -> rollout 按顺序,reference 可以并行。
221-
# # 分卡:actor、rollout、reference 三者并行。
222-
#
223-
# # helper function for initializing Reference controller
224-
# def init_ref_controller_helper(ref):
225-
# if ref is not None:
226-
# logger.info("ref is not none, initializing reference controller")
227-
# ref.initialize()
228-
#
229-
# # helper function for initializing Train Controller (actor) & Rollout Controller
230-
# def init_train_and_rollout_controller_helper(actor, rollout):
231-
# logger.info("initializing trainer controller and rollout controller")
232-
# actor.initialize()
233-
# rollout.initialize(
234-
# colocate_with=actor if config.enable_colocate_mode else None
235-
# )
236-
#
237-
# # helper function for initializing Rollout controller
238-
# def init_rollout_controller_helper(rollout):
239-
# logger.info("initializing rollout controller")
240-
# rollout.initialize(
241-
# colocate_with=actor if config.enable_colocate_mode else None
242-
# )
243-
#
244-
# if config.enable_colocate_mode:
245-
# logger.info(
246-
# f"initializing all controllers in colocation mode {config.enable_colocate_mode}"
247-
# )
248-
# with ThreadPoolExecutor(max_workers=2) as executor:
249-
# futures = [
250-
# executor.submit(
251-
# init_train_and_rollout_controller_helper, actor, rollout
252-
# ),
253-
# executor.submit(init_ref_controller_helper, ref),
254-
# ]
255-
# wait_future_ordered(futures)
256-
# logger.info(
257-
# f"initialized all controllers in colocation mode {config.enable_colocate_mode}"
258-
# )
259-
# else:
260-
# logger.info(
261-
# f"initializing all controllers in colocation mode {config.enable_colocate_mode}"
262-
# )
263-
# with ThreadPoolExecutor(max_workers=3) as executor:
264-
# futures = [
265-
# executor.submit(actor.initialize),
266-
# executor.submit(init_rollout_controller_helper, rollout),
267-
# executor.submit(init_ref_controller_helper, ref),
268-
# ]
269-
# wait_future_ordered(futures)
270-
# logger.info(
271-
# f"initialized all controllers in colocation mode {config.enable_colocate_mode}"
272-
# )
207+
208+
# 共卡:actor -> rollout 按顺序,reference 可以并行。
209+
# 分卡:actor、rollout、reference 三者并行。
210+
211+
# helper function for initializing Train Controller (actor) & Rollout Controller
212+
def init_train_and_rollout_controller_helper(actor, rollout):
213+
logger.info("initializing trainer controller and rollout controller")
214+
actor.initialize(role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec)
215+
rollout.initialize(role="rollout", alloc_mode=allocation_mode)
216+
217+
if config.enable_colocate_mode:
218+
with ThreadPoolExecutor(max_workers=2) as executor:
219+
futures = [
220+
executor.submit(
221+
init_train_and_rollout_controller_helper, actor, rollout
222+
),
223+
]
224+
if ref is not None:
225+
futures.append(
226+
executor.submit(
227+
ref.initialize,
228+
role="ref",
229+
alloc_mode=allocation_mode,
230+
ft_spec=ft_spec,
231+
)
232+
)
233+
234+
wait_future_ordered(futures)
235+
logger.info(
236+
f"initialized all controllers in colocation mode {config.enable_colocate_mode}"
237+
)
238+
else:
239+
with ThreadPoolExecutor(max_workers=3) as executor:
240+
futures = [
241+
executor.submit(
242+
actor.initialize,
243+
role="actor",
244+
alloc_mode=allocation_mode,
245+
ft_spec=ft_spec,
246+
),
247+
executor.submit(
248+
rollout.initialize, role="rollout", alloc_mode=allocation_mode
249+
),
250+
]
251+
if ref is not None:
252+
futures.append(
253+
executor.submit(
254+
ref.initialize,
255+
role="ref",
256+
alloc_mode=allocation_mode,
257+
ft_spec=ft_spec,
258+
)
259+
)
260+
261+
wait_future_ordered(futures)
262+
logger.info(
263+
f"initialized all controllers in colocation mode {config.enable_colocate_mode}"
264+
)
273265
#
274266
# if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
275267
# config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)

0 commit comments

Comments
 (0)