Skip to content

Commit 0fef58c

Browse files
chtruong814zpqiu
andauthored
cp: fix: Fix DTensor slice crash after PyTorch 2.9 bump (1689) into r0.5.0 (#1707)
Signed-off-by: Zhaopeng Qiu <[email protected]> Signed-off-by: NeMo Bot <[email protected]> Co-authored-by: alexchiu <[email protected]>
1 parent e883ac4 commit 0fef58c

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

nemo_rl/algorithms/loss_functions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -922,11 +922,12 @@ def __call__(
922922
if context_parallel_group is None
923923
else torch.distributed.get_world_size(context_parallel_group)
924924
)
925-
logit_slice_idxs = slice(
926-
seq_start // cp_size,
927-
(seq_start + padded_seq_lengths[seq_idx]) // cp_size,
925+
logit_start = seq_start // cp_size
926+
logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size
927+
logit_length = logit_end - logit_start
928+
next_token_logits_slice = next_token_logits.narrow(
929+
1, logit_start, logit_length
928930
)
929-
next_token_logits_slice = next_token_logits[:, logit_slice_idxs, :]
930931

931932
loss, metrics = self.loss_fn(
932933
next_token_logits_slice,

0 commit comments

Comments
 (0)