2626from nemo_rl .algorithms .loss_functions import (
2727 DPOLossFn ,
2828)
29- from nemo_rl .algorithms .utils import set_seed
29+ from nemo_rl .algorithms .utils import maybe_pad_last_batch , set_seed
3030from nemo_rl .data import DataConfig
3131from nemo_rl .data .datasets import AllTaskProcessedDataset , preference_collate_fn
3232from nemo_rl .distributed .virtual_cluster import ClusterConfig , RayVirtualCluster
@@ -87,7 +87,14 @@ class MasterConfig(TypedDict):
8787
8888class DPOValMetrics (TypedDict ):
8989 loss : float
90+ sft_loss : float
91+ preference_loss : float
9092 accuracy : float
93+ rewards_chosen_mean : float
94+ rewards_rejected_mean : float
95+ num_valid_samples : float
96+ global_valid_seqs : float
97+ global_valid_toks : float
9198
9299
93100# =======================================================
@@ -187,7 +194,7 @@ def setup(
187194 ],
188195 add_loss_mask = True ,
189196 ),
190- drop_last = True ,
197+ drop_last = False ,
191198 )
192199 for k , v in val_dataset .items ()
193200 }
@@ -255,6 +262,15 @@ def add_ref_logprobs_to_data(dataloader, policy, master_config, is_val=False):
255262 else master_config ["policy" ]["train_micro_batch_size" ] * 2
256263 )
257264
265+ # when running validation with drop_last=False, we might end up with a partial batch.
266+ # In this case, we pad the batch to the next multiple of micro_batch_size * dp_size.
267+ dp_size = policy .sharding_annotations .get_axis_size ("data_parallel" )
268+ if batch .size % (dp_size * micro_batch_size ) != 0 :
269+ assert is_val , (
270+ "Partial batches should only happen during validation, but got a partial batch during training."
271+ )
272+ batch = maybe_pad_last_batch (batch , dp_size , micro_batch_size )
273+
258274 ## append ref policy logprobs to batch
259275 logprobs = policy .get_reference_policy_logprobs (
260276 batch ,
@@ -342,7 +358,7 @@ def validate_one_dataset(
342358 with timer .time ("total_validation_time" ):
343359 print (f"▶ Starting validation at step { step } for `{ dataset_name } ` set.." )
344360
345- val_metrics = defaultdict (lambda : 0.0 )
361+ val_metrics = defaultdict (list )
346362 num_valid_batches = 0
347363 for batch_idx , val_batch in enumerate (
348364 add_ref_logprobs_to_data (val_dataloader , policy , master_config , is_val = True )
@@ -352,7 +368,7 @@ def validate_one_dataset(
352368 val_batch ,
353369 loss_fn ,
354370 eval_mode = True ,
355- gbs = val_batch_size * 2 ,
371+ gbs = val_batch . size ,
356372 mbs = val_mbs * 2 ,
357373 )
358374
@@ -361,22 +377,61 @@ def validate_one_dataset(
361377 "No validation metrics were collected for this batch."
362378 " This is likely because there were no valid samples."
363379 )
364-
365380 else :
366- for k , v in val_results ["all_mb_metrics" ].items ():
367- if k in {"lr" , "wd" , "global_valid_seqs" , "global_valid_toks" }:
368- val_metrics [k ] += np .mean (v ).item ()
369- else :
370- val_metrics [k ] += np .sum (v ).item ()
381+ for metric_name in DPOValMetrics .__annotations__ .keys ():
382+ reduction = (
383+ np .mean
384+ if metric_name in {"global_valid_seqs" , "global_valid_toks" }
385+ else sum
386+ )
387+ val_metrics [metric_name ] += [
388+ reduction (val_results ["all_mb_metrics" ][metric_name ])
389+ ]
390+
371391 num_valid_batches += 1
372392
373393 if val_batches > 0 and batch_idx >= val_batches - 1 :
374394 break
375395
376- for k , v in val_metrics .items ():
377- if k == "num_valid_samples" :
378- continue
379- val_metrics [k ] /= num_valid_batches
396+ if num_valid_batches > 0 :
397+ sum_num_valid_samples = sum (val_metrics ["num_valid_samples" ])
398+ global_valid_toks = sum (val_metrics ["global_valid_toks" ])
399+ global_valid_seqs = sum (val_metrics ["global_valid_seqs" ])
400+ val_metrics = DPOValMetrics (
401+ num_valid_samples = sum_num_valid_samples ,
402+ global_valid_seqs = global_valid_seqs ,
403+ global_valid_toks = global_valid_toks ,
404+ ** {
405+ metric_name : sum (
406+ [
407+ value * weight
408+ for value , weight in zip (
409+ val_metrics [metric_name ],
410+ val_metrics ["num_valid_samples" ],
411+ )
412+ ]
413+ )
414+ / sum_num_valid_samples
415+ for metric_name in DPOValMetrics .__annotations__ .keys ()
416+ if metric_name
417+ not in {
418+ "num_valid_samples" ,
419+ "global_valid_seqs" ,
420+ "global_valid_toks" ,
421+ }
422+ },
423+ )
424+ else :
425+ warnings .warn (
426+ "No validation metrics were collected."
427+ " This is likely because there were no valid samples in the validation set."
428+ )
429+ val_metrics = DPOValMetrics (
430+ ** {
431+ metric_name : 0.0
432+ for metric_name in DPOValMetrics .__annotations__ .keys ()
433+ }
434+ )
380435
381436 # Calculate validation metrics
382437 policy .prepare_for_training ()
0 commit comments