@@ -197,7 +197,7 @@ def setup(
197197 )
198198 dataloader .load_state_dict (dataloader_state_dict )
199199
200- print (f" ✓ Training dataloader loaded with { len (dataset )} samples" )
200+ print (f" ✓ Training dataloader loaded with { len (dataset )} samples" , flush = True )
201201
202202 # Load validation dataset if provided
203203 val_dataloader : Optional [StatefulDataLoader ] = None
@@ -212,12 +212,15 @@ def setup(
212212 shuffle = False ,
213213 collate_fn = rl_collate_fn ,
214214 )
215- print (f" ✓ Validation dataloader loaded with { len (val_dataset )} samples" )
215+ print (
216+ f" ✓ Validation dataloader loaded with { len (val_dataset )} samples" ,
217+ flush = True ,
218+ )
216219
217220 # ==========================
218221 # Cluster
219222 # ==========================
220- print ("\n ▶ Setting up compute cluster..." )
223+ print ("\n ▶ Setting up compute cluster..." , flush = True )
221224 colocated_inference = generation_config ["colocated" ]["enabled" ]
222225
223226 if colocated_inference :
@@ -233,7 +236,10 @@ def setup(
233236 )
234237 train_cluster = cluster
235238 inference_cluster = cluster
236- print (f" ✓ Ray cluster initialized with { cluster_config ['num_nodes' ]} nodes" )
239+ print (
240+ f" ✓ Ray cluster initialized with { cluster_config ['num_nodes' ]} nodes" ,
241+ flush = True ,
242+ )
237243
238244 else :
239245 assert generation_config ["backend" ] != "megatron" , (
@@ -289,7 +295,8 @@ def setup(
289295 max_colocated_worker_groups = 1 ,
290296 )
291297 print (
292- f" ✓ Ray train cluster initialized with { train_nodes } nodes with { train_gpus_per_node } GPUs per node"
298+ f" ✓ Ray train cluster initialized with { train_nodes } nodes with { train_gpus_per_node } GPUs per node" ,
299+ flush = True ,
293300 )
294301
295302 # initialize inference cluster
@@ -301,13 +308,14 @@ def setup(
301308 max_colocated_worker_groups = 1 ,
302309 )
303310 print (
304- f" ✓ Ray inference cluster initialized with { inference_nodes } nodes with { inference_gpus_per_node } GPUs per node"
311+ f" ✓ Ray inference cluster initialized with { inference_nodes } nodes with { inference_gpus_per_node } GPUs per node" ,
312+ flush = True ,
305313 )
306314
307315 # ==========================
308316 # Training and Inference
309317 # ==========================
310- print ("\n ▶ Setting up model and training..." )
318+ print ("\n ▶ Setting up model and training..." , flush = True )
311319
312320 # vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this)
313321 backend = generation_config ["backend" ]
@@ -316,7 +324,8 @@ def setup(
316324 if backend == "megatron" :
317325 policy_generation = None
318326 print (
319- f" ✓ Using { backend } backend for generation with { policy_config ['model_name' ]} "
327+ f" ✓ Using { backend } backend for generation with { policy_config ['model_name' ]} " ,
328+ flush = True ,
320329 )
321330 elif backend == "vllm" :
322331 generation_config = cast (VllmConfig , generation_config )
@@ -332,7 +341,8 @@ def setup(
332341 # vllm 0.8 fails in initialization if its called in the first training step since it has no clean view of the GPU memory (HF is sharing the same memory).
333342 policy_generation .finish_generation ()
334343 print (
335- f" ✓ Using vLLM backend for generation with { policy_config ['model_name' ]} "
344+ f" ✓ Using vLLM backend for generation with { policy_config ['model_name' ]} " ,
345+ flush = True ,
336346 )
337347
338348 if last_checkpoint_path :
@@ -355,7 +365,7 @@ def setup(
355365 # if it is not colocated inference, initialize collective communication for update weights
356366 if not colocated_inference :
357367 ip , port = train_cluster .get_master_address_and_port ()
358- print (f"Using ip: { ip } , port: { port } for collective communication" )
368+ print (f"Using ip: { ip } , port: { port } for collective communication" , flush = True )
359369 # inference cluster + head node of the train cluster
360370 world_size = inference_nodes * inference_gpus_per_node + 1
361371 # init collective
@@ -372,7 +382,7 @@ def setup(
372382
373383 print ("\n " + "=" * 60 )
374384 print (" " * 18 + "SETUP COMPLETE" )
375- print ("=" * 60 + "\n " )
385+ print ("=" * 60 + "\n " , flush = True )
376386
377387 return (
378388 policy ,
@@ -446,7 +456,8 @@ def refit_policy_generation(
446456 )
447457 total_num_keys = sum (len (k ) for k in grouped_param_keys )
448458 print (
449- f"[Refit] Split { total_num_keys } keys into { len (grouped_param_keys )} groups"
459+ f"[Refit] Split { total_num_keys } keys into { len (grouped_param_keys )} groups" ,
460+ flush = True ,
450461 )
451462 # do update
452463 for keys in grouped_param_keys :
@@ -525,7 +536,7 @@ def grpo_train(
525536
526537 # Run validation at the start if configured
527538 if val_at_start and step == 0 :
528- print ("\n 🔍 Running initial validation..." )
539+ print ("\n 🔍 Running initial validation..." , flush = True )
529540 if NEED_REFIT and POLICY_GENERATION_STALE :
530541 refit_policy_generation (policy , policy_generation , colocated_inference )
531542 POLICY_GENERATION_STALE = False
@@ -547,7 +558,8 @@ def grpo_train(
547558 batch : BatchedDataDict [DatumSpec ]
548559 for batch in dataloader :
549560 print (
550- f"\n { '=' * 25 } Step { step + 1 } /{ min (len (dataloader ), master_config ['grpo' ]['max_num_steps' ])} { '=' * 25 } "
561+ f"\n { '=' * 25 } Step { step + 1 } /{ min (len (dataloader ), master_config ['grpo' ]['max_num_steps' ])} { '=' * 25 } " ,
562+ flush = True ,
551563 )
552564 maybe_gpu_profile_step (policy , step + 1 )
553565 if policy != policy_generation :
@@ -556,7 +568,7 @@ def grpo_train(
556568
557569 with timer .time ("total_step_time" ):
558570 # Prepare batch
559- print ("▶ Preparing batch..." )
571+ print ("▶ Preparing batch..." , flush = True )
560572 with timer .time ("data_processing" ):
561573 # Repeat batch items
562574 repeated_batch : BatchedDataDict [DatumSpec ] = batch .repeat_interleave (
@@ -570,7 +582,10 @@ def grpo_train(
570582 input_ids = batched_flat ["token_ids" ]
571583
572584 # Generate responses - this updates the LLMMessageLogType in repeated_batch
573- print (f"▶ Generating responses for batch of size { repeated_batch .size } ..." )
585+ print (
586+ f"▶ Generating responses for batch of size { repeated_batch .size } ..." ,
587+ flush = True ,
588+ )
574589 with timer .time ("prepare_for_generation" ):
575590 if NEED_REFIT and POLICY_GENERATION_STALE :
576591 refit_policy_generation (
@@ -612,12 +627,12 @@ def grpo_train(
612627 policy_generation .finish_generation ()
613628
614629 # Calculate rewards & advantages
615- print ("▶ Processing rewards..." )
630+ print ("▶ Processing rewards..." , flush = True )
616631 with timer .time ("reward_calculation" ):
617632 # Extract rewards from final_batch
618633 rewards = repeated_batch ["total_reward" ]
619634
620- print ("▶ Computing advantages..." )
635+ print ("▶ Computing advantages..." , flush = True )
621636 baseline , std = calculate_baseline_and_std_per_prompt (
622637 input_ids ,
623638 rewards ,
@@ -689,11 +704,11 @@ def grpo_train(
689704 train_data .update (flat_messages .get_multimodal_dict (as_tensors = False ))
690705 train_data .to ("cpu" )
691706
692- print ("▶ Preparing for logprob inference..." )
707+ print ("▶ Preparing for logprob inference..." , flush = True )
693708 with timer .time ("logprob_inference_prep" ):
694709 policy .prepare_for_lp_inference ()
695710
696- print ("▶ Computing logprobs..." )
711+ print ("▶ Computing logprobs..." , flush = True )
697712 with timer .time ("policy_and_reference_logprobs" ):
698713 fprop_logprobs = policy .get_logprobs (train_data )["logprobs" ]
699714 reference_logprobs = policy .get_reference_policy_logprobs (train_data )[
@@ -702,12 +717,12 @@ def grpo_train(
702717 train_data ["prev_logprobs" ] = fprop_logprobs
703718 train_data ["reference_policy_logprobs" ] = reference_logprobs
704719
705- print ("▶ Preparing for training..." )
720+ print ("▶ Preparing for training..." , flush = True )
706721 with timer .time ("training_prep" ):
707722 policy .prepare_for_training () # set model train and reload optim to GPU
708723 POLICY_GENERATION_STALE = True
709724
710- print ("▶ Training policy..." )
725+ print ("▶ Training policy..." , flush = True )
711726 with timer .time ("policy_training" ):
712727 train_results = policy .train (train_data , loss_fn )
713728
@@ -774,7 +789,7 @@ def grpo_train(
774789 master_config ["checkpointing" ]["metric_name" ] = None
775790
776791 with timer .time ("checkpointing" ):
777- print (f"Saving checkpoint for step { step + 1 } ..." )
792+ print (f"Saving checkpoint for step { step + 1 } ..." , flush = True )
778793 checkpoint_path = checkpointer .init_tmp_checkpoint (
779794 step + 1 , grpo_save_state , master_config
780795 )
@@ -845,24 +860,27 @@ def grpo_train(
845860 print (f" • Loss: { metrics ['loss' ]:.4f} " )
846861 print (f" • Avg Reward: { np .mean (rewards .numpy ()):.4f} " )
847862 print (
848- f" • Mean Generation Length: { rollout_metrics ['mean_gen_tokens_per_sample' ]:.4f} "
863+ f" • Mean Generation Length: { rollout_metrics ['mean_gen_tokens_per_sample' ]:.4f} " ,
864+ flush = True ,
849865 )
850866 if "total_flops" in train_results :
851867 total_tflops = (
852868 train_results ["total_flops" ] / timing_metrics ["policy_training" ] / 1e12
853869 )
854870 num_ranks = train_results ["num_ranks" ]
855871 print (
856- f" • Training FLOPS: { total_tflops :.2f} TFLOPS ({ total_tflops / num_ranks :.2f} TFLOPS per rank)"
872+ f" • Training FLOPS: { total_tflops :.2f} TFLOPS ({ total_tflops / num_ranks :.2f} TFLOPS per rank)" ,
873+ flush = True ,
857874 )
858875 if "theoretical_tflops" in train_results :
859876 theoretical_tflops = train_results ["theoretical_tflops" ]
860877 print (
861- f" • Training Model Floating Point Utilization: { 100 * total_tflops / theoretical_tflops :.2f} %"
878+ f" • Training Model Floating Point Utilization: { 100 * total_tflops / theoretical_tflops :.2f} %" ,
879+ flush = True ,
862880 )
863881 metrics ["train_fp_utilization" ] = total_tflops / theoretical_tflops
864882
865- print ("\n ⏱️ Timing:" )
883+ print ("\n ⏱️ Timing:" , flush = True )
866884 # Display total time first, separately
867885 total_time = timing_metrics .get ("total_step_time" , 0 )
868886
@@ -878,15 +896,15 @@ def grpo_train(
878896 }
879897 )
880898
881- print (f" • Total step time: { total_time :.2f} s" )
899+ print (f" • Total step time: { total_time :.2f} s" , flush = True )
882900
883901 # Display all other timing metrics
884902 for k , v in sorted (
885903 timing_metrics .items (), key = lambda item : item [1 ], reverse = True
886904 ):
887905 if k != "total_step_time" :
888906 percent = (v / total_time * 100 ) if total_time > 0 else 0
889- print (f" • { k } : { v :.2f} s ({ percent :.1f} %)" )
907+ print (f" • { k } : { v :.2f} s ({ percent :.1f} %)" , flush = True )
890908
891909 logger .log_metrics (metrics , step + 1 , prefix = "train" )
892910 logger .log_metrics (timing_metrics , step + 1 , prefix = "timing/train" )
@@ -907,12 +925,12 @@ def validate(
907925) -> tuple [dict [str , Any ], dict [str , Any ]]:
908926 """Run validation on the validation dataset."""
909927 if val_dataloader is None :
910- print (" ⚠️ No validation dataloader provided, skipping validation" )
928+ print (" ⚠️ No validation dataloader provided, skipping validation" , flush = True )
911929 return {}, {}
912930
913931 timer = Timer ()
914932 with timer .time ("total_validation_time" ):
915- print (f"▶ Starting validation at step { step } ..." )
933+ print (f"▶ Starting validation at step { step } ..." , flush = True )
916934
917935 total_rewards = []
918936 total_lengths = []
@@ -985,7 +1003,7 @@ def validate(
9851003 )
9861004 except Exception as e :
9871005 print (f"\n ⚠️ Error displaying message samples: { str (e )} " )
988- print (" ⚠️ Continuing validation without displaying samples..." )
1006+ print (" ⚠️ Continuing validation without displaying samples..." , flush = True )
9891007
9901008 # Get timing metrics
9911009 timing_metrics = timer .get_timing_metrics (reduction_op = "sum" )
@@ -995,12 +1013,12 @@ def validate(
9951013 print ("\n 📊 Validation Results:" )
9961014 print (f" • Accuracy: { accuracy :.4f} " )
9971015 print (f" • Average response length: { avg_length :.1f} tokens" )
998- print (f" • Samples processed: { len (total_rewards )} " )
1016+ print (f" • Samples processed: { len (total_rewards )} " , flush = True )
9991017
10001018 # Print timing information
10011019 print ("\n ⏱️ Validation Timing:" )
10021020 validation_time = timing_metrics .get ("total_validation_time" , 0 )
1003- print (f" • Total validation time: { validation_time :.2f} s" )
1021+ print (f" • Total validation time: { validation_time :.2f} s" , flush = True )
10041022
10051023 # Make sure to reset the timer after validation
10061024 timer .reset ()
0 commit comments