Commit e808281
Fix paligemma missing test_liger_kernel_to_instance_for_paligemma_instance test in test_monkey_patch.py (#785)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
WHAT THIS PR ADDS
- A new unit test file containing
`test_apply_liger_kernel_to_instance_for_paligemma()`. Fix #776
- The test instantiates a dummy PaliGemmaForConditionalGeneration model,
confirms it is un-patched, runs `_apply_liger_kernel_to_instance()`,
then verifies that:
- model.forward is replaced by `paligemma_lce_forward()`.
- `vision_tower.vision_model.post_layernorm.forward` is replaced by
`LigerLayerNorm.forward`.
- Every encoder layer’s `layer_norm1.forward` and `layer_norm2.forward`
are also replaced.
- Source equality is checked with inspect.getsource before and after
patching.
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
`transformers==4.49.0`
<details>
<summary>Test result</summary>
❯ python3 -m pytest test/transformers/test_monkey_patch.py -k paligemma
-v -rP
============================================== test session starts
==============================================
platform linux -- Python 3.11.11, pytest-8.4.1, pluggy-1.6.0 --
/home/vvvdwbvvv/.local/bin/python3
cachedir: .pytest_cache
rootdir: /home/vvvdwbvvv/develop/Liger-Kernel
configfile: pyproject.toml
plugins: asyncio-1.0.0
asyncio: mode=Mode.AUTO, asyncio_default_fixture_loop_scope=None,
asyncio_default_test_loop_scope=function
collected 33 items / 32 deselected / 1 selected
test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_paligemma
------------------------------------------------- live log call
-------------------------------------------------
INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1864
Applying Liger kernels to model instance with model type: paligemma with
kwargs: {}
PASSED [100%]
==================================================== PASSES
=====================================================
_______________________________
test_apply_liger_kernel_to_instance_for_paligemma
_______________________________
--------------------------------------------- Captured stdout call
----------------------------------------------
PaliGemmaForConditionalGeneration(
(vision_tower): SiglipVisionModel(
(vision_model): SiglipVisionTransformer(
(embeddings): SiglipVisionEmbeddings(
(patch_embedding): Conv2d(3, 48, kernel_size=(16, 16), stride=(16, 16),
padding=valid)
(position_embedding): Embedding(196, 48)
)
(encoder): SiglipEncoder(
(layers): ModuleList(
(0-1): 2 x SiglipEncoderLayer(
(self_attn): SiglipSdpaAttention(
(k_proj): Linear(in_features=48, out_features=48, bias=True)
(v_proj): Linear(in_features=48, out_features=48, bias=True)
(q_proj): Linear(in_features=48, out_features=48, bias=True)
(out_proj): Linear(in_features=48, out_features=48, bias=True)
)
(layer_norm1): LigerLayerNorm((48,), eps=1e-05)
(mlp): SiglipMLP(
(activation_fn): PytorchGELUTanh()
(fc1): Linear(in_features=48, out_features=64, bias=True)
(fc2): Linear(in_features=64, out_features=48, bias=True)
)
(layer_norm2): LigerLayerNorm((48,), eps=1e-05)
)
)
)
(post_layernorm): LigerLayerNorm((48,), eps=1e-05)
(head): SiglipMultiheadAttentionPoolingHead(
(attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=48,
out_features=48, bias=True)
)
(layernorm): LigerLayerNorm((48,), eps=1e-05, elementwise_affine=True)
(mlp): SiglipMLP(
(activation_fn): PytorchGELUTanh()
(fc1): Linear(in_features=48, out_features=64, bias=True)
(fc2): Linear(in_features=64, out_features=48, bias=True)
)
)
)
)
(multi_modal_projector): PaliGemmaMultiModalProjector(
(linear): Linear(in_features=48, out_features=2048, bias=True)
)
(language_model): GemmaForCausalLM(
(model): GemmaModel(
(embed_tokens): Embedding(256000, 32, padding_idx=0)
(layers): ModuleList(
(0-1): 2 x GemmaDecoderLayer(
(self_attn): GemmaAttention(
(q_proj): Linear(in_features=32, out_features=4096, bias=False)
(k_proj): Linear(in_features=32, out_features=4096, bias=False)
(v_proj): Linear(in_features=32, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=32, bias=False)
)
(mlp): LigerGEGLUMLP(
(gate_proj): Linear(in_features=32, out_features=64, bias=False)
(up_proj): Linear(in_features=32, out_features=64, bias=False)
(down_proj): Linear(in_features=64, out_features=32, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LigerRMSNorm((32,), eps=1e-05, offset=1.0,
in_place=True, row_mode=None)
(post_attention_layernorm): LigerRMSNorm((32,), eps=1e-05, offset=1.0,
in_place=True, row_mode=None)
)
)
(norm): LigerRMSNorm((32,), eps=1e-05, offset=1.0, in_place=True,
row_mode=None)
(rotary_emb): GemmaRotaryEmbedding()
)
(lm_head): Linear(in_features=32, out_features=256000, bias=False)
)
)
----------------------------------------------- Captured log call
-----------------------------------------------
INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1864
Applying Liger kernels to model instance with model type: paligemma with
kwargs: {}
======================================= 1 passed, 32 deselected in 1.78s
========================================
</details>
`transformers==4.53.1`
<details>
<summary>Test result</summary>
❯ python3 -m pytest test/transformers/test_monkey_patch.py -k paligemma
-v -rP
============================================== test session starts
==============================================
platform linux -- Python 3.11.11, pytest-8.4.1, pluggy-1.6.0 --
/home/vvvdwbvvv/.local/bin/python3
cachedir: .pytest_cache
rootdir: /home/vvvdwbvvv/develop/Liger-Kernel
configfile: pyproject.toml
plugins: asyncio-1.0.0
asyncio: mode=Mode.AUTO, asyncio_default_fixture_loop_scope=None,
asyncio_default_test_loop_scope=function
collected 33 items / 32 deselected / 1 selected
test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_paligemma
------------------------------------------------- live log call
-------------------------------------------------
INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1864
Applying Liger kernels to model instance with model type: paligemma with
kwargs: {}
PASSED [100%]
==================================================== PASSES
=====================================================
_______________________________
test_apply_liger_kernel_to_instance_for_paligemma
_______________________________
--------------------------------------------- Captured stdout call
----------------------------------------------
PaliGemmaForConditionalGeneration(
(model): PaliGemmaModel(
(vision_tower): SiglipVisionModel(
(vision_model): SiglipVisionTransformer(
(embeddings): SiglipVisionEmbeddings(
(patch_embedding): Conv2d(3, 48, kernel_size=(16, 16), stride=(16, 16),
padding=valid)
(position_embedding): Embedding(196, 48)
)
(encoder): SiglipEncoder(
(layers): ModuleList(
(0-1): 2 x SiglipEncoderLayer(
(layer_norm1): LigerLayerNorm((48,), eps=1e-05)
(self_attn): SiglipAttention(
(k_proj): Linear(in_features=48, out_features=48, bias=True)
(v_proj): Linear(in_features=48, out_features=48, bias=True)
(q_proj): Linear(in_features=48, out_features=48, bias=True)
(out_proj): Linear(in_features=48, out_features=48, bias=True)
)
(layer_norm2): LigerLayerNorm((48,), eps=1e-05)
(mlp): SiglipMLP(
(activation_fn): PytorchGELUTanh()
(fc1): Linear(in_features=48, out_features=64, bias=True)
(fc2): Linear(in_features=64, out_features=48, bias=True)
)
)
)
)
(post_layernorm): LigerLayerNorm((48,), eps=1e-05)
(head): SiglipMultiheadAttentionPoolingHead(
(attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=48,
out_features=48, bias=True)
)
(layernorm): LigerLayerNorm((48,), eps=1e-05, elementwise_affine=True)
(mlp): SiglipMLP(
(activation_fn): PytorchGELUTanh()
(fc1): Linear(in_features=48, out_features=64, bias=True)
(fc2): Linear(in_features=64, out_features=48, bias=True)
)
)
)
)
(multi_modal_projector): PaliGemmaMultiModalProjector(
(linear): Linear(in_features=48, out_features=2048, bias=True)
)
(language_model): GemmaModel(
(embed_tokens): Embedding(256000, 32, padding_idx=0)
(layers): ModuleList(
(0-1): 2 x GemmaDecoderLayer(
(self_attn): GemmaAttention(
(q_proj): Linear(in_features=32, out_features=4096, bias=False)
(k_proj): Linear(in_features=32, out_features=4096, bias=False)
(v_proj): Linear(in_features=32, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=32, bias=False)
)
(mlp): LigerGEGLUMLP(
(gate_proj): Linear(in_features=32, out_features=64, bias=False)
(up_proj): Linear(in_features=32, out_features=64, bias=False)
(down_proj): Linear(in_features=64, out_features=32, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LigerRMSNorm((32,), eps=1e-05, offset=1.0,
in_place=True, row_mode=None)
(post_attention_layernorm): LigerRMSNorm((32,), eps=1e-05, offset=1.0,
in_place=True, row_mode=None)
)
)
(norm): LigerRMSNorm((32,), eps=1e-05, offset=1.0, in_place=True,
row_mode=None)
(rotary_emb): GemmaRotaryEmbedding()
)
)
(lm_head): Linear(in_features=32, out_features=256000, bias=False)
)
----------------------------------------------- Captured log call
-----------------------------------------------
INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1864
Applying Liger kernels to model instance with model type: paligemma with
kwargs: {}
======================================= 1 passed, 32 deselected in 2.42s
========================================
</details>
- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
---------
Co-authored-by: Shao Tang <[email protected]>
Co-authored-by: Tcc0403 <[email protected]>1 parent 5700de2 commit e808281
File tree
2 files changed
+71
-2
lines changed- src/liger_kernel/transformers
- test/transformers
2 files changed
+71
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1096 | 1096 | | |
1097 | 1097 | | |
1098 | 1098 | | |
| 1099 | + | |
1099 | 1100 | | |
| 1101 | + | |
1100 | 1102 | | |
1101 | 1103 | | |
1102 | 1104 | | |
| |||
1155 | 1157 | | |
1156 | 1158 | | |
1157 | 1159 | | |
1158 | | - | |
| 1160 | + | |
1159 | 1161 | | |
1160 | 1162 | | |
1161 | 1163 | | |
| |||
1165 | 1167 | | |
1166 | 1168 | | |
1167 | 1169 | | |
1168 | | - | |
| 1170 | + | |
1169 | 1171 | | |
1170 | 1172 | | |
1171 | 1173 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
38 | 38 | | |
39 | 39 | | |
40 | 40 | | |
| 41 | + | |
41 | 42 | | |
42 | 43 | | |
43 | 44 | | |
| |||
49 | 50 | | |
50 | 51 | | |
51 | 52 | | |
| 53 | + | |
52 | 54 | | |
53 | 55 | | |
54 | 56 | | |
| |||
126 | 128 | | |
127 | 129 | | |
128 | 130 | | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
129 | 140 | | |
130 | 141 | | |
131 | 142 | | |
| |||
793 | 804 | | |
794 | 805 | | |
795 | 806 | | |
| 807 | + | |
| 808 | + | |
| 809 | + | |
| 810 | + | |
| 811 | + | |
| 812 | + | |
| 813 | + | |
| 814 | + | |
| 815 | + | |
| 816 | + | |
| 817 | + | |
| 818 | + | |
| 819 | + | |
| 820 | + | |
| 821 | + | |
| 822 | + | |
| 823 | + | |
| 824 | + | |
| 825 | + | |
| 826 | + | |
| 827 | + | |
| 828 | + | |
| 829 | + | |
| 830 | + | |
| 831 | + | |
| 832 | + | |
| 833 | + | |
| 834 | + | |
| 835 | + | |
| 836 | + | |
| 837 | + | |
| 838 | + | |
| 839 | + | |
| 840 | + | |
| 841 | + | |
| 842 | + | |
| 843 | + | |
| 844 | + | |
| 845 | + | |
| 846 | + | |
| 847 | + | |
| 848 | + | |
| 849 | + | |
| 850 | + | |
| 851 | + | |
| 852 | + | |
| 853 | + | |
| 854 | + | |
| 855 | + | |
| 856 | + | |
| 857 | + | |
| 858 | + | |
| 859 | + | |
| 860 | + | |
| 861 | + | |
| 862 | + | |
796 | 863 | | |
797 | 864 | | |
798 | 865 | | |
| |||
0 commit comments