diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 6c1078b14a..76eec94014 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1254,7 +1254,7 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: torch.stack([data['advantages'] for data in batch]) }) - with torch.no_grad(): + with torch.inference_mode(): batch_encoded_inputs['old_per_token_logps'] = ( self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0] if self.old_policy() else None)