Skip to content

Commit dc865ef

Browse files
authored
feat: support drop_last=False during validation (#1029)
Signed-off-by: ashors1 <ashors@nvidia.com> Signed-off-by: Anna Shors <ashors@nvidia.com>
1 parent 7675ae4 commit dc865ef

File tree

6 files changed

+260
-31
lines changed

6 files changed

+260
-31
lines changed

nemo_rl/algorithms/dpo.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from nemo_rl.algorithms.loss_functions import (
2727
DPOLossFn,
2828
)
29-
from nemo_rl.algorithms.utils import set_seed
29+
from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed
3030
from nemo_rl.data import DataConfig
3131
from nemo_rl.data.datasets import AllTaskProcessedDataset, preference_collate_fn
3232
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
@@ -87,7 +87,14 @@ class MasterConfig(TypedDict):
8787

8888
class DPOValMetrics(TypedDict):
8989
loss: float
90+
sft_loss: float
91+
preference_loss: float
9092
accuracy: float
93+
rewards_chosen_mean: float
94+
rewards_rejected_mean: float
95+
num_valid_samples: float
96+
global_valid_seqs: float
97+
global_valid_toks: float
9198

9299

93100
# =======================================================
@@ -187,7 +194,7 @@ def setup(
187194
],
188195
add_loss_mask=True,
189196
),
190-
drop_last=True,
197+
drop_last=False,
191198
)
192199
for k, v in val_dataset.items()
193200
}
@@ -255,6 +262,15 @@ def add_ref_logprobs_to_data(dataloader, policy, master_config, is_val=False):
255262
else master_config["policy"]["train_micro_batch_size"] * 2
256263
)
257264

265+
# when running validation with drop_last=False, we might end up with a partial batch.
266+
# In this case, we pad the batch to the next multiple of micro_batch_size * dp_size.
267+
dp_size = policy.sharding_annotations.get_axis_size("data_parallel")
268+
if batch.size % (dp_size * micro_batch_size) != 0:
269+
assert is_val, (
270+
"Partial batches should only happen during validation, but got a partial batch during training."
271+
)
272+
batch = maybe_pad_last_batch(batch, dp_size, micro_batch_size)
273+
258274
## append ref policy logprobs to batch
259275
logprobs = policy.get_reference_policy_logprobs(
260276
batch,
@@ -342,7 +358,7 @@ def validate_one_dataset(
342358
with timer.time("total_validation_time"):
343359
print(f"▶ Starting validation at step {step} for `{dataset_name}` set..")
344360

345-
val_metrics = defaultdict(lambda: 0.0)
361+
val_metrics = defaultdict(list)
346362
num_valid_batches = 0
347363
for batch_idx, val_batch in enumerate(
348364
add_ref_logprobs_to_data(val_dataloader, policy, master_config, is_val=True)
@@ -352,7 +368,7 @@ def validate_one_dataset(
352368
val_batch,
353369
loss_fn,
354370
eval_mode=True,
355-
gbs=val_batch_size * 2,
371+
gbs=val_batch.size,
356372
mbs=val_mbs * 2,
357373
)
358374

@@ -361,22 +377,61 @@ def validate_one_dataset(
361377
"No validation metrics were collected for this batch."
362378
" This is likely because there were no valid samples."
363379
)
364-
365380
else:
366-
for k, v in val_results["all_mb_metrics"].items():
367-
if k in {"lr", "wd", "global_valid_seqs", "global_valid_toks"}:
368-
val_metrics[k] += np.mean(v).item()
369-
else:
370-
val_metrics[k] += np.sum(v).item()
381+
for metric_name in DPOValMetrics.__annotations__.keys():
382+
reduction = (
383+
np.mean
384+
if metric_name in {"global_valid_seqs", "global_valid_toks"}
385+
else sum
386+
)
387+
val_metrics[metric_name] += [
388+
reduction(val_results["all_mb_metrics"][metric_name])
389+
]
390+
371391
num_valid_batches += 1
372392

373393
if val_batches > 0 and batch_idx >= val_batches - 1:
374394
break
375395

376-
for k, v in val_metrics.items():
377-
if k == "num_valid_samples":
378-
continue
379-
val_metrics[k] /= num_valid_batches
396+
if num_valid_batches > 0:
397+
sum_num_valid_samples = sum(val_metrics["num_valid_samples"])
398+
global_valid_toks = sum(val_metrics["global_valid_toks"])
399+
global_valid_seqs = sum(val_metrics["global_valid_seqs"])
400+
val_metrics = DPOValMetrics(
401+
num_valid_samples=sum_num_valid_samples,
402+
global_valid_seqs=global_valid_seqs,
403+
global_valid_toks=global_valid_toks,
404+
**{
405+
metric_name: sum(
406+
[
407+
value * weight
408+
for value, weight in zip(
409+
val_metrics[metric_name],
410+
val_metrics["num_valid_samples"],
411+
)
412+
]
413+
)
414+
/ sum_num_valid_samples
415+
for metric_name in DPOValMetrics.__annotations__.keys()
416+
if metric_name
417+
not in {
418+
"num_valid_samples",
419+
"global_valid_seqs",
420+
"global_valid_toks",
421+
}
422+
},
423+
)
424+
else:
425+
warnings.warn(
426+
"No validation metrics were collected."
427+
" This is likely because there were no valid samples in the validation set."
428+
)
429+
val_metrics = DPOValMetrics(
430+
**{
431+
metric_name: 0.0
432+
for metric_name in DPOValMetrics.__annotations__.keys()
433+
}
434+
)
380435

381436
# Calculate validation metrics
382437
policy.prepare_for_training()

nemo_rl/algorithms/rm.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from nemo_rl.algorithms.loss_functions import (
2727
PreferenceLoss,
2828
)
29-
from nemo_rl.algorithms.utils import set_seed
29+
from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed
3030
from nemo_rl.data import DataConfig
3131
from nemo_rl.data.datasets import (
3232
AllTaskProcessedDataset,
@@ -172,7 +172,7 @@ def setup(
172172
],
173173
add_loss_mask=False,
174174
),
175-
drop_last=True,
175+
drop_last=False,
176176
)
177177
for k, v in val_dataset.items()
178178
}
@@ -307,14 +307,20 @@ def validate_one_dataset(
307307
dict_val_metrics = defaultdict(list)
308308
num_valid_batches = 0
309309
for batch_idx, val_batch in enumerate(val_dataloader):
310+
# When running validation with drop_last=False, we might end up with a partial batch.
311+
# In this case, we pad the batch to the next multiple of micro_batch_size * dp_size.
312+
if val_batch.size < val_batch_size * 2:
313+
dp_size = policy.sharding_annotations.get_axis_size("data_parallel")
314+
val_batch = maybe_pad_last_batch(val_batch, dp_size, val_mbs * 2)
315+
310316
## just run model fwd
311317
val_results = policy.train(
312318
val_batch,
313319
loss_fn,
314320
eval_mode=True,
315-
## NOTE: we double the batch size here because each preference example corresponds to a pair of
316-
## examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch.
317-
gbs=val_batch_size * 2,
321+
gbs=val_batch.size,
322+
# NOTE: we double the batch size because each preference example corresponds to a pair of
323+
# examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch.
318324
mbs=val_mbs * 2,
319325
)
320326

nemo_rl/algorithms/sft.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from nemo_rl.algorithms.loss_functions import (
2525
NLLLoss,
2626
)
27-
from nemo_rl.algorithms.utils import set_seed
27+
from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed
2828
from nemo_rl.data import DataConfig
2929
from nemo_rl.data.datasets import AllTaskProcessedDataset, rl_collate_fn
3030
from nemo_rl.data.interfaces import TaskDataSpec
@@ -150,7 +150,7 @@ def setup(
150150
batch_size=sft_config["val_global_batch_size"],
151151
shuffle=False,
152152
collate_fn=rl_collate_fn,
153-
drop_last=True,
153+
drop_last=False,
154154
)
155155

156156
# ==========================
@@ -240,7 +240,7 @@ def validate(
240240
# val_total = len(val_dataloader)
241241

242242
val_metrics = {"val_loss": 0.0}
243-
num_valid_batches = 0
243+
sum_num_valid_tokens = 0
244244

245245
policy.prepare_for_training()
246246
for batch_idx, val_batch in enumerate(val_dataloader):
@@ -269,13 +269,18 @@ def validate(
269269

270270
# update multimodal data
271271
val_data.update(cat_and_padded.get_multimodal_dict(as_tensors=False))
272+
# When running validation with drop_last=False, we might end up with a partial batch.
273+
# Check if we need to pad the final batch to make it divisible by micro_batch_size * dp_size.
274+
if val_data.size < val_batch_size:
275+
dp_size = policy.sharding_annotations.get_axis_size("data_parallel")
276+
val_data = maybe_pad_last_batch(val_data, dp_size, val_mbs)
272277

273278
## just run model fwd
274279
val_results = policy.train(
275280
val_data,
276281
loss_fn,
277282
eval_mode=True,
278-
gbs=val_batch_size,
283+
gbs=val_data.size,
279284
mbs=val_mbs,
280285
)
281286

@@ -285,14 +290,17 @@ def validate(
285290
" This is likely because there were no valid samples."
286291
)
287292
else:
288-
val_metrics["val_loss"] += float(val_results["loss"])
289-
num_valid_batches += 1
293+
num_valid_tokens = (
294+
val_data["sample_mask"].unsqueeze(-1) * val_data["token_mask"]
295+
).sum()
296+
val_metrics["val_loss"] += float(val_results["loss"]) * num_valid_tokens
297+
sum_num_valid_tokens += num_valid_tokens
290298

291299
if val_batches > 0 and batch_idx >= val_batches - 1:
292300
break
293301

294-
if num_valid_batches > 0:
295-
val_metrics["val_loss"] /= num_valid_batches
302+
if sum_num_valid_tokens > 0:
303+
val_metrics["val_loss"] /= sum_num_valid_tokens
296304
else:
297305
warnings.warn(
298306
"No validation metrics were collected."
@@ -306,7 +314,7 @@ def validate(
306314
timing_metrics = timer.get_timing_metrics(reduction_op="sum")
307315
validation_time = timing_metrics.get("total_validation_time", 0)
308316

309-
if num_valid_batches > 0:
317+
if sum_num_valid_tokens > 0:
310318
# Print summary of validation results
311319
print("\n📊 Validation Results:")
312320
print(f" • Validation loss: {val_metrics['val_loss']:.4f}")

nemo_rl/algorithms/utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
import math
1416
import random
1517
import warnings
1618
from functools import wraps
@@ -265,3 +267,62 @@ def get_tokenizer(
265267
processor.name_or_path = tokenizer.name_or_path
266268

267269
return tokenizer if processor is None else processor
270+
271+
272+
def maybe_pad_last_batch(batch: dict, dp_size: int, mbs: int) -> dict:
273+
"""Pads the given batch so that its size is divisible by (mbs * dp_size).
274+
275+
Args:
276+
batch (dict): The batch to pad.
277+
dp_size (int): Data parallel size.
278+
mbs (int): Micro batch size.
279+
280+
Returns:
281+
dict: The padded batch.
282+
"""
283+
min_padding = (math.ceil(batch.size / (mbs * dp_size)) * mbs * dp_size) - batch.size
284+
if min_padding > 0:
285+
print(f"Padding last validation batch with {min_padding} padding samples")
286+
# Pad input_ids
287+
batch["input_ids"] = torch.cat(
288+
[
289+
batch["input_ids"],
290+
batch["input_ids"][-1].unsqueeze(0).repeat(min_padding, 1),
291+
]
292+
)
293+
# Pad input_lengths
294+
batch["input_lengths"] = torch.cat(
295+
[
296+
batch["input_lengths"],
297+
batch["input_lengths"][-1].unsqueeze(0).repeat(min_padding),
298+
]
299+
)
300+
if "token_mask" in batch:
301+
# Pad token_mask
302+
batch["token_mask"] = torch.cat(
303+
[
304+
batch["token_mask"],
305+
batch["token_mask"][-1].unsqueeze(0).repeat(min_padding, 1),
306+
]
307+
)
308+
# Pad sample_mask
309+
batch["sample_mask"] = torch.cat(
310+
[
311+
batch["sample_mask"],
312+
torch.zeros_like(batch["sample_mask"][-1])
313+
.unsqueeze(0)
314+
.repeat(min_padding),
315+
]
316+
)
317+
318+
if "reference_policy_logprobs" in batch:
319+
# Pad reference_policy_logprobs
320+
batch["reference_policy_logprobs"] = torch.cat(
321+
[
322+
batch["reference_policy_logprobs"],
323+
batch["reference_policy_logprobs"][-1]
324+
.unsqueeze(0)
325+
.repeat(min_padding, 1),
326+
]
327+
)
328+
return batch

tests/unit/algorithms/test_dpo.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,31 @@
1414

1515
from unittest.mock import MagicMock
1616

17+
import numpy as np
1718
import torch
1819

1920
from nemo_rl.algorithms.dpo import add_ref_logprobs_to_data
21+
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
22+
from nemo_rl.distributed.named_sharding import NamedSharding
2023

2124

2225
class MockPolicy:
2326
def __init__(self, logprobs):
2427
self.logprobs = logprobs
28+
self.sharding_annotations = NamedSharding(
29+
layout=np.arange(2).reshape(
30+
1, # PP
31+
-1, # DP
32+
1, # CP
33+
1, # TP
34+
),
35+
names=[
36+
"pipeline_parallel",
37+
"data_parallel",
38+
"context_parallel",
39+
"tensor_parallel",
40+
],
41+
)
2542

2643
def get_reference_policy_logprobs(self, batch, micro_batch_size):
2744
return {"reference_logprobs": self.logprobs}
@@ -30,7 +47,7 @@ def get_reference_policy_logprobs(self, batch, micro_batch_size):
3047
def test_add_logprobs_to_batch():
3148
"""Test that add_ref_logprobs_to_data correctly adds reference policy logprobs to batches."""
3249
# Create mock data
33-
batch_size = 2
50+
batch_size = 8
3451
seq_len = 4
3552
vocab_size = 16
3653

@@ -45,7 +62,7 @@ def test_add_logprobs_to_batch():
4562

4663
# Create a mock dataloader that yields our mock batch
4764
mock_dataloader = MagicMock()
48-
mock_dataloader.__iter__.return_value = iter([mock_batch])
65+
mock_dataloader.__iter__.return_value = iter([BatchedDataDict(mock_batch)])
4966

5067
# Create a mock policy that returns our mock logprobs
5168
mock_policy = MockPolicy(mock_logprobs)

0 commit comments

Comments
 (0)