Skip to content

Commit f1e8821

Browse files
committed
[model] fix ovis gradient_checkpointing vit no_grad (#4571)
1 parent f239fb8 commit f1e8821

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

swift/llm/model/model/qwen.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,12 @@ def get_model_tokenizer_ovis(*args, **kwargs):
808808
use_submodel_func(model, 'llm', func_list)
809809
embedding = model.get_input_embeddings()
810810
patch_output_clone(embedding)
811+
if hasattr(model.visual_tokenizer, 'backbone'):
812+
backbone = model.visual_tokenizer.backbone
813+
if hasattr(backbone, 'vision_model'):
814+
patch_get_input_embeddings(model.visual_tokenizer, 'backbone.vision_model.embeddings')
815+
elif hasattr(backbone, 'preprocessor'):
816+
patch_get_input_embeddings(model.visual_tokenizer, 'backbone.preprocessor.patchifier')
811817
try:
812818
# fix device_map
813819
from transformers.cache_utils import HybridCache

0 commit comments

Comments
 (0)