Skip to content

Commit e6fb45a

Browse files
authored
Fix gemma3 monkey patch tests (#735)
## Summary With transformers >= 4.52.0, there was some refactoring of gemma3 model code. Patching should still work for previous transformers version, but gemma3 tests won't pass for older transformers versions. Not sure we want to maintain that logic in the tests. For reference: - Before (transformers <=4.51.3): ``` Gemma3ForConditionalGeneration - language_model (Gemma3ForCausalLM) - model (Gemma3TextModel) - layers/norm/etc. ``` - After: ``` Gemma3ForConditionalGeneration - model (Gemma3Model) - language_model (Gemma3TextModel) - layers/norm/etc. - language_model (for backwards-compatibility, points to model.language_model (Gemma3TextModel)) ``` Fixing part of #729 ## Testing Done Gemma3 monkey patch tests pass ``` pytest test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_gemma3_conditional_generation pytest test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_gemma3_text ``` - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent ea3ac1b commit e6fb45a

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def apply_liger_kernel_to_gemma3_text(
776776

777777
from transformers.models.gemma3 import modeling_gemma3
778778
from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
779-
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
779+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM, Gemma3TextModel
780780

781781
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
782782
from liger_kernel.transformers.model.gemma3 import causal_forward
@@ -807,9 +807,9 @@ def apply_liger_kernel_to_gemma3_text(
807807
# The model instance already exists, so we need to additionally patch the
808808
# instance variables that reference already-instantiated modules
809809

810-
if isinstance(model, Gemma3ForCausalLM):
810+
if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
811811
# get the base model from the model instance
812-
base_model = model.model
812+
base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
813813

814814
if rms_norm:
815815
_patch_rms_norm_module_for_gemma3(base_model.norm)
@@ -1625,7 +1625,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
16251625
return
16261626

16271627
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
1628-
16291628
apply_fn_signature = inspect.signature(apply_fn)
16301629

16311630
# Filter out the keyword arguments that are not supported by the apply function

test/transformers/test_monkey_patch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ def test_apply_liger_kernel_to_instance_for_gemma3_text():
667667

668668

669669
@pytest.mark.skipif(not is_gemma3_available(), reason="gemma3 module not available")
670-
def test_apply_liger_kernel_to_instance_for_gemma3():
670+
def test_apply_liger_kernel_to_instance_for_gemma3_conditional_generation():
671671
# Ensure any monkey patching is cleaned up for subsequent tests
672672

673673
with patch("transformers.models.gemma3.modeling_gemma3"):
@@ -687,8 +687,8 @@ def test_apply_liger_kernel_to_instance_for_gemma3():
687687
intermediate_size=64,
688688
)
689689
config = transformers.models.gemma3.configuration_gemma3.Gemma3Config(text_config, vision_config)
690-
dummy_model_instance = Gemma3ForConditionalGeneration._from_config(config)
691690

691+
dummy_model_instance = Gemma3ForConditionalGeneration._from_config(config)
692692
assert isinstance(dummy_model_instance, Gemma3ForConditionalGeneration)
693693

694694
# Check that model instance variables are not yet patched with Liger modules
@@ -704,11 +704,11 @@ def test_apply_liger_kernel_to_instance_for_gemma3():
704704
dummy_model_instance.multi_modal_projector.mm_soft_emb_norm.forward
705705
) != inspect.getsource(LigerRMSNorm.forward)
706706

707-
assert inspect.getsource(dummy_model_instance.language_model.model.norm.forward) != inspect.getsource(
707+
assert inspect.getsource(dummy_model_instance.language_model.norm.forward) != inspect.getsource(
708708
LigerRMSNorm.forward
709709
)
710710

711-
for layer in dummy_model_instance.language_model.model.layers:
711+
for layer in dummy_model_instance.language_model.layers:
712712
assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerGEGLUMLP.forward)
713713
assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
714714
assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
@@ -736,10 +736,10 @@ def test_apply_liger_kernel_to_instance_for_gemma3():
736736
dummy_model_instance.multi_modal_projector.mm_soft_emb_norm.forward
737737
) == inspect.getsource(LigerRMSNorm.forward)
738738

739-
assert inspect.getsource(dummy_model_instance.language_model.model.norm.forward) == inspect.getsource(
739+
assert inspect.getsource(dummy_model_instance.language_model.norm.forward) == inspect.getsource(
740740
LigerRMSNorm.forward
741741
)
742-
for layer in dummy_model_instance.language_model.model.layers:
742+
for layer in dummy_model_instance.language_model.layers:
743743
assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerGEGLUMLP.forward)
744744
assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
745745
assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)

0 commit comments

Comments
 (0)