Skip to content

Commit 730ecc9

Browse files
authored
[train] Fix qwen2.5-vl use_cache (#4458)
1 parent 3478bdb commit 730ecc9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

swift/trainers/mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,8 @@ def clip_grad_norm_(self, parameters, *args, **kwargs):
316316
def _prepare_gradient_checkpointing(self, model) -> None:
317317
from swift.llm import HfConfigFactory, get_model_arch, deep_getattr, dynamic_gradient_checkpointing
318318
args = self.args
319+
HfConfigFactory.set_model_config_attr(model, 'use_cache', False)
319320
if args.gradient_checkpointing or args.vit_gradient_checkpointing:
320-
HfConfigFactory.set_model_config_attr(model, 'use_cache', False)
321321
dynamic_gradient_checkpointing(model, args.vit_gradient_checkpointing)
322322
if args.gradient_checkpointing:
323323
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)

0 commit comments

Comments
 (0)