Skip to content

Commit e808281

Browse files
vvvdwbvvvlancertsTcc0403
authored
Fix paligemma missing test_liger_kernel_to_instance_for_paligemma_instance test in test_monkey_patch.py (#785)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> WHAT THIS PR ADDS - A new unit test file containing `test_apply_liger_kernel_to_instance_for_paligemma()`. Fix #776 - The test instantiates a dummy PaliGemmaForConditionalGeneration model, confirms it is un-patched, runs `_apply_liger_kernel_to_instance()`, then verifies that: - model.forward is replaced by `paligemma_lce_forward()`. - `vision_tower.vision_model.post_layernorm.forward` is replaced by `LigerLayerNorm.forward`. - Every encoder layer’s `layer_norm1.forward` and `layer_norm2.forward` are also replaced. - Source equality is checked with inspect.getsource before and after patching. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> `transformers==4.49.0` <details> <summary>Test result</summary> ❯ python3 -m pytest test/transformers/test_monkey_patch.py -k paligemma -v -rP ============================================== test session starts ============================================== platform linux -- Python 3.11.11, pytest-8.4.1, pluggy-1.6.0 -- /home/vvvdwbvvv/.local/bin/python3 cachedir: .pytest_cache rootdir: /home/vvvdwbvvv/develop/Liger-Kernel configfile: pyproject.toml plugins: asyncio-1.0.0 asyncio: mode=Mode.AUTO, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function collected 33 items / 32 deselected / 1 selected test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_paligemma ------------------------------------------------- live log call ------------------------------------------------- INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1864 Applying Liger kernels to model instance with model type: paligemma with kwargs: {} PASSED [100%] ==================================================== PASSES ===================================================== _______________________________ test_apply_liger_kernel_to_instance_for_paligemma _______________________________ --------------------------------------------- Captured stdout call ---------------------------------------------- PaliGemmaForConditionalGeneration( (vision_tower): SiglipVisionModel( (vision_model): SiglipVisionTransformer( (embeddings): SiglipVisionEmbeddings( (patch_embedding): Conv2d(3, 48, kernel_size=(16, 16), stride=(16, 16), padding=valid) (position_embedding): Embedding(196, 48) ) (encoder): SiglipEncoder( (layers): ModuleList( (0-1): 2 x SiglipEncoderLayer( (self_attn): SiglipSdpaAttention( (k_proj): Linear(in_features=48, out_features=48, bias=True) (v_proj): Linear(in_features=48, out_features=48, bias=True) (q_proj): Linear(in_features=48, out_features=48, bias=True) (out_proj): Linear(in_features=48, out_features=48, bias=True) ) (layer_norm1): LigerLayerNorm((48,), eps=1e-05) (mlp): SiglipMLP( (activation_fn): PytorchGELUTanh() (fc1): Linear(in_features=48, out_features=64, bias=True) (fc2): Linear(in_features=64, out_features=48, bias=True) ) (layer_norm2): LigerLayerNorm((48,), eps=1e-05) ) ) ) (post_layernorm): LigerLayerNorm((48,), eps=1e-05) (head): SiglipMultiheadAttentionPoolingHead( (attention): MultiheadAttention( (out_proj): NonDynamicallyQuantizableLinear(in_features=48, out_features=48, bias=True) ) (layernorm): LigerLayerNorm((48,), eps=1e-05, elementwise_affine=True) (mlp): SiglipMLP( (activation_fn): PytorchGELUTanh() (fc1): Linear(in_features=48, out_features=64, bias=True) (fc2): Linear(in_features=64, out_features=48, bias=True) ) ) ) ) (multi_modal_projector): PaliGemmaMultiModalProjector( (linear): Linear(in_features=48, out_features=2048, bias=True) ) (language_model): GemmaForCausalLM( (model): GemmaModel( (embed_tokens): Embedding(256000, 32, padding_idx=0) (layers): ModuleList( (0-1): 2 x GemmaDecoderLayer( (self_attn): GemmaAttention( (q_proj): Linear(in_features=32, out_features=4096, bias=False) (k_proj): Linear(in_features=32, out_features=4096, bias=False) (v_proj): Linear(in_features=32, out_features=4096, bias=False) (o_proj): Linear(in_features=4096, out_features=32, bias=False) ) (mlp): LigerGEGLUMLP( (gate_proj): Linear(in_features=32, out_features=64, bias=False) (up_proj): Linear(in_features=32, out_features=64, bias=False) (down_proj): Linear(in_features=64, out_features=32, bias=False) (act_fn): SiLU() ) (input_layernorm): LigerRMSNorm((32,), eps=1e-05, offset=1.0, in_place=True, row_mode=None) (post_attention_layernorm): LigerRMSNorm((32,), eps=1e-05, offset=1.0, in_place=True, row_mode=None) ) ) (norm): LigerRMSNorm((32,), eps=1e-05, offset=1.0, in_place=True, row_mode=None) (rotary_emb): GemmaRotaryEmbedding() ) (lm_head): Linear(in_features=32, out_features=256000, bias=False) ) ) ----------------------------------------------- Captured log call ----------------------------------------------- INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1864 Applying Liger kernels to model instance with model type: paligemma with kwargs: {} ======================================= 1 passed, 32 deselected in 1.78s ======================================== </details> `transformers==4.53.1` <details> <summary>Test result</summary> ❯ python3 -m pytest test/transformers/test_monkey_patch.py -k paligemma -v -rP ============================================== test session starts ============================================== platform linux -- Python 3.11.11, pytest-8.4.1, pluggy-1.6.0 -- /home/vvvdwbvvv/.local/bin/python3 cachedir: .pytest_cache rootdir: /home/vvvdwbvvv/develop/Liger-Kernel configfile: pyproject.toml plugins: asyncio-1.0.0 asyncio: mode=Mode.AUTO, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function collected 33 items / 32 deselected / 1 selected test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_paligemma ------------------------------------------------- live log call ------------------------------------------------- INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1864 Applying Liger kernels to model instance with model type: paligemma with kwargs: {} PASSED [100%] ==================================================== PASSES ===================================================== _______________________________ test_apply_liger_kernel_to_instance_for_paligemma _______________________________ --------------------------------------------- Captured stdout call ---------------------------------------------- PaliGemmaForConditionalGeneration( (model): PaliGemmaModel( (vision_tower): SiglipVisionModel( (vision_model): SiglipVisionTransformer( (embeddings): SiglipVisionEmbeddings( (patch_embedding): Conv2d(3, 48, kernel_size=(16, 16), stride=(16, 16), padding=valid) (position_embedding): Embedding(196, 48) ) (encoder): SiglipEncoder( (layers): ModuleList( (0-1): 2 x SiglipEncoderLayer( (layer_norm1): LigerLayerNorm((48,), eps=1e-05) (self_attn): SiglipAttention( (k_proj): Linear(in_features=48, out_features=48, bias=True) (v_proj): Linear(in_features=48, out_features=48, bias=True) (q_proj): Linear(in_features=48, out_features=48, bias=True) (out_proj): Linear(in_features=48, out_features=48, bias=True) ) (layer_norm2): LigerLayerNorm((48,), eps=1e-05) (mlp): SiglipMLP( (activation_fn): PytorchGELUTanh() (fc1): Linear(in_features=48, out_features=64, bias=True) (fc2): Linear(in_features=64, out_features=48, bias=True) ) ) ) ) (post_layernorm): LigerLayerNorm((48,), eps=1e-05) (head): SiglipMultiheadAttentionPoolingHead( (attention): MultiheadAttention( (out_proj): NonDynamicallyQuantizableLinear(in_features=48, out_features=48, bias=True) ) (layernorm): LigerLayerNorm((48,), eps=1e-05, elementwise_affine=True) (mlp): SiglipMLP( (activation_fn): PytorchGELUTanh() (fc1): Linear(in_features=48, out_features=64, bias=True) (fc2): Linear(in_features=64, out_features=48, bias=True) ) ) ) ) (multi_modal_projector): PaliGemmaMultiModalProjector( (linear): Linear(in_features=48, out_features=2048, bias=True) ) (language_model): GemmaModel( (embed_tokens): Embedding(256000, 32, padding_idx=0) (layers): ModuleList( (0-1): 2 x GemmaDecoderLayer( (self_attn): GemmaAttention( (q_proj): Linear(in_features=32, out_features=4096, bias=False) (k_proj): Linear(in_features=32, out_features=4096, bias=False) (v_proj): Linear(in_features=32, out_features=4096, bias=False) (o_proj): Linear(in_features=4096, out_features=32, bias=False) ) (mlp): LigerGEGLUMLP( (gate_proj): Linear(in_features=32, out_features=64, bias=False) (up_proj): Linear(in_features=32, out_features=64, bias=False) (down_proj): Linear(in_features=64, out_features=32, bias=False) (act_fn): SiLU() ) (input_layernorm): LigerRMSNorm((32,), eps=1e-05, offset=1.0, in_place=True, row_mode=None) (post_attention_layernorm): LigerRMSNorm((32,), eps=1e-05, offset=1.0, in_place=True, row_mode=None) ) ) (norm): LigerRMSNorm((32,), eps=1e-05, offset=1.0, in_place=True, row_mode=None) (rotary_emb): GemmaRotaryEmbedding() ) ) (lm_head): Linear(in_features=32, out_features=256000, bias=False) ) ----------------------------------------------- Captured log call ----------------------------------------------- INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1864 Applying Liger kernels to model instance with model type: paligemma with kwargs: {} ======================================= 1 passed, 32 deselected in 2.42s ======================================== </details> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <[email protected]> Co-authored-by: Tcc0403 <[email protected]>
1 parent 5700de2 commit e808281

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,9 @@ def apply_liger_kernel_to_paligemma(
10961096
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
10971097

10981098
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
1099+
from transformers.models.gemma.modeling_gemma import GemmaModel
10991100
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
1101+
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
11001102
from transformers.models.paligemma import modeling_paligemma
11011103
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
11021104
from transformers.models.siglip import modeling_siglip
@@ -1155,7 +1157,7 @@ def apply_liger_kernel_to_paligemma(
11551157

11561158
language_model = model.language_model
11571159

1158-
if isinstance(language_model, GemmaForCausalLM):
1160+
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
11591161
apply_liger_kernel_to_gemma(
11601162
rope=rope,
11611163
cross_entropy=False,
@@ -1165,7 +1167,7 @@ def apply_liger_kernel_to_paligemma(
11651167
model=language_model,
11661168
)
11671169

1168-
elif isinstance(language_model, Gemma2ForCausalLM):
1170+
elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
11691171
apply_liger_kernel_to_gemma2(
11701172
rope=rope,
11711173
cross_entropy=False,

test/transformers/test_monkey_patch.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
3939
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
4040
from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
41+
from liger_kernel.transformers.model.paligemma import lce_forward as paligemma_lce_forward
4142
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
4243
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
4344
else:
@@ -49,6 +50,7 @@
4950
)
5051
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward
5152
from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward
53+
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated as paligemma_lce_forward
5254
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward
5355
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward
5456

@@ -126,6 +128,15 @@ def is_gemma3_available():
126128
return False
127129

128130

131+
def is_paligemma_available():
132+
try:
133+
import transformers.models.paligemma # noqa: F401
134+
135+
return True
136+
except ImportError:
137+
return False
138+
139+
129140
def test_import_from_root():
130141
try:
131142
from liger_kernel.transformers import AutoLigerKernelForCausalLM # noqa: F401
@@ -793,6 +804,62 @@ def test_apply_liger_kernel_to_instance_for_gemma2():
793804
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
794805

795806

807+
@pytest.mark.skipif(not is_paligemma_available(), reason="paligemma module not available")
808+
def test_apply_liger_kernel_to_instance_for_paligemma():
809+
# Ensure any monkey patching is cleaned up for subsequent tests
810+
with patch("transformers.models.paligemma.modeling_paligemma"):
811+
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
812+
813+
# Instantiate a dummy model
814+
config = transformers.models.paligemma.configuration_paligemma.PaliGemmaConfig(
815+
torch_dtype=torch.bfloat16,
816+
text_config={
817+
"num_hidden_layers": 2,
818+
"rms_norm_eps": 1e-5,
819+
"hidden_size": 32,
820+
"intermediate_size": 64,
821+
"hidden_act": "silu",
822+
},
823+
vision_config={
824+
"num_hidden_layers": 2,
825+
"layer_norm_eps": 1e-5,
826+
"hidden_size": 48,
827+
"intermediate_size": 64,
828+
},
829+
)
830+
831+
dummy_model_instance = PaliGemmaForConditionalGeneration(config)
832+
assert isinstance(dummy_model_instance, PaliGemmaForConditionalGeneration)
833+
834+
# Check that model instance variables are not yet patched with Liger modules
835+
assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(paligemma_lce_forward)
836+
assert inspect.getsource(
837+
dummy_model_instance.vision_tower.vision_model.post_layernorm.forward
838+
) != inspect.getsource(LigerLayerNorm.forward)
839+
840+
for layer in dummy_model_instance.vision_tower.vision_model.encoder.layers:
841+
assert inspect.getsource(layer.layer_norm1.forward) != inspect.getsource(LigerLayerNorm.forward)
842+
assert inspect.getsource(layer.layer_norm2.forward) != inspect.getsource(LigerLayerNorm.forward)
843+
844+
# Test applying kernels to the model instance
845+
_apply_liger_kernel_to_instance(model=dummy_model_instance)
846+
847+
# Check that the model's instance variables were correctly patched with Liger modules
848+
assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(paligemma_lce_forward)
849+
assert inspect.getsource(
850+
dummy_model_instance.vision_tower.vision_model.post_layernorm.forward
851+
) == inspect.getsource(LigerLayerNorm.forward)
852+
853+
for layer in dummy_model_instance.vision_tower.vision_model.encoder.layers:
854+
assert inspect.getsource(layer.layer_norm1.forward) == inspect.getsource(LigerLayerNorm.forward)
855+
assert inspect.getsource(layer.layer_norm2.forward) == inspect.getsource(LigerLayerNorm.forward)
856+
857+
try:
858+
print(dummy_model_instance)
859+
except Exception as e:
860+
pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
861+
862+
796863
@pytest.mark.skipif(not is_gemma3_available(), reason="gemma3 module not available")
797864
def test_apply_liger_kernel_to_instance_for_gemma3_text():
798865
# Ensure any monkey patching is cleaned up for subsequent tests

0 commit comments

Comments
 (0)