File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed
nemo_automodel/components/loss Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change 1515import torch
1616import torch .nn as nn
1717import torch .nn .functional as F
18+ from torch .distributed .tensor import DTensor
1819
1920
2021class KDLoss (nn .Module ):
@@ -68,6 +69,14 @@ def forward(
6869 teacher_logits = teacher_logits .view (- 1 , teacher_logits .shape [- 1 ])
6970 if labels .ndim > 1 :
7071 labels = labels .view (- 1 )
72+
73+ if isinstance (teacher_logits , DTensor ):
74+ teacher_logits = teacher_logits .full_tensor ()
75+ if isinstance (student_logits , DTensor ):
76+ student_logits = student_logits .full_tensor ()
77+ if isinstance (labels , DTensor ):
78+ labels = labels .full_tensor ()
79+
7180 t_logits = teacher_logits [valid_mask ]
7281 s_logits = student_logits [valid_mask ]
7382 labels = labels [valid_mask ]
You can’t perform that action at this time.
0 commit comments