Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions modelopt/torch/distill/distillation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def compute_kd_loss(
student_loss: torch.Tensor | None = None,
loss_reduction_fn: Callable | None = None,
skip_balancer: bool = False,
labels: torch.Tensor | None = None,
**loss_fn_kwargs,
) -> torch.Tensor | dict[str, torch.Tensor]:
"""Compute total loss for distillation backpropagation.

Expand All @@ -248,8 +248,8 @@ def compute_kd_loss(
loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. Useful for
loss-masking situations where the callable changes arguments each iteration.
skip_balancer: Whether or not to use loss balancer to reduce the loss dict into a scalar.
labels: Labels to be passed to the loss function, if needed. This is necessary for losses that
require labels, such as MFTLoss.
**loss_fn_kwargs: Additional keyword arguments to be passed to the loss function, if needed.
This facilitates losses that require extras, such as labels for ``mtd.MFTLoss``.

Returns:
If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses.
Expand All @@ -268,8 +268,7 @@ def compute_kd_loss(
student_layer._intermediate_output = None
teacher_layer._intermediate_output = None

extra_kwargs = {"labels": labels} if labels is not None else {}
loss = loss_fn(out_s, out_t, **extra_kwargs) # Student is pred, Teacher is target
loss = loss_fn(out_s, out_t, **loss_fn_kwargs) # Student is pred, Teacher is target
if loss_reduction_fn is not None:
# Needed in cases where a loss mask is used on non-scalar loss-fn outputs, prior to
# reducing to a scalar loss value.
Expand Down
Loading
Loading