Commit 0ea822f
authored
Bug fixes in patching module (#834)
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
1. Fix `_patch_layer_norm_module` by replacing `LigerRMSNorm` with
`LigerLayerNorm`.
2. Correctly change the name of the instance and not of the Class by
replacing patches like `module.__class__.__name__ =
LigerLayerNorm.__name__` with `_bind_method_to_module(module,
"_get_name", lambda self: LigerLayerNorm.__name__)`.
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
```
from transformers import AutoModelForCausalLM
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct").to(device)
apply_liger_kernel_to_qwen2(model=model)
print(model)
```
prints:
```
Applied Liger kernels to Qwen2
Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(151936, 896)
(layers): ModuleList(
(0-23): 24 x Qwen2DecoderLayer(
(self_attn): Qwen2Attention(
(q_proj): Linear(in_features=896, out_features=896, bias=True)
(k_proj): Linear(in_features=896, out_features=128, bias=True)
(v_proj): Linear(in_features=896, out_features=128, bias=True)
(o_proj): Linear(in_features=896, out_features=896, bias=False)
)
(mlp): LigerSwiGLUMLP(
(gate_proj): Linear(in_features=896, out_features=4864, bias=False)
(up_proj): Linear(in_features=896, out_features=4864, bias=False)
(down_proj): Linear(in_features=4864, out_features=896, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LigerRMSNorm((896,), eps=1e-06, offset=0.0, in_place=True, row_mode=None)
(post_attention_layernorm): LigerRMSNorm((896,), eps=1e-06, offset=0.0, in_place=True, row_mode=None)
)
)
(norm): LigerRMSNorm((896,), eps=1e-06, offset=0.0, in_place=True, row_mode=None)
(rotary_emb): Qwen2RotaryEmbedding()
)
(lm_head): Linear(in_features=896, out_features=151936, bias=False)
)
```
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence1 parent d2431a9 commit 0ea822f
1 file changed
+12
-12
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
78 | 78 | | |
79 | 79 | | |
80 | 80 | | |
81 | | - | |
82 | | - | |
| 81 | + | |
| 82 | + | |
83 | 83 | | |
84 | 84 | | |
85 | 85 | | |
| |||
88 | 88 | | |
89 | 89 | | |
90 | 90 | | |
91 | | - | |
| 91 | + | |
92 | 92 | | |
93 | 93 | | |
94 | 94 | | |
| |||
110 | 110 | | |
111 | 111 | | |
112 | 112 | | |
113 | | - | |
114 | | - | |
115 | | - | |
116 | | - | |
117 | | - | |
118 | | - | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
119 | 119 | | |
120 | 120 | | |
121 | 121 | | |
122 | 122 | | |
123 | 123 | | |
124 | | - | |
| 124 | + | |
125 | 125 | | |
126 | 126 | | |
127 | 127 | | |
128 | 128 | | |
129 | | - | |
| 129 | + | |
130 | 130 | | |
131 | 131 | | |
132 | 132 | | |
133 | 133 | | |
134 | | - | |
| 134 | + | |
135 | 135 | | |
136 | 136 | | |
137 | 137 | | |
| |||
0 commit comments