@@ -239,7 +239,7 @@ def compute_kd_loss(
239
239
student_loss : torch .Tensor | None = None ,
240
240
loss_reduction_fn : Callable | None = None ,
241
241
skip_balancer : bool = False ,
242
- labels : torch . Tensor | None = None ,
242
+ ** loss_fn_kwargs ,
243
243
) -> torch .Tensor | dict [str , torch .Tensor ]:
244
244
"""Compute total loss for distillation backpropagation.
245
245
@@ -248,8 +248,8 @@ def compute_kd_loss(
248
248
loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. Useful for
249
249
loss-masking situations where the callable changes arguments each iteration.
250
250
skip_balancer: Whether or not to use loss balancer to reduce the loss dict into a scalar.
251
- labels: Labels to be passed to the loss function, if needed. This is necessary for losses that
252
- require labels , such as MFTLoss.
251
+ **loss_fn_kwargs: Additional keyword arguments to be passed to the loss function, if needed.
252
+ This facilitates losses that require extras , such as labels for ``mtd. MFTLoss`` .
253
253
254
254
Returns:
255
255
If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses.
@@ -268,8 +268,7 @@ def compute_kd_loss(
268
268
student_layer ._intermediate_output = None
269
269
teacher_layer ._intermediate_output = None
270
270
271
- extra_kwargs = {"labels" : labels } if labels is not None else {}
272
- loss = loss_fn (out_s , out_t , ** extra_kwargs ) # Student is pred, Teacher is target
271
+ loss = loss_fn (out_s , out_t , ** loss_fn_kwargs ) # Student is pred, Teacher is target
273
272
if loss_reduction_fn is not None :
274
273
# Needed in cases where a loss mask is used on non-scalar loss-fn outputs, prior to
275
274
# reducing to a scalar loss value.
0 commit comments