@@ -265,10 +265,10 @@ def prediction_step(
265
265
labels_list = pad_sequence (labels_list , batch_first = True , padding_value = 0 )
266
266
return None , response_list , labels_list
267
267
268
- def compute_loss (self , model , inputs , return_outputs = False , num_items_in_batch = None ):
268
+ def _prepare_inputs (self , inputs ):
269
+ inputs = super ()._prepare_inputs (inputs )
269
270
from swift .plugin .loss import get_loss_func
270
271
loss_kwargs = {}
271
- labels = None
272
272
compute_loss_func = self .compute_loss_func
273
273
loss_scale = inputs .pop ('loss_scale' , None )
274
274
if loss_scale is not None :
@@ -287,14 +287,25 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
287
287
if inputs .get ('position_ids' ) is not None :
288
288
loss_kwargs ['position_ids' ] = inputs ['position_ids' ]
289
289
290
- if (self .label_smoother is not None or compute_loss_func is not None ) and 'labels' in inputs :
291
- labels = inputs .pop ('labels' )
292
-
293
- use_logits_to_keep = self .get_use_logits_to_keep ('labels' in inputs )
290
+ use_logits_to_keep = self .get_use_logits_to_keep ('labels' in inputs and self .label_smoother is None
291
+ and compute_loss_func is None )
294
292
if use_logits_to_keep :
295
293
inputs ['labels' ], logits_to_keep = self .get_logits_to_keep (inputs ['labels' ])
296
294
if logits_to_keep is not None :
297
295
inputs ['logits_to_keep' ] = logits_to_keep
296
+
297
+ inputs ['compute_loss_func' ] = compute_loss_func
298
+ inputs ['loss_kwargs' ] = loss_kwargs
299
+ return inputs
300
+
301
+ def compute_loss (self , model , inputs , return_outputs = False , num_items_in_batch = None ):
302
+ labels = None
303
+ compute_loss_func = inputs .pop ('compute_loss_func' , None )
304
+ loss_kwargs = inputs .pop ('loss_kwargs' , {})
305
+
306
+ if (self .label_smoother is not None or compute_loss_func is not None ) and 'labels' in inputs :
307
+ labels = inputs .pop ('labels' )
308
+
298
309
outputs = model (** inputs )
299
310
# Save past state if it exists
300
311
# TODO: this needs to be fixed and made cleaner later.
0 commit comments