Skip to content

Commit c48508e

Browse files
committed
fix seq packing
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent bb7ae5e commit c48508e

File tree

2 files changed

+51
-25
lines changed

2 files changed

+51
-25
lines changed

nemo_rl/algorithms/loss_functions.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
from typing import Any, NotRequired, Optional, TypedDict, TypeVar
15+
from typing import Any, Callable, NotRequired, Optional, TypedDict, TypeVar
1616

1717
import torch
1818
import torch.distributed
@@ -802,22 +802,27 @@ class SequencePackingLossWrapper:
802802
def __init__(
803803
self,
804804
loss_fn: LossFunction,
805+
prepare_fn: Callable[Any, Any],
805806
cu_seqlens_q: Tensor,
806807
cu_seqlens_q_padded: Optional[Tensor] = None,
808+
vocab_parallel_rank: Optional[int] = None,
809+
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
810+
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
807811
):
808812
self.loss_fn = loss_fn
813+
self.prepare_fn = prepare_fn
809814
self.cu_seqlens_q = cu_seqlens_q
810815
self.cu_seqlens_q_padded = cu_seqlens_q_padded
816+
self.vocab_parallel_rank = vocab_parallel_rank
817+
self.vocab_parallel_group = vocab_parallel_group
818+
self.context_parallel_group = context_parallel_group
811819

812820
def __call__(
813821
self,
814822
next_token_logits: Tensor,
815823
data: BatchedDataDict[Any],
816824
global_valid_seqs: Tensor | None,
817825
global_valid_toks: Tensor | None,
818-
vocab_parallel_rank: Optional[int] = None,
819-
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
820-
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
821826
) -> tuple[Tensor, dict[str, Any]]:
822827
"""Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding."""
823828
unpadded_cu_seqlens = self.cu_seqlens_q
@@ -851,8 +856,8 @@ def __call__(
851856
# get next_token_logits
852857
cp_size = (
853858
1
854-
if context_parallel_group is None
855-
else torch.distributed.get_world_size(context_parallel_group)
859+
if self.context_parallel_group is None
860+
else torch.distributed.get_world_size(self.context_parallel_group)
856861
)
857862
logit_start = seq_start // cp_size
858863
logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size
@@ -861,14 +866,14 @@ def __call__(
861866
1, logit_start, logit_length
862867
)
863868

869+
# prepare data for loss function
870+
loss_fn_args = self.prepare_fn(next_token_logits_slice, unpadded_seq_data)
871+
864872
loss, metrics = self.loss_fn(
865-
next_token_logits_slice,
873+
*loss_fn_args,
866874
unpadded_seq_data,
867875
global_valid_seqs,
868876
global_valid_toks,
869-
vocab_parallel_rank=vocab_parallel_rank,
870-
vocab_parallel_group=vocab_parallel_group,
871-
context_parallel_group=context_parallel_group,
872877
)
873878
loss_accum += loss
874879
for k, v in metrics.items():

nemo_rl/models/automodel/train.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,12 @@ def __call__(
505505
Returns:
506506
Tuple of (loss, metrics)
507507
"""
508+
from nemo_rl.algorithms.loss_functions import (
509+
ClippedPGLossFn,
510+
DPOLossFn,
511+
NLLLoss,
512+
)
513+
508514
# Handle CP redistribution
509515
if self.cp_size > 1:
510516
_, mb = prepare_data_for_cp(
@@ -514,30 +520,45 @@ def __call__(
514520
logits, self.device_mesh, self.cp_mesh, sequence_dim
515521
)
516522

517-
# Compute logprobs from logits
518-
logprobs = get_logprobs_from_logits(
519-
input_ids=mb["input_ids"],
520-
next_token_logits=logits,
521-
seq_index=mb.get("seq_index", None),
522-
)
523-
del logits
523+
# Prepare data for loss function
524+
def prepare_for_loss_fn(
525+
logits: torch.Tensor, mb: BatchedDataDict[Any]
526+
) -> tuple[Any]:
527+
if isinstance(self.loss_fn, (ClippedPGLossFn, NLLLoss, DPOLossFn)):
528+
logprobs = get_logprobs_from_logits(
529+
input_ids=mb["input_ids"],
530+
next_token_logits=logits,
531+
seq_index=mb.get("seq_index", None),
532+
)
533+
534+
loss_fn_args = (logprobs,)
535+
536+
# TODO: PreferenceLoss, DistillationLossFn
537+
538+
return loss_fn_args
524539

525540
# Wrap loss function for sequence packing if needed
526541
if self.enable_seq_packing:
527542
loss_fn_ = SequencePackingLossWrapper(
528543
loss_fn=self.loss_fn,
544+
prepare_fn=prepare_for_loss_fn,
529545
cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q,
530546
cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q,
531547
)
548+
loss, loss_metrics = loss_fn_(
549+
logits,
550+
mb,
551+
global_valid_seqs,
552+
global_valid_toks,
553+
)
532554
else:
533-
loss_fn_ = self.loss_fn
534-
535-
loss, loss_metrics = loss_fn_(
536-
logprobs,
537-
mb,
538-
global_valid_seqs,
539-
global_valid_toks,
540-
)
555+
loss_fn_args = prepare_for_loss_fn(logits, mb)
556+
loss, loss_metrics = self.loss_fn(
557+
*loss_fn_args,
558+
mb,
559+
global_valid_seqs,
560+
global_valid_toks,
561+
)
541562

542563
return loss, loss_metrics
543564

0 commit comments

Comments
 (0)