1+ import os
2+ import sys
3+ from copy import deepcopy
4+
5+ import torch .distributed as dist
6+ from torchdata .stateful_dataloader import StatefulDataLoader
7+
8+ from areal .api .alloc_mode import AllocationMode
9+ from areal .api .cli_args import GRPOConfig , load_expr_config
10+ from areal .api .io_struct import FinetuneSpec , StepInfo , WeightUpdateMeta
11+ from areal .dataset import get_custom_dataset
12+ from areal .engine .ppo .actor import FSDPPPOActor
13+ from areal .engine .sglang_remote import RemoteSGLangEngine
14+ from areal .platforms import current_platform
15+ from areal .utils import seeding , stats_tracker
16+ from areal .utils .data import (
17+ broadcast_tensor_container ,
18+ cycle_dataloader ,
19+ tensor_container_to ,
20+ )
21+ from areal .utils .device import log_gpu_stats
22+ from areal .utils .evaluator import Evaluator
23+ from areal .utils .hf_utils import load_hf_tokenizer
24+ from areal .utils .recover import RecoverHandler
25+ from areal .utils .saver import Saver
26+ from areal .utils .stats_logger import StatsLogger
27+ from areal .workflow .rlvr import RLVRWorkflow
28+
29+ from typing import TYPE_CHECKING , Optional
30+ from datasets import load_dataset
31+ from datasets .distributed import split_dataset_by_node
32+ if TYPE_CHECKING :
33+ from datasets import Dataset
34+ from transformers .processing_utils import ProcessorMixin
35+ from transformers .tokenization_utils_fast import PreTrainedTokenizerFast
36+
37+ def gsm8k_reward_fn (prompt , completions , prompt_ids , completion_ids , answer , ** kwargs ):
38+ from areal .reward .math_parser import process_results
39+
40+ return int (process_results (completions , answer )[0 ])
41+
42+ def load_greso_dataset (
43+ path : str ,
44+ rank : int ,
45+ world_size : int ,
46+ type : str = "sft" ,
47+ split : Optional [str ] = None ,
48+ max_length : Optional [int ] = None ,
49+ tokenizer : Optional ["PreTrainedTokenizerFast" ] = None ,
50+ processor : Optional ["ProcessorMixin" ] = None ,
51+ ** kwargs ,
52+ ) -> "Dataset" :
53+ dataset = load_dataset ("parquet" , data_dir = path , split = split )
54+
55+ def process (sample ):
56+ return {"messages" : sample ["messages" ], "answer" : sample ["answer" ]}
57+
58+ dataset = dataset .map (process )
59+
60+ # Filter out sequences longer than max_length if tokenizer and max_length are provided
61+ if max_length is not None :
62+
63+ def filter_length (sample ):
64+ # Tokenize the user content to check length
65+ content = sample ["messages" ][0 ]["content" ]
66+ tokens = tokenizer .encode (content )
67+ return len (tokens ) <= max_length
68+
69+ dataset = dataset .filter (filter_length )
70+
71+ dataset = split_dataset_by_node (dataset , rank = rank , world_size = world_size )
72+ return dataset
73+
74+
75+ def main (args ):
76+ config , _ = load_expr_config (args , GRPOConfig )
77+ config : GRPOConfig
78+
79+ rank = int (os .getenv ("RANK" ))
80+ tokenizer = load_hf_tokenizer (config .tokenizer_path )
81+
82+ seeding .set_random_seed (config .seed , key = f"trainer{ rank } " )
83+ allocation_mode = AllocationMode .from_str (config .allocation_mode )
84+ parallel_strategy = allocation_mode .train
85+ assert parallel_strategy is not None
86+
87+ # Initialize train engine
88+ actor = FSDPPPOActor (config = config .actor )
89+ actor .create_process_group (parallel_strategy = parallel_strategy )
90+
91+ train_dataset = load_greso_dataset (
92+ path = config .train_dataset .path ,
93+ rank = actor .data_parallel_rank ,
94+ world_size = actor .data_parallel_world_size ,
95+ split = "train" ,
96+ max_length = config .train_dataset .max_length ,
97+ type = config .train_dataset .type ,
98+ tokenizer = tokenizer ,
99+ )
100+ valid_dataset = load_greso_dataset (
101+ path = config .valid_dataset .path ,
102+ rank = actor .data_parallel_rank ,
103+ world_size = actor .data_parallel_world_size ,
104+ split = "test" ,
105+ max_length = config .valid_dataset .max_length ,
106+ type = config .valid_dataset .type ,
107+ tokenizer = tokenizer ,
108+ )
109+
110+ # Create dataset and dataloaders
111+ train_dataloader = StatefulDataLoader (
112+ train_dataset ,
113+ batch_size = config .train_dataset .batch_size // actor .data_parallel_world_size ,
114+ shuffle = config .train_dataset .shuffle ,
115+ num_workers = config .train_dataset .num_workers ,
116+ collate_fn = lambda x : x ,
117+ drop_last = config .train_dataset .drop_last ,
118+ )
119+ valid_dataloader = StatefulDataLoader (
120+ valid_dataset ,
121+ batch_size = config .valid_dataset .batch_size // actor .data_parallel_world_size ,
122+ shuffle = config .valid_dataset .shuffle ,
123+ num_workers = config .valid_dataset .num_workers ,
124+ collate_fn = lambda x : x ,
125+ drop_last = config .valid_dataset .drop_last ,
126+ )
127+ ft_spec = FinetuneSpec (
128+ total_train_epochs = config .total_train_epochs ,
129+ dataset_size = len (train_dataloader ) * config .train_dataset .batch_size ,
130+ train_batch_size = config .train_dataset .batch_size ,
131+ )
132+
133+ # Initialize inference engine
134+ rollout = RemoteSGLangEngine (config .rollout )
135+ rollout .initialize (train_data_parallel_size = parallel_strategy .dp_size )
136+ eval_rollout = RemoteSGLangEngine (deepcopy (config .rollout ))
137+ # NOTE: eval does not have any offpolicyness control
138+ eval_rollout .config .max_head_offpolicyness = int (1e12 )
139+ eval_rollout .initialize ()
140+
141+ actor .initialize (None , ft_spec )
142+ ref = None
143+ if config .actor .kl_ctl > 0 and config .ref is not None :
144+ ref = FSDPPPOActor (config = config .ref )
145+ ref .create_process_group (parallel_strategy = parallel_strategy )
146+ ref .initialize (None , ft_spec )
147+
148+ # NOTE: Weight update meta only requires address and free port of rank 0,
149+ # but `WeightUpdateMeta.from_fsdp_xccl` has to be executed on all ranks
150+ # due to `engine.get_param_specs()`.
151+ # Therefore, we create weight update meta on all ranks, then broadcast the one on rank 0.
152+ weight_update_meta = [
153+ WeightUpdateMeta .from_fsdp_xccl (
154+ AllocationMode .from_str (config .allocation_mode ), actor
155+ )
156+ ]
157+ dist .broadcast_object_list (weight_update_meta , src = 0 )
158+ weight_update_meta = weight_update_meta [0 ]
159+
160+ # Create rollout workflow
161+ if tokenizer .pad_token_id not in config .gconfig .stop_token_ids :
162+ config .gconfig .stop_token_ids .append (tokenizer .pad_token_id )
163+ if tokenizer .eos_token_id not in config .gconfig .stop_token_ids :
164+ config .gconfig .stop_token_ids .append (tokenizer .eos_token_id )
165+ workflow = RLVRWorkflow (
166+ reward_fn = gsm8k_reward_fn ,
167+ gconfig = config .gconfig ,
168+ tokenizer = tokenizer ,
169+ enable_thinking = False ,
170+ dump_dir = os .path .join (
171+ StatsLogger .get_log_path (config .stats_logger ), "generated"
172+ ),
173+ )
174+ eval_workflow = RLVRWorkflow (
175+ reward_fn = gsm8k_reward_fn ,
176+ gconfig = config .gconfig .new (temperature = 0.6 ),
177+ tokenizer = tokenizer ,
178+ enable_thinking = False ,
179+ rollout_stat_scope = "eval-rollout" ,
180+ dump_dir = os .path .join (
181+ StatsLogger .get_log_path (config .stats_logger ), "generated-eval"
182+ ),
183+ )
184+
185+ # Run training.
186+ saver = Saver (config .saver , ft_spec )
187+ stats_logger = StatsLogger (config .stats_logger , ft_spec )
188+ evaluator = Evaluator (config .evaluator , ft_spec )
189+
190+ recover_handler = RecoverHandler (config .recover , ft_spec )
191+ recover_info = recover_handler .load (
192+ actor ,
193+ saver ,
194+ evaluator ,
195+ stats_logger ,
196+ train_dataloader ,
197+ inference_engine = rollout ,
198+ weight_update_meta = weight_update_meta ,
199+ )
200+ start_step = (
201+ recover_info .last_step_info .next ().global_step
202+ if recover_info is not None
203+ else 0
204+ )
205+
206+ total_epochs = config .total_train_epochs
207+ steps_per_epoch = len (train_dataloader )
208+ max_steps = total_epochs * steps_per_epoch
209+
210+ data_generator = cycle_dataloader (train_dataloader )
211+ for global_step in range (start_step , max_steps ):
212+ epoch = global_step // steps_per_epoch
213+ step = global_step % steps_per_epoch
214+ step_info = StepInfo (
215+ global_step = global_step ,
216+ epoch = epoch ,
217+ epoch_step = step ,
218+ steps_per_epoch = steps_per_epoch ,
219+ )
220+
221+ with stats_tracker .record_timing ("rollout" ):
222+ batch = None
223+ if actor .is_data_parallel_head ():
224+ if config .async_training :
225+ batch = rollout .prepare_batch (
226+ train_dataloader ,
227+ workflow = workflow ,
228+ should_accept = lambda sample : True ,
229+ )
230+ else :
231+ batch = rollout .rollout_batch (
232+ next (data_generator ),
233+ workflow = workflow ,
234+ should_accept = lambda sample : True ,
235+ )
236+ batch = tensor_container_to (batch , actor .device )
237+ batch = broadcast_tensor_container (
238+ batch ,
239+ src_rank = actor .current_data_parallel_head (),
240+ group = actor .context_and_model_parallel_group ,
241+ )
242+ # Create barrier to synchronize all rollout processes.
243+ dist .barrier (device_ids = [actor .device .index ])
244+ current_platform .synchronize ()
245+
246+ if config .actor .recompute_logprob or config .actor .use_decoupled_loss :
247+ with stats_tracker .record_timing ("recompute_logp" ):
248+ logp = actor .compute_logp (batch )
249+ batch ["prox_logp" ] = logp
250+ log_gpu_stats ("recompute logp" )
251+
252+ if ref is not None :
253+ with stats_tracker .record_timing ("ref_logp" ):
254+ batch ["ref_logp" ] = ref .compute_logp (batch )
255+ log_gpu_stats ("ref logp" )
256+
257+ with stats_tracker .record_timing ("compute_advantage" ):
258+ actor .compute_advantages (batch )
259+ log_gpu_stats ("compute advantages" )
260+
261+ with (
262+ stats_tracker .record_timing ("train_step" ),
263+ stats_tracker .scope ("grpo_actor" ),
264+ ):
265+ stats = actor .ppo_update (batch )
266+ actor .step_lr_scheduler ()
267+ log_gpu_stats ("ppo update" )
268+
269+ # pause inference for updating weights, save, and evaluation
270+ rollout .pause ()
271+
272+ with stats_tracker .record_timing ("update_weights" ):
273+ if dist .get_rank () == 0 :
274+ future = rollout .update_weights (weight_update_meta )
275+ actor .upload_weights (weight_update_meta )
276+ if dist .get_rank () == 0 :
277+ future .result ()
278+ dist .barrier (device_ids = [actor .device .index ])
279+ current_platform .synchronize ()
280+
281+ actor .set_version (global_step + 1 )
282+ rollout .set_version (global_step + 1 )
283+ eval_rollout .set_version (global_step + 1 )
284+
285+ with stats_tracker .record_timing ("save" ):
286+ saver .save (actor , epoch , step , global_step , tokenizer = tokenizer )
287+
288+ with stats_tracker .record_timing ("checkpoint_for_recover" ):
289+ recover_handler .dump (
290+ actor ,
291+ step_info ,
292+ saver ,
293+ evaluator ,
294+ stats_logger ,
295+ train_dataloader ,
296+ tokenizer = tokenizer ,
297+ )
298+
299+ dist .barrier (device_ids = [actor .device .index ])
300+ current_platform .synchronize ()
301+
302+ with stats_tracker .record_timing ("eval" ):
303+
304+ def evaluate_fn ():
305+ if actor .is_data_parallel_head ():
306+ # Stats are logged in workflow
307+ # and will be exported later
308+ cnt = 0
309+ for data in valid_dataloader :
310+ for item in data :
311+ eval_rollout .submit (item , eval_workflow )
312+ cnt += 1
313+ eval_rollout .wait (cnt , timeout = None )
314+ dist .barrier (device_ids = [actor .device .index ])
315+ current_platform .synchronize ()
316+
317+ evaluator .evaluate (
318+ evaluate_fn ,
319+ epoch ,
320+ step ,
321+ global_step ,
322+ )
323+
324+ dist .barrier (device_ids = [actor .device .index ])
325+ current_platform .synchronize ()
326+
327+ # Upload statistics to the logger (e.g., wandb)
328+ stats [0 ].update (
329+ stats_tracker .export_all (reduce_group = actor .data_parallel_group )
330+ )
331+ stats_logger .commit (epoch , step , global_step , stats )
332+
333+ dist .barrier (device_ids = [actor .device .index ])
334+ current_platform .synchronize ()
335+
336+ # Resume rollout
337+ rollout .resume ()
338+
339+ stats_logger .close ()
340+ eval_rollout .destroy ()
341+ rollout .destroy ()
342+ if ref is not None :
343+ ref .destroy ()
344+ actor .destroy ()
345+
346+
347+ if __name__ == "__main__" :
348+ main (sys .argv [1 :])
0 commit comments