@@ -239,7 +239,7 @@ def compute_kd_loss(
239239 student_loss : torch .Tensor | None = None ,
240240 loss_reduction_fn : Callable | None = None ,
241241 skip_balancer : bool = False ,
242- labels : torch . Tensor | None = None ,
242+ ** loss_fn_kwargs ,
243243 ) -> torch .Tensor | dict [str , torch .Tensor ]:
244244 """Compute total loss for distillation backpropagation.
245245
@@ -248,8 +248,8 @@ def compute_kd_loss(
248248 loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. Useful for
249249 loss-masking situations where the callable changes arguments each iteration.
250250 skip_balancer: Whether or not to use loss balancer to reduce the loss dict into a scalar.
251- labels: Labels to be passed to the loss function, if needed. This is necessary for losses that
252- require labels , such as MFTLoss.
251+ **loss_fn_kwargs: Additional keyword arguments to be passed to the loss function, if needed.
252+ This facilitates losses that require extras , such as labels for ``mtd. MFTLoss`` .
253253
254254 Returns:
255255 If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses.
@@ -268,8 +268,7 @@ def compute_kd_loss(
268268 student_layer ._intermediate_output = None
269269 teacher_layer ._intermediate_output = None
270270
271- extra_kwargs = {"labels" : labels } if labels is not None else {}
272- loss = loss_fn (out_s , out_t , ** extra_kwargs ) # Student is pred, Teacher is target
271+ loss = loss_fn (out_s , out_t , ** loss_fn_kwargs ) # Student is pred, Teacher is target
273272 if loss_reduction_fn is not None :
274273 # Needed in cases where a loss mask is used on non-scalar loss-fn outputs, prior to
275274 # reducing to a scalar loss value.
0 commit comments