@@ -237,8 +237,9 @@ def prediction_step(
237237 ** gen_kwargs ,
238238 ) -> Tuple [Optional [float ], Optional [torch .Tensor ], Optional [torch .Tensor ]]:
239239 if not self .args .predict_with_generate or prediction_loss_only :
240- return super ().prediction_step (
241- model , inputs , prediction_loss_only = prediction_loss_only , ignore_keys = ignore_keys )
240+ with self .template .forward_context (self .model , inputs ):
241+ return super ().prediction_step (
242+ model , inputs , prediction_loss_only = prediction_loss_only , ignore_keys = ignore_keys )
242243 from swift .llm import RequestConfig , InferRequest
243244 data_list = inputs ['_data' ]
244245 labels_list = [InferRequest .remove_response (data ['messages' ]) for data in data_list ]
@@ -340,5 +341,5 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
340341 return (loss , outputs ) if return_outputs else loss
341342
342343 def training_step (self , model , inputs , * args , ** kwargs ):
343- with self .template .training_step_context (self .model , inputs ):
344+ with self .template .forward_context (self .model , inputs ):
344345 return super ().training_step (model , inputs , * args , ** kwargs )
0 commit comments