Skip to content

Commit 153d226

Browse files
Fix missing property access for multimodal models (#966)
## Summary This PR fixes access to missing attributes for multimodal models in `src/liger_kernel/transformers/monkey_patch.py`. The main change is to consistently access attributes (like `language_model`, `vision_tower`, and `visual`) through the submodel `.model` attribute of the parent model, rather than directly from the parent model itself. This fixes AttributeError after this PR was merged in transformers: - huggingface/transformers#42156 See associated issue in TRL: - huggingface/trl#4601 Fix #960. ## Details Fix: Consistent attribute access via `.model` * Updated all references to submodules such as `language_model`, `vision_tower`, and `visual` to use the `.model` attribute (e.g., `model.model.language_model` instead of `model.language_model`) across all kernel application functions for models including LLava, Mllama, Gemma3, PaliGemma, Qwen2 VL, Qwen2.5 VL, Qwen3 VL, Qwen3 VL MoE, GLM4V, GLM4V MoE, and InternVL. Normalization and patching logic updates * Adjusted normalization and patching calls to operate on submodels accessed via `.model`, ensuring that layer normalization and RMS normalization are consistently applied to the correct components. These changes make the codebase more maintainable and robust against future changes in model class implementations. ## Testing Done - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <[email protected]>
1 parent 67d98e4 commit 153d226

File tree

2 files changed

+103
-98
lines changed

2 files changed

+103
-98
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def apply_liger_kernel_to_llava(
430430
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
431431
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
432432
)
433-
text_kwargs["model"] = model.language_model
433+
text_kwargs["model"] = model.model.language_model
434434
text_liger_fn(**text_kwargs)
435435
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
436436
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
@@ -445,7 +445,7 @@ def apply_liger_kernel_to_llava(
445445
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
446446
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
447447
)
448-
vision_kwargs["model"] = model.vision_tower
448+
vision_kwargs["model"] = model.model.vision_tower
449449
vision_liger_fn(**vision_kwargs)
450450
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
451451
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
@@ -615,8 +615,8 @@ def apply_liger_kernel_to_mllama(
615615
# instance variables that reference already-instantiated modules
616616

617617
if isinstance(model, MllamaForConditionalGeneration):
618-
language_model: MllamaForCausalLM = model.language_model
619-
vision_model: MllamaVisionModel = model.vision_model
618+
language_model: MllamaForCausalLM = model.model.language_model
619+
vision_model: MllamaVisionModel = model.model.vision_model
620620
if isinstance(language_model, MllamaForCausalLM):
621621
text_model: MllamaTextModel = language_model.model
622622
else:
@@ -1118,8 +1118,8 @@ def apply_liger_kernel_to_gemma3(
11181118
# instance variables that reference already-instantiated modules
11191119

11201120
if isinstance(model, Gemma3ForConditionalGeneration):
1121-
if isinstance(model.vision_tower, SiglipVisionModel):
1122-
vision_tower = model.vision_tower
1121+
if isinstance(model.model.vision_tower, SiglipVisionModel):
1122+
vision_tower = model.model.vision_tower
11231123

11241124
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
11251125

@@ -1132,15 +1132,15 @@ def apply_liger_kernel_to_gemma3(
11321132
raise TypeError("The vision tower must be SiglipVisionModel")
11331133

11341134
if rms_norm:
1135-
_patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
1135+
_patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
11361136

11371137
apply_liger_kernel_to_gemma3_text(
11381138
rope=rope,
11391139
cross_entropy=False,
11401140
fused_linear_cross_entropy=False,
11411141
rms_norm=rms_norm,
11421142
geglu=geglu,
1143-
model=model.language_model,
1143+
model=model.model.language_model,
11441144
)
11451145

11461146
else:
@@ -1228,7 +1228,7 @@ def apply_liger_kernel_to_paligemma(
12281228
if not isinstance(model, PaliGemmaForConditionalGeneration):
12291229
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
12301230

1231-
vision_tower: SiglipVisionModel = model.vision_tower
1231+
vision_tower: SiglipVisionModel = model.model.vision_tower
12321232

12331233
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
12341234

@@ -1238,7 +1238,7 @@ def apply_liger_kernel_to_paligemma(
12381238
_patch_layer_norm_module(layer.layer_norm1)
12391239
_patch_layer_norm_module(layer.layer_norm2)
12401240

1241-
language_model = model.language_model
1241+
language_model = model.model.language_model
12421242

12431243
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
12441244
apply_liger_kernel_to_gemma(
@@ -1593,11 +1593,10 @@ def apply_liger_kernel_to_qwen2_vl(
15931593
if model is not None:
15941594
# The model instance already exists, so we need to additionally patch the
15951595
# instance variables that reference already-instantiated modules
1596-
1597-
if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1598-
# Note: language_model and visual properties can be accessed throught conditional class for BC.
1599-
# Not sure if it is subject to changes in the future.
1600-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1596+
if isinstance(model, Qwen2VLForConditionalGeneration):
1597+
text_model: Qwen2VLTextModel = model.model.language_model
1598+
vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
1599+
elif isinstance(model, Qwen2VLModel):
16011600
text_model: Qwen2VLTextModel = model.language_model
16021601
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
16031602
elif isinstance(model, Qwen2VLTextModel):
@@ -1684,11 +1683,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
16841683
if model is not None:
16851684
# The model instance already exists, so we need to additionally patch the
16861685
# instance variables that reference already-instantiated modules
1687-
1688-
if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1689-
# Note: language_model and visual properties can be accessed throught conditional class for BC.
1690-
# Not sure if it is subject to changes in the future.
1691-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1686+
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
1687+
text_model: Qwen2_5_VLTextModel = model.model.language_model
1688+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
1689+
elif isinstance(model, Qwen2_5_VLModel):
16921690
text_model: Qwen2_5_VLTextModel = model.language_model
16931691
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
16941692
elif isinstance(model, Qwen2_5_VLTextModel):
@@ -1702,7 +1700,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
17021700

17031701
if vision_model is not None:
17041702
# Patch Qwen2_5_VisionTransformerPretrainedModel
1705-
for vision_block in model.visual.blocks:
1703+
for vision_block in vision_model.blocks:
17061704
if rms_norm:
17071705
_patch_rms_norm_module(vision_block.norm1)
17081706
_patch_rms_norm_module(vision_block.norm2)
@@ -1771,7 +1769,9 @@ def apply_liger_kernel_to_qwen3_vl(
17711769
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
17721770

17731771
if model is not None and rms_norm:
1774-
if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
1772+
if isinstance(model, Qwen3VLForConditionalGeneration):
1773+
text_model: Qwen3VLTextModel = model.model.language_model
1774+
elif isinstance(model, Qwen3VLModel):
17751775
text_model: Qwen3VLTextModel = model.language_model
17761776
elif isinstance(model, Qwen3VLTextModel):
17771777
text_model = model
@@ -1846,7 +1846,9 @@ def apply_liger_kernel_to_qwen3_vl_moe(
18461846
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
18471847

18481848
if model is not None and rms_norm:
1849-
if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
1849+
if isinstance(model, Qwen3VLMoeForConditionalGeneration):
1850+
text_model: Qwen3VLMoeTextModel = model.model.language_model
1851+
elif isinstance(model, Qwen3VLMoeModel):
18501852
text_model: Qwen3VLMoeTextModel = model.language_model
18511853
elif isinstance(model, Qwen3VLMoeTextModel):
18521854
text_model = model
@@ -2191,10 +2193,10 @@ def apply_liger_kernel_to_glm4v(
21912193
if model is not None:
21922194
# The model instance already exists, so we need to additionally patch the
21932195
# instance variables that reference already-instantiated modules
2194-
if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
2195-
# Note: language_model and visual properties can be accessed throught conditional class for BC.
2196-
# Not sure if it is subject to changes in the future.
2197-
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
2196+
if isinstance(model, Glm4vForConditionalGeneration):
2197+
text_model: Glm4vTextModel = model.model.language_model
2198+
vision_model: Glm4vVisionModel = model.model.visual
2199+
elif isinstance(model, Glm4vModel):
21982200
text_model: Glm4vTextModel = model.language_model
21992201
vision_model: Glm4vVisionModel = model.visual
22002202
elif isinstance(model, Glm4vTextModel):
@@ -2281,10 +2283,11 @@ def apply_liger_kernel_to_glm4v_moe(
22812283
if model is not None:
22822284
# The model instance already exists, so we need to additionally patch the
22832285
# instance variables that reference already-instantiated modules
2284-
if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
2285-
# Note: language_model and visual properties can be accessed throught conditional class for BC.
2286-
# Not sure if it is subject to changes in the future.
2287-
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
2286+
if isinstance(model, Glm4vMoeForConditionalGeneration):
2287+
text_model: Glm4vMoeTextModel = model.model.language_model
2288+
vision_model: Glm4vMoeVisionModel = model.model.visual
2289+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2290+
elif isinstance(model, Glm4vMoeModel):
22882291
text_model: Glm4vMoeTextModel = model.language_model
22892292
vision_model: Glm4vMoeVisionModel = model.visual
22902293
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
@@ -2387,8 +2390,10 @@ def apply_liger_kernel_to_internvl(
23872390
if model is not None:
23882391
# The model instance already exists, so we need to additionally patch the
23892392
# instance variables that reference already-instantiated modules
2390-
if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
2391-
# NOTE: language_model and visual properties can be accessed throught conditional class.
2393+
if isinstance(model, InternVLForConditionalGeneration):
2394+
text_model = model.model.language_model
2395+
vision_model: InternVLVisionModel = model.model.vision_tower
2396+
elif isinstance(model, InternVLModel):
23922397
text_model = model.language_model
23932398
vision_model: InternVLVisionModel = model.vision_tower
23942399
else:

0 commit comments

Comments
 (0)