Skip to content

Commit 9aedfdf

Browse files
authored
Update distill Megatron plugin (#319)
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent d406aa1 commit 9aedfdf

File tree

2 files changed

+280
-142
lines changed

2 files changed

+280
-142
lines changed

modelopt/torch/distill/distillation_model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def compute_kd_loss(
239239
student_loss: torch.Tensor | None = None,
240240
loss_reduction_fn: Callable | None = None,
241241
skip_balancer: bool = False,
242-
labels: torch.Tensor | None = None,
242+
**loss_fn_kwargs,
243243
) -> torch.Tensor | dict[str, torch.Tensor]:
244244
"""Compute total loss for distillation backpropagation.
245245
@@ -248,8 +248,8 @@ def compute_kd_loss(
248248
loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. Useful for
249249
loss-masking situations where the callable changes arguments each iteration.
250250
skip_balancer: Whether or not to use loss balancer to reduce the loss dict into a scalar.
251-
labels: Labels to be passed to the loss function, if needed. This is necessary for losses that
252-
require labels, such as MFTLoss.
251+
**loss_fn_kwargs: Additional keyword arguments to be passed to the loss function, if needed.
252+
This facilitates losses that require extras, such as labels for ``mtd.MFTLoss``.
253253
254254
Returns:
255255
If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses.
@@ -268,8 +268,7 @@ def compute_kd_loss(
268268
student_layer._intermediate_output = None
269269
teacher_layer._intermediate_output = None
270270

271-
extra_kwargs = {"labels": labels} if labels is not None else {}
272-
loss = loss_fn(out_s, out_t, **extra_kwargs) # Student is pred, Teacher is target
271+
loss = loss_fn(out_s, out_t, **loss_fn_kwargs) # Student is pred, Teacher is target
273272
if loss_reduction_fn is not None:
274273
# Needed in cases where a loss mask is used on non-scalar loss-fn outputs, prior to
275274
# reducing to a scalar loss value.

0 commit comments

Comments
 (0)