Skip to content

Commit 617973e

Browse files
authored
Update kd_loss.py
gather tensors in kd loss to enable TP > 1 with KD
1 parent 0d1eb3c commit 617973e

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

nemo_automodel/components/loss/kd_loss.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
18+
from torch.distributed.tensor import DTensor
1819

1920

2021
class KDLoss(nn.Module):
@@ -68,6 +69,14 @@ def forward(
6869
teacher_logits = teacher_logits.view(-1, teacher_logits.shape[-1])
6970
if labels.ndim > 1:
7071
labels = labels.view(-1)
72+
73+
if isinstance(teacher_logits, DTensor):
74+
teacher_logits = teacher_logits.full_tensor()
75+
if isinstance(student_logits, DTensor):
76+
student_logits = student_logits.full_tensor()
77+
if isinstance(labels, DTensor):
78+
labels = labels.full_tensor()
79+
7180
t_logits = teacher_logits[valid_mask]
7281
s_logits = student_logits[valid_mask]
7382
labels = labels[valid_mask]

0 commit comments

Comments
 (0)