22import os
33import pprint
44import sys
5+ from concurrent .futures import ThreadPoolExecutor
56
67from datasets import load_dataset
78from torchdata .stateful_dataloader import StatefulDataLoader
89
910from areal .api .cli_args import (
11+ SchedulingStrategy ,
1012 load_expr_config ,
1113)
1214from areal .api .io_struct import AllocationMode , FinetuneSpec
1315from areal .extension .asystem .api .cli_args import GRPOConfig
1416from areal .extension .asystem .ascheduler import AsystemScheduler
15- from areal .extension .asystem .controller import RolloutController
17+ from areal .extension .asystem .controller import RolloutController , TrainController
1618from areal .extension .asystem .recover import latest_checkpoint , periodic_checkpoint
1719from areal .extension .asystem .remote_hybrid_inference_worker import (
1820 RemoteHybridInferenceWorker ,
1921)
20- from areal .extension .asystem .util import ShuffleSampler
22+ from areal .extension .asystem .remote_hybrid_train_worker import RemoteHybridTrainWorker
23+ from areal .extension .asystem .util import ShuffleSampler , wait_future_ordered
2124from areal .utils import logging
2225from areal .utils .hf_utils import load_hf_tokenizer
2326from areal .utils .stats_logger import StatsLogger
@@ -112,7 +115,6 @@ def main(args):
112115 sampler = ShuffleSampler (train_dataset ),
113116 collate_fn = custom_collate_fn ,
114117 )
115-
116118 ############################## recover #########################################
117119 recover_meta_info_path = config .recover .recover_meta_info_path
118120 enable_recover = True
@@ -154,18 +156,6 @@ def main(args):
154156
155157 stats_logger = StatsLogger (
156158 config ,
157- # StatsLoggerConfig(
158- # experiment_name=config.experiment_name,
159- # trial_name=config.trial_name,
160- # fileroot=config.stats_logger.fileroot,
161- # wandb=WandBConfig(
162- # mode=config.stats_logger.wandb.mode,
163- # id_suffix=config.stats_logger.wandb.id_suffix,
164- # ),
165- # tensorboard=TensorBoardConfig(
166- # path=config.stats_logger.tensorboard.path
167- # ),
168- # ),
169159 ft_spec ,
170160 )
171161
@@ -177,31 +167,29 @@ def main(args):
177167 inference_config .pp_size = allocate_mode .gen .pp_size
178168 inference_config .storage_path = f"{ config .rollout .storage_path } /{ config .experiment_name } /{ config .trial_name } "
179169 inference_config .seed = config .seed
170+ inference_config .scheduling_strategy = (
171+ SchedulingStrategy (type = "colocation" , target = "actor" )
172+ if config .enable_colocate_mode
173+ else inference_config .scheduling_strategy
174+ )
175+
180176 rollout = RolloutController (
181- RemoteHybridInferenceWorker ( inference_config ) ,
177+ RemoteHybridInferenceWorker ,
182178 inference_config ,
183179 scheduler ,
184180 )
185181
186- # actor = TrainController(
187- # RemoteHybridTrainWorker(config.actor),
188- # TrainControllerConfig(
189- # experiment_name=config.experiment_name,
190- # trial_name=config.trial_name,
191- # allocation_mode=config.allocation_mode,
192- # enable_colocate_mode=config.enable_colocate_mode,
193- # group_size=config.actor.hybrid_engine.group_size,
194- # storage_prefix=config.storage_prefix,
195- # ),
196- # scheduler,
197- # )
198- #
199- # # engine initialize
200- # # initialize actor first for colocation mode
201- # # actor.initialize()
202- # # rollout.initialize(colocate_with=actor if config.enable_colocate_mode else None)
203- #
204- # ref = None
182+ actor = TrainController (
183+ RemoteHybridTrainWorker ,
184+ config .actor ,
185+ scheduler ,
186+ )
187+
188+ allocation_mode = AllocationMode .from_str (config .allocation_mode )
189+
190+ # rollout.initialize(role="rollout", alloc_mode=allocation_mode)
191+ # actor.initialize(role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec)
192+ ref = None
205193 # if config.actor.hybrid_engine.wrap_policy.kl_ctl > 0:
206194 # ref = DistributedReferenceController(
207195 # RemoteHybridTrainWorker(config.ref),
@@ -216,60 +204,64 @@ def main(args):
216204 # scheduler,
217205 # )
218206 # # ref.initialize()
219- #
220- # # 共卡:actor -> rollout 按顺序,reference 可以并行。
221- # # 分卡:actor、rollout、reference 三者并行。
222- #
223- # # helper function for initializing Reference controller
224- # def init_ref_controller_helper(ref):
225- # if ref is not None:
226- # logger.info("ref is not none, initializing reference controller")
227- # ref.initialize()
228- #
229- # # helper function for initializing Train Controller (actor) & Rollout Controller
230- # def init_train_and_rollout_controller_helper(actor, rollout):
231- # logger.info("initializing trainer controller and rollout controller")
232- # actor.initialize()
233- # rollout.initialize(
234- # colocate_with=actor if config.enable_colocate_mode else None
235- # )
236- #
237- # # helper function for initializing Rollout controller
238- # def init_rollout_controller_helper(rollout):
239- # logger.info("initializing rollout controller")
240- # rollout.initialize(
241- # colocate_with=actor if config.enable_colocate_mode else None
242- # )
243- #
244- # if config.enable_colocate_mode:
245- # logger.info(
246- # f"initializing all controllers in colocation mode {config.enable_colocate_mode}"
247- # )
248- # with ThreadPoolExecutor(max_workers=2) as executor:
249- # futures = [
250- # executor.submit(
251- # init_train_and_rollout_controller_helper, actor, rollout
252- # ),
253- # executor.submit(init_ref_controller_helper, ref),
254- # ]
255- # wait_future_ordered(futures)
256- # logger.info(
257- # f"initialized all controllers in colocation mode {config.enable_colocate_mode}"
258- # )
259- # else:
260- # logger.info(
261- # f"initializing all controllers in colocation mode {config.enable_colocate_mode}"
262- # )
263- # with ThreadPoolExecutor(max_workers=3) as executor:
264- # futures = [
265- # executor.submit(actor.initialize),
266- # executor.submit(init_rollout_controller_helper, rollout),
267- # executor.submit(init_ref_controller_helper, ref),
268- # ]
269- # wait_future_ordered(futures)
270- # logger.info(
271- # f"initialized all controllers in colocation mode {config.enable_colocate_mode}"
272- # )
207+
208+ # 共卡:actor -> rollout 按顺序,reference 可以并行。
209+ # 分卡:actor、rollout、reference 三者并行。
210+
211+ # helper function for initializing Train Controller (actor) & Rollout Controller
212+ def init_train_and_rollout_controller_helper (actor , rollout ):
213+ logger .info ("initializing trainer controller and rollout controller" )
214+ actor .initialize (role = "actor" , alloc_mode = allocation_mode , ft_spec = ft_spec )
215+ rollout .initialize (role = "rollout" , alloc_mode = allocation_mode )
216+
217+ if config .enable_colocate_mode :
218+ with ThreadPoolExecutor (max_workers = 2 ) as executor :
219+ futures = [
220+ executor .submit (
221+ init_train_and_rollout_controller_helper , actor , rollout
222+ ),
223+ ]
224+ if ref is not None :
225+ futures .append (
226+ executor .submit (
227+ ref .initialize ,
228+ role = "ref" ,
229+ alloc_mode = allocation_mode ,
230+ ft_spec = ft_spec ,
231+ )
232+ )
233+
234+ wait_future_ordered (futures )
235+ logger .info (
236+ f"initialized all controllers in colocation mode { config .enable_colocate_mode } "
237+ )
238+ else :
239+ with ThreadPoolExecutor (max_workers = 3 ) as executor :
240+ futures = [
241+ executor .submit (
242+ actor .initialize ,
243+ role = "actor" ,
244+ alloc_mode = allocation_mode ,
245+ ft_spec = ft_spec ,
246+ ),
247+ executor .submit (
248+ rollout .initialize , role = "rollout" , alloc_mode = allocation_mode
249+ ),
250+ ]
251+ if ref is not None :
252+ futures .append (
253+ executor .submit (
254+ ref .initialize ,
255+ role = "ref" ,
256+ alloc_mode = allocation_mode ,
257+ ft_spec = ft_spec ,
258+ )
259+ )
260+
261+ wait_future_ordered (futures )
262+ logger .info (
263+ f"initialized all controllers in colocation mode { config .enable_colocate_mode } "
264+ )
273265 #
274266 # if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
275267 # config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
0 commit comments