We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
fix: Fix DTensor slice crash after PyTorch 2.9 bump (1689)
r0.5.0
1 parent e883ac4 commit 0fef58cCopy full SHA for 0fef58c
nemo_rl/algorithms/loss_functions.py
@@ -922,11 +922,12 @@ def __call__(
922
if context_parallel_group is None
923
else torch.distributed.get_world_size(context_parallel_group)
924
)
925
- logit_slice_idxs = slice(
926
- seq_start // cp_size,
927
- (seq_start + padded_seq_lengths[seq_idx]) // cp_size,
+ logit_start = seq_start // cp_size
+ logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size
+ logit_length = logit_end - logit_start
928
+ next_token_logits_slice = next_token_logits.narrow(
929
+ 1, logit_start, logit_length
930
- next_token_logits_slice = next_token_logits[:, logit_slice_idxs, :]
931
932
loss, metrics = self.loss_fn(
933
next_token_logits_slice,
0 commit comments