1+ import concurrent .futures
12import json
23import os
34import pprint
4- import sys
5- import asyncio
65import shutil
7- import concurrent . futures
6+ import sys
87from concurrent .futures import ThreadPoolExecutor
98
109from datasets import load_dataset
1110from torchdata .stateful_dataloader import StatefulDataLoader
1211
12+ from realhf .api .core .data_api import load_hf_tokenizer
13+
1314from areal .api .cli_args import (
1415 SchedulingStrategy ,
1516 load_expr_config ,
1617)
18+ from areal .api .engine_api import WeightUpdateMeta
1719from areal .api .io_struct import AllocationMode , FinetuneSpec
1820from areal .extension .asystem .api .cli_args import GRPOConfig
1921from areal .extension .asystem .ascheduler import AsystemScheduler
2325 RemoteHybridInferenceWorker ,
2426)
2527from areal .extension .asystem .remote_hybrid_train_worker import RemoteHybridTrainWorker
26- from areal .extension .asystem .util import ShuffleSampler , wait_future_ordered
28+ from areal .extension .asystem .utils .align_tools import summarize_rewards
29+ from areal .extension .asystem .utils .util import ShuffleSampler , wait_future_ordered
2730from areal .utils import logging , stats_tracker
28- from areal .utils .hf_utils import load_hf_tokenizer
2931from areal .utils .stats_logger import StatsLogger
30- from areal .api .engine_api import WeightUpdateMeta
3132
3233logger = logging .getLogger ("Trainer" )
3334
3435
35- def custom_collate_fn (batch ):
36- all_keys = set ().union (* (d .keys () for d in batch ))
37- collated_batch = {}
38- for key in all_keys :
39- collated_batch [key ] = [d .get (key ) for d in batch ]
40- return collated_batch
41-
42-
4336def clear_dir (path ):
4437 if os .path .exists (path ):
4538 for filename in os .listdir (path ):
@@ -54,13 +47,9 @@ def main(args):
5447 config , _ = load_expr_config (args , GRPOConfig )
5548 config : GRPOConfig
5649
57- if config .gconfig .max_tokens is None :
58- logger .info (
59- "config.gconfig.max_tokens is None, set it to max_new_tokens + max_prompt_len"
60- )
61- config .gconfig .max_tokens = (
62- config .gconfig .max_new_tokens + config .train_dataset .max_length
63- )
50+ config .gconfig .max_tokens = (
51+ config .gconfig .max_new_tokens + config .train_dataset .max_length
52+ )
6453
6554 if config .enable_colocate_mode :
6655 config .rollout .engine_config ["enable_memory_saver" ] = True
@@ -122,22 +111,9 @@ def main(args):
122111 train_dataset = dataset ["train" ]
123112 train_dataset = train_dataset .filter (
124113 lambda x : len (tokenizer .encode (x ["prompt" ]))
125- <= config .train_dataset .max_length
114+ <= config .train_dataset .max_length
126115 )
127116
128- def process (sample ):
129- messages = [
130- {
131- "role" : "user" ,
132- "content" : sample ["prompt" ]
133- .replace ("<role>HUMAN</role>" , "" )
134- .replace ("<role>ASSISTANT</role>" , "" ),
135- }
136- ]
137- return {"messages" : messages }
138-
139- train_dataset = train_dataset .map (process ).remove_columns (["prompt" ])
140-
141117 dataloader = StatefulDataLoader (
142118 train_dataset ,
143119 batch_size = config .train_dataset .batch_size ,
@@ -221,16 +197,21 @@ def process(sample):
221197 if config .actor .hybrid_engine .wrap_policy .kl_ctl > 0 :
222198 ref = TrainController (
223199 RemoteHybridTrainWorker ,
224- config .actor ,
200+ config .ref ,
225201 scheduler ,
226202 )
227203
228204 allocation_mode = AllocationMode .from_str (config .allocation_mode )
229205
230206 def init_train_and_rollout_controller_helper (actor , rollout ):
231207 logger .info ("initializing trainer controller and rollout controller" )
232- actor .initialize (role = "actor" , alloc_mode = allocation_mode , ft_spec = ft_spec ,
233- group_size = config .gconfig .n_samples , )
208+ actor .initialize (
209+ role = "actor" ,
210+ alloc_mode = allocation_mode ,
211+ ft_spec = ft_spec ,
212+ group_size = config .gconfig .n_samples ,
213+ enable_colocate_mode = config .enable_colocate_mode ,
214+ )
234215 rollout .initialize (role = "rollout" , alloc_mode = allocation_mode )
235216
236217 if config .enable_colocate_mode :
@@ -254,7 +235,9 @@ def init_train_and_rollout_controller_helper(actor, rollout):
254235 )
255236
256237 wait_future_ordered (futures )
257- logger .info (f"initialized all controllers in colocation mode { config .enable_colocate_mode } " )
238+ logger .info (
239+ f"initialized all controllers in colocation mode { config .enable_colocate_mode } "
240+ )
258241 else :
259242 with ThreadPoolExecutor (max_workers = 3 ) as executor :
260243 futures = [
@@ -268,7 +251,10 @@ def init_train_and_rollout_controller_helper(actor, rollout):
268251 storage_prefix = config .storage_prefix ,
269252 ),
270253 executor .submit (
271- rollout .initialize , role = "rollout" , alloc_mode = allocation_mode
254+ rollout .initialize ,
255+ role = "rollout" ,
256+ alloc_mode = allocation_mode ,
257+ enable_colocate_mode = config .enable_colocate_mode ,
272258 ),
273259 ]
274260 if ref is not None :
@@ -324,15 +310,19 @@ def init_train_and_rollout_controller_helper(actor, rollout):
324310 )
325311 clear_dir (weight_update_config .path )
326312 else :
327- with concurrent .futures .ThreadPoolExecutor (max_workers = 2 ) as executor :
313+ with concurrent .futures .ThreadPoolExecutor (
314+ max_workers = 2
315+ ) as executor :
328316 upload_future = executor .submit (
329317 actor .upload_weights , weight_update_config
330318 )
331319 update_future = executor .submit (
332320 rollout .update_weights , weight_update_config
333321 )
334322 wait_future_ordered ([upload_future , update_future ])
335- logger .info (f"{ weight_update_config .type } update weight succeeded, step: { step } " )
323+ logger .info (
324+ f"{ weight_update_config .type } update weight succeeded, step: { step } "
325+ )
336326
337327 with (
338328 stats_tracker .record_timing ("rollout_step" ),
@@ -360,11 +350,8 @@ def init_train_and_rollout_controller_helper(actor, rollout):
360350 reward_fn = "areal.extension.asystem.math_reward.reward_fn" ,
361351 gconfig = config .gconfig ,
362352 tokenizer = config .tokenizer_path ,
363- enable_thinking = False ,
364- dump_dir = os .path .join (
365- f"{ config .storage_prefix } /experiments/logs/root/{ config .experiment_name } /{ config .trial_name } " ,
366- "generated" ,
367- ),
353+ exp_name = config .experiment_name ,
354+ trial_name = config .trial_name ,
368355 ),
369356 )
370357 else :
@@ -375,24 +362,27 @@ def init_train_and_rollout_controller_helper(actor, rollout):
375362 reward_fn = "areal.extension.asystem.math_reward.reward_fn" ,
376363 gconfig = config .gconfig ,
377364 tokenizer = config .tokenizer_path ,
378- enable_thinking = False ,
379- dump_dir = os .path .join (
380- f"{ config .storage_prefix } /experiments/logs/root/{ config .experiment_name } /{ config .trial_name } " ,
381- "generated" ,
382- ),
365+ exp_name = config .experiment_name ,
366+ trial_name = config .trial_name ,
383367 ),
384368 )
385369
386- #TODO: calc_training_data_metrics
387- # with (stats_tracker.scope("training_data"), ):
388- # calc_training_data_metrics(rollout_res)
370+ # with (
371+ # stats_tracker.scope("training_data"),
372+ # ):
373+ # calc_training_data_metrics(batch.get_data())
389374 # calc_training_data_group_metrics(
390- # rollout_res , config.gconfig.n_samples
375+ # batch.get_data , config.gconfig.n_samples
391376 # )
392- # calc_training_data_version_metrics(rollout_res, global_step)
393- #
394- logger .info (f"rollout batch res: { batch } , reward: { batch ["rewards" ]} " )
395- with (stats_tracker .record_timing ("notify_rollout_end_event" ), ):
377+ # calc_training_data_version_metrics(batch.get_data, global_step)
378+
379+ logger .info (
380+ "rollout batch reward summary: %s" ,
381+ summarize_rewards (batch ["rewards" ]),
382+ )
383+ with (
384+ stats_tracker .record_timing ("notify_rollout_end_event" ),
385+ ):
396386 logger .info (
397387 f"start to notify_rollout_end_event, step: { step } , epoch: { epoch } "
398388 )
@@ -420,7 +410,9 @@ def init_train_and_rollout_controller_helper(actor, rollout):
420410 stats_tracker .record_timing ("train_step" ),
421411 stats_tracker .scope ("train" ),
422412 ):
423- with (stats_tracker .record_timing ("notify_train_start_event" ),):
413+ with (
414+ stats_tracker .record_timing ("notify_train_start_event" ),
415+ ):
424416 logger .info (
425417 f"start to notify_train_start_event, step: { step } , epoch: { epoch } "
426418 )
@@ -429,12 +421,14 @@ def init_train_and_rollout_controller_helper(actor, rollout):
429421 f"notify_train_start_event succeeded, step: { step } , epoch: { epoch } "
430422 )
431423
432- with (stats_tracker .record_timing ("train_distributed_batch" ), ):
424+ with (
425+ stats_tracker .record_timing ("train_distributed_batch" ),
426+ ):
433427 logger .info (f"start to train, step: { step } , epoch: { epoch } " )
434428 actor .train_batch (
435429 batch ,
436430 loss_fn = lambda logits , batch_data : None ,
437- loss_weight_fn = lambda batch_data : None
431+ loss_weight_fn = lambda batch_data : None ,
438432 )
439433 logger .info (
440434 f"train succeeded, step: { step } , epoch: { epoch } "
@@ -487,7 +481,9 @@ def init_train_and_rollout_controller_helper(actor, rollout):
487481 f"[Trainer] periodic_checkpoint recover save success, epoch:{ epoch } , epoch_step: { step } , global_step:{ global_step } "
488482 )
489483
490- with (stats_tracker .record_timing ("notify_train_end_event" ),):
484+ with (
485+ stats_tracker .record_timing ("notify_train_end_event" ),
486+ ):
491487 logger .info (
492488 f"start to notify_train_end_event, step: { step } , epoch: { epoch } "
493489 )
0 commit comments