From 17f918ffbe76d44e212e4d5b52889d251b615556 Mon Sep 17 00:00:00 2001 From: chenjianhuii <875121393@qq.com> Date: Wed, 3 Sep 2025 15:31:14 +0800 Subject: [PATCH 1/2] bug fix: RuntimeError when training GRPO with LoRA and PtEngine --- swift/llm/infer/infer_engine/pt_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index 6585edd8d0..ffab0eae1a 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -461,7 +461,7 @@ async def _gen_wrapper(): return await queue.get() # Ensure `template._post_encode` has no gradient. - @torch.inference_mode() + @torch.no_grad() def _infer( self, infer_requests: List[InferRequest], From 723af9c260693763a1ec05d10ad0a7751620eee7 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 4 Sep 2025 17:03:33 +0800 Subject: [PATCH 2/2] fix pt zero3 --- swift/llm/infer/infer_engine/pt_engine.py | 2 +- swift/trainers/rlhf_trainer/grpo_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index ffab0eae1a..6585edd8d0 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -461,7 +461,7 @@ async def _gen_wrapper(): return await queue.get() # Ensure `template._post_encode` has no gradient. - @torch.no_grad() + @torch.inference_mode() def _infer( self, infer_requests: List[InferRequest], diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 6c1078b14a..76eec94014 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1254,7 +1254,7 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: torch.stack([data['advantages'] for data in batch]) }) - with torch.no_grad(): + with torch.inference_mode(): batch_encoded_inputs['old_per_token_logps'] = ( self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0] if self.old_policy() else None)