Skip to content

Commit 7c5efd8

Browse files
authored
chore: flush to stdout when print logging during GRPO (#1021)
Signed-off-by: Peter Jin <pjin@nvidia.com>
1 parent 0358a86 commit 7c5efd8

File tree

1 file changed

+52
-34
lines changed

1 file changed

+52
-34
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def setup(
197197
)
198198
dataloader.load_state_dict(dataloader_state_dict)
199199

200-
print(f" ✓ Training dataloader loaded with {len(dataset)} samples")
200+
print(f" ✓ Training dataloader loaded with {len(dataset)} samples", flush=True)
201201

202202
# Load validation dataset if provided
203203
val_dataloader: Optional[StatefulDataLoader] = None
@@ -212,12 +212,15 @@ def setup(
212212
shuffle=False,
213213
collate_fn=rl_collate_fn,
214214
)
215-
print(f" ✓ Validation dataloader loaded with {len(val_dataset)} samples")
215+
print(
216+
f" ✓ Validation dataloader loaded with {len(val_dataset)} samples",
217+
flush=True,
218+
)
216219

217220
# ==========================
218221
# Cluster
219222
# ==========================
220-
print("\n▶ Setting up compute cluster...")
223+
print("\n▶ Setting up compute cluster...", flush=True)
221224
colocated_inference = generation_config["colocated"]["enabled"]
222225

223226
if colocated_inference:
@@ -233,7 +236,10 @@ def setup(
233236
)
234237
train_cluster = cluster
235238
inference_cluster = cluster
236-
print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes")
239+
print(
240+
f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes",
241+
flush=True,
242+
)
237243

238244
else:
239245
assert generation_config["backend"] != "megatron", (
@@ -289,7 +295,8 @@ def setup(
289295
max_colocated_worker_groups=1,
290296
)
291297
print(
292-
f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node"
298+
f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node",
299+
flush=True,
293300
)
294301

295302
# initialize inference cluster
@@ -301,13 +308,14 @@ def setup(
301308
max_colocated_worker_groups=1,
302309
)
303310
print(
304-
f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node"
311+
f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node",
312+
flush=True,
305313
)
306314

307315
# ==========================
308316
# Training and Inference
309317
# ==========================
310-
print("\n▶ Setting up model and training...")
318+
print("\n▶ Setting up model and training...", flush=True)
311319

312320
# vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this)
313321
backend = generation_config["backend"]
@@ -316,7 +324,8 @@ def setup(
316324
if backend == "megatron":
317325
policy_generation = None
318326
print(
319-
f" ✓ Using {backend} backend for generation with {policy_config['model_name']}"
327+
f" ✓ Using {backend} backend for generation with {policy_config['model_name']}",
328+
flush=True,
320329
)
321330
elif backend == "vllm":
322331
generation_config = cast(VllmConfig, generation_config)
@@ -332,7 +341,8 @@ def setup(
332341
# vllm 0.8 fails in initialization if its called in the first training step since it has no clean view of the GPU memory (HF is sharing the same memory).
333342
policy_generation.finish_generation()
334343
print(
335-
f" ✓ Using vLLM backend for generation with {policy_config['model_name']}"
344+
f" ✓ Using vLLM backend for generation with {policy_config['model_name']}",
345+
flush=True,
336346
)
337347

338348
if last_checkpoint_path:
@@ -355,7 +365,7 @@ def setup(
355365
# if it is not colocated inference, initialize collective communication for update weights
356366
if not colocated_inference:
357367
ip, port = train_cluster.get_master_address_and_port()
358-
print(f"Using ip: {ip}, port: {port} for collective communication")
368+
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
359369
# inference cluster + head node of the train cluster
360370
world_size = inference_nodes * inference_gpus_per_node + 1
361371
# init collective
@@ -372,7 +382,7 @@ def setup(
372382

373383
print("\n" + "=" * 60)
374384
print(" " * 18 + "SETUP COMPLETE")
375-
print("=" * 60 + "\n")
385+
print("=" * 60 + "\n", flush=True)
376386

377387
return (
378388
policy,
@@ -446,7 +456,8 @@ def refit_policy_generation(
446456
)
447457
total_num_keys = sum(len(k) for k in grouped_param_keys)
448458
print(
449-
f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups"
459+
f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups",
460+
flush=True,
450461
)
451462
# do update
452463
for keys in grouped_param_keys:
@@ -525,7 +536,7 @@ def grpo_train(
525536

526537
# Run validation at the start if configured
527538
if val_at_start and step == 0:
528-
print("\n🔍 Running initial validation...")
539+
print("\n🔍 Running initial validation...", flush=True)
529540
if NEED_REFIT and POLICY_GENERATION_STALE:
530541
refit_policy_generation(policy, policy_generation, colocated_inference)
531542
POLICY_GENERATION_STALE = False
@@ -547,7 +558,8 @@ def grpo_train(
547558
batch: BatchedDataDict[DatumSpec]
548559
for batch in dataloader:
549560
print(
550-
f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}"
561+
f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}",
562+
flush=True,
551563
)
552564
maybe_gpu_profile_step(policy, step + 1)
553565
if policy != policy_generation:
@@ -556,7 +568,7 @@ def grpo_train(
556568

557569
with timer.time("total_step_time"):
558570
# Prepare batch
559-
print("▶ Preparing batch...")
571+
print("▶ Preparing batch...", flush=True)
560572
with timer.time("data_processing"):
561573
# Repeat batch items
562574
repeated_batch: BatchedDataDict[DatumSpec] = batch.repeat_interleave(
@@ -570,7 +582,10 @@ def grpo_train(
570582
input_ids = batched_flat["token_ids"]
571583

572584
# Generate responses - this updates the LLMMessageLogType in repeated_batch
573-
print(f"▶ Generating responses for batch of size {repeated_batch.size}...")
585+
print(
586+
f"▶ Generating responses for batch of size {repeated_batch.size}...",
587+
flush=True,
588+
)
574589
with timer.time("prepare_for_generation"):
575590
if NEED_REFIT and POLICY_GENERATION_STALE:
576591
refit_policy_generation(
@@ -612,12 +627,12 @@ def grpo_train(
612627
policy_generation.finish_generation()
613628

614629
# Calculate rewards & advantages
615-
print("▶ Processing rewards...")
630+
print("▶ Processing rewards...", flush=True)
616631
with timer.time("reward_calculation"):
617632
# Extract rewards from final_batch
618633
rewards = repeated_batch["total_reward"]
619634

620-
print("▶ Computing advantages...")
635+
print("▶ Computing advantages...", flush=True)
621636
baseline, std = calculate_baseline_and_std_per_prompt(
622637
input_ids,
623638
rewards,
@@ -689,11 +704,11 @@ def grpo_train(
689704
train_data.update(flat_messages.get_multimodal_dict(as_tensors=False))
690705
train_data.to("cpu")
691706

692-
print("▶ Preparing for logprob inference...")
707+
print("▶ Preparing for logprob inference...", flush=True)
693708
with timer.time("logprob_inference_prep"):
694709
policy.prepare_for_lp_inference()
695710

696-
print("▶ Computing logprobs...")
711+
print("▶ Computing logprobs...", flush=True)
697712
with timer.time("policy_and_reference_logprobs"):
698713
fprop_logprobs = policy.get_logprobs(train_data)["logprobs"]
699714
reference_logprobs = policy.get_reference_policy_logprobs(train_data)[
@@ -702,12 +717,12 @@ def grpo_train(
702717
train_data["prev_logprobs"] = fprop_logprobs
703718
train_data["reference_policy_logprobs"] = reference_logprobs
704719

705-
print("▶ Preparing for training...")
720+
print("▶ Preparing for training...", flush=True)
706721
with timer.time("training_prep"):
707722
policy.prepare_for_training() # set model train and reload optim to GPU
708723
POLICY_GENERATION_STALE = True
709724

710-
print("▶ Training policy...")
725+
print("▶ Training policy...", flush=True)
711726
with timer.time("policy_training"):
712727
train_results = policy.train(train_data, loss_fn)
713728

@@ -774,7 +789,7 @@ def grpo_train(
774789
master_config["checkpointing"]["metric_name"] = None
775790

776791
with timer.time("checkpointing"):
777-
print(f"Saving checkpoint for step {step + 1}...")
792+
print(f"Saving checkpoint for step {step + 1}...", flush=True)
778793
checkpoint_path = checkpointer.init_tmp_checkpoint(
779794
step + 1, grpo_save_state, master_config
780795
)
@@ -845,24 +860,27 @@ def grpo_train(
845860
print(f" • Loss: {metrics['loss']:.4f}")
846861
print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}")
847862
print(
848-
f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}"
863+
f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}",
864+
flush=True,
849865
)
850866
if "total_flops" in train_results:
851867
total_tflops = (
852868
train_results["total_flops"] / timing_metrics["policy_training"] / 1e12
853869
)
854870
num_ranks = train_results["num_ranks"]
855871
print(
856-
f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)"
872+
f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)",
873+
flush=True,
857874
)
858875
if "theoretical_tflops" in train_results:
859876
theoretical_tflops = train_results["theoretical_tflops"]
860877
print(
861-
f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%"
878+
f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%",
879+
flush=True,
862880
)
863881
metrics["train_fp_utilization"] = total_tflops / theoretical_tflops
864882

865-
print("\n⏱️ Timing:")
883+
print("\n⏱️ Timing:", flush=True)
866884
# Display total time first, separately
867885
total_time = timing_metrics.get("total_step_time", 0)
868886

@@ -878,15 +896,15 @@ def grpo_train(
878896
}
879897
)
880898

881-
print(f" • Total step time: {total_time:.2f}s")
899+
print(f" • Total step time: {total_time:.2f}s", flush=True)
882900

883901
# Display all other timing metrics
884902
for k, v in sorted(
885903
timing_metrics.items(), key=lambda item: item[1], reverse=True
886904
):
887905
if k != "total_step_time":
888906
percent = (v / total_time * 100) if total_time > 0 else 0
889-
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")
907+
print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True)
890908

891909
logger.log_metrics(metrics, step + 1, prefix="train")
892910
logger.log_metrics(timing_metrics, step + 1, prefix="timing/train")
@@ -907,12 +925,12 @@ def validate(
907925
) -> tuple[dict[str, Any], dict[str, Any]]:
908926
"""Run validation on the validation dataset."""
909927
if val_dataloader is None:
910-
print(" ⚠️ No validation dataloader provided, skipping validation")
928+
print(" ⚠️ No validation dataloader provided, skipping validation", flush=True)
911929
return {}, {}
912930

913931
timer = Timer()
914932
with timer.time("total_validation_time"):
915-
print(f"▶ Starting validation at step {step}...")
933+
print(f"▶ Starting validation at step {step}...", flush=True)
916934

917935
total_rewards = []
918936
total_lengths = []
@@ -985,7 +1003,7 @@ def validate(
9851003
)
9861004
except Exception as e:
9871005
print(f"\n ⚠️ Error displaying message samples: {str(e)}")
988-
print(" ⚠️ Continuing validation without displaying samples...")
1006+
print(" ⚠️ Continuing validation without displaying samples...", flush=True)
9891007

9901008
# Get timing metrics
9911009
timing_metrics = timer.get_timing_metrics(reduction_op="sum")
@@ -995,12 +1013,12 @@ def validate(
9951013
print("\n📊 Validation Results:")
9961014
print(f" • Accuracy: {accuracy:.4f}")
9971015
print(f" • Average response length: {avg_length:.1f} tokens")
998-
print(f" • Samples processed: {len(total_rewards)}")
1016+
print(f" • Samples processed: {len(total_rewards)}", flush=True)
9991017

10001018
# Print timing information
10011019
print("\n ⏱️ Validation Timing:")
10021020
validation_time = timing_metrics.get("total_validation_time", 0)
1003-
print(f" • Total validation time: {validation_time:.2f}s")
1021+
print(f" • Total validation time: {validation_time:.2f}s", flush=True)
10041022

10051023
# Make sure to reset the timer after validation
10061024
timer.reset()

0 commit comments

Comments
 (0)