Skip to content

Commit 08f2ea4

Browse files
Manan17lancertsvaibhavjindal
authored
Fix mllama monkey patch tests (#737)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> With transformers >= 4.52.0, there was some refactoring of mllama model code. Fixing part of #729 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> Gemma3 monkey patch tests pass ``` pytest test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation pytest test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_mllama_for_causal_lm ``` <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <[email protected]> Co-authored-by: Vaibhav Jindal <[email protected]>
1 parent cd6ec32 commit 08f2ea4

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,13 +428,14 @@ def apply_liger_kernel_to_mllama(
428428
if isinstance(model, MllamaForConditionalGeneration):
429429
language_model: MllamaForCausalLM = model.language_model
430430
vision_model: MllamaVisionModel = model.vision_model
431-
text_model: MllamaTextModel = language_model.model
431+
text_model: MllamaTextModel = language_model
432432
elif isinstance(model, MllamaForCausalLM):
433433
text_model = model.model
434434
vision_model = None
435435
elif isinstance(model, MllamaTextModel):
436436
text_model = model
437437
vision_model = None
438+
438439
else:
439440
raise ValueError(f"Unsupported Mllama model type: {type(model)}")
440441

test/transformers/test_monkey_patch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,10 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
348348
assert isinstance(dummy_model_instance, MllamaForConditionalGeneration)
349349

350350
# Check that model instance variables are not yet patched with Liger modules
351-
assert inspect.getsource(dummy_model_instance.language_model.model.norm.forward) != inspect.getsource(
351+
assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource(
352352
LigerRMSNorm.forward
353353
)
354-
for layer in dummy_model_instance.language_model.model.layers:
354+
for layer in dummy_model_instance.language_model.layers:
355355
assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward)
356356
assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
357357
assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
@@ -377,10 +377,10 @@ def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
377377
_apply_liger_kernel_to_instance(model=dummy_model_instance)
378378

379379
# Check that the model's instance variables were correctly patched with Liger modules
380-
assert inspect.getsource(dummy_model_instance.language_model.model.norm.forward) == inspect.getsource(
380+
assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource(
381381
LigerRMSNorm.forward
382382
)
383-
for layer in dummy_model_instance.language_model.model.layers:
383+
for layer in dummy_model_instance.language_model.layers:
384384
assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward)
385385
assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
386386
assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)

0 commit comments

Comments
 (0)