Commit 5e3bf99
authored
apply monkey patch to instance instead of class method to avoid conflicts in mixed usage scenario (#772)
In some cases, **revert_liger_kernel_to_XX** can’t restore the
replacement of **model.forward** by **monkey_patch**. For scenarios
involving a passed model instance, modifying the assignment logic of
**XX_lce_forward** can prevent such issues.
Using **llama** as an example to update the **monkey_patch.py**. Then
**revert_liger_kernel_to_llama** logic can restore all environment
module configurations.
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
When I use **pytest** to test the **test_sft_trainer_with_liger** and
**test_train_offloading** functions from
[test_sft_slow.py](https://github.com/huggingface/trl/blob/main/tests/slow/test_sft_slow.py)
under the same process, the monkey patch applied by the liger kernel in
the first executed function affects the subsequent function's test.
Even though I used the **revert_liger_kernel_to_XX** logic to restore
the module configuration of the environment, I found that
**model.forward** still retains a reference to **XX_lce_forward**.
When the model passed to **apply_liger_kernel_to_XX** is an instance,
using **MethodType** to modify only the forward method of the
instantiated model can avoid this issue.
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
I wrote a simple test script to reproduce the bug:
```
from transformers import AutoModelForCausalLM
from transformers.utils import is_liger_kernel_available
from transformers.models.llama import modeling_llama
import importlib
if __name__ == '__main__':
model_path = 'trl-internal-testing/tiny-LlamaForCausalLM-3.2'
model_init_kwargs = {}
print(f"ori_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, ori_forward:{modeling_llama.LlamaForCausalLM.forward}")
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
if is_liger_kernel_available():
from liger_kernel.transformers import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model=model)
print(f"liger_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, liger_forward:{modeling_llama.LlamaForCausalLM.forward}")
print(f"liger_model:{model.forward}")
importlib.reload(modeling_llama)
print(f"reload_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, reload_forward:{modeling_llama.LlamaForCausalLM.forward}")
model_reload = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
print(f"reload_model:{model_reload.forward}")
```
The above code produces the following output:
```
ori_LlamaRMSNorm:<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>, ori_forward:<function LlamaForCausalLM.forward at 0x7f7f3321eca0>
liger_LlamaRMSNorm:<class 'liger_kernel.transformers.rms_norm.LigerRMSNorm'>, liger_forward:<function lce_forward at 0x7f7e26382160>
liger_model:<bound method lce_forward of LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 8)
(layers): ModuleList(
(0-1): 2 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=8, out_features=8, bias=False)
(k_proj): Linear(in_features=8, out_features=4, bias=False)
(v_proj): Linear(in_features=8, out_features=4, bias=False)
(o_proj): Linear(in_features=8, out_features=8, bias=False)
)
(mlp): LigerSwiGLUMLP(
(gate_proj): Linear(in_features=8, out_features=32, bias=False)
(up_proj): Linear(in_features=8, out_features=32, bias=False)
(down_proj): Linear(in_features=32, out_features=8, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True)
(post_attention_layernorm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True)
)
)
(norm): LigerRMSNorm((8,), eps=1e-06, offset=0.0, in_place=True)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=8, out_features=128256, bias=False)
)>
reload_LlamaRMSNorm:<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>, reload_forward:<function LlamaForCausalLM.forward at 0x7f7e263be700>
reload_model:<bound method lce_forward of LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 8)
(layers): ModuleList(
(0-1): 2 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=8, out_features=8, bias=False)
(k_proj): Linear(in_features=8, out_features=4, bias=False)
(v_proj): Linear(in_features=8, out_features=4, bias=False)
(o_proj): Linear(in_features=8, out_features=8, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=8, out_features=32, bias=False)
(up_proj): Linear(in_features=8, out_features=32, bias=False)
(down_proj): Linear(in_features=32, out_features=8, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((8,), eps=1e-06)
(post_attention_layernorm): LlamaRMSNorm((8,), eps=1e-06)
)
)
(norm): LlamaRMSNorm((8,), eps=1e-06)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=8, out_features=128256, bias=False)
)>
```
Obviously, ```reload_model: <bound method lce_forward of
LlamaForCausalLM(``` should not be lce_forward here.
## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
After modifying monkey_patch.py, reload_model can correctly output as
```reload_model: <bound method LlamaForCausalLM.forward of
LlamaForCausalLM(```.
<!--
Replace BLANK with your device type. For example, A100-80G-PCIe
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
XPU or CUDA1 parent 2c766fe commit 5e3bf99
File tree
2 files changed
+202
-32
lines changed- src/liger_kernel/transformers
- test/transformers
2 files changed
+202
-32
lines changedsrc/liger_kernel/transformers/monkey_patch.py
100644100755Lines changed: 113 additions & 31 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
| 5 | + | |
5 | 6 | | |
6 | 7 | | |
7 | 8 | | |
| |||
260 | 261 | | |
261 | 262 | | |
262 | 263 | | |
263 | | - | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
264 | 268 | | |
265 | 269 | | |
266 | | - | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
267 | 274 | | |
268 | 275 | | |
269 | 276 | | |
| |||
318 | 325 | | |
319 | 326 | | |
320 | 327 | | |
321 | | - | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
322 | 332 | | |
323 | | - | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
324 | 337 | | |
325 | 338 | | |
326 | 339 | | |
| |||
490 | 503 | | |
491 | 504 | | |
492 | 505 | | |
493 | | - | |
| 506 | + | |
494 | 507 | | |
495 | 508 | | |
496 | 509 | | |
| |||
506 | 519 | | |
507 | 520 | | |
508 | 521 | | |
509 | | - | |
| 522 | + | |
| 523 | + | |
| 524 | + | |
| 525 | + | |
510 | 526 | | |
511 | 527 | | |
512 | | - | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
513 | 532 | | |
514 | 533 | | |
515 | 534 | | |
| |||
592 | 611 | | |
593 | 612 | | |
594 | 613 | | |
595 | | - | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
596 | 618 | | |
597 | 619 | | |
598 | 620 | | |
| |||
660 | 682 | | |
661 | 683 | | |
662 | 684 | | |
663 | | - | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
664 | 689 | | |
665 | 690 | | |
666 | | - | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
667 | 695 | | |
668 | 696 | | |
669 | 697 | | |
| |||
737 | 765 | | |
738 | 766 | | |
739 | 767 | | |
740 | | - | |
| 768 | + | |
| 769 | + | |
| 770 | + | |
| 771 | + | |
741 | 772 | | |
742 | 773 | | |
743 | | - | |
| 774 | + | |
| 775 | + | |
| 776 | + | |
| 777 | + | |
744 | 778 | | |
745 | 779 | | |
746 | 780 | | |
| |||
812 | 846 | | |
813 | 847 | | |
814 | 848 | | |
815 | | - | |
| 849 | + | |
| 850 | + | |
| 851 | + | |
| 852 | + | |
816 | 853 | | |
817 | 854 | | |
818 | | - | |
| 855 | + | |
| 856 | + | |
| 857 | + | |
| 858 | + | |
819 | 859 | | |
820 | 860 | | |
821 | 861 | | |
| |||
894 | 934 | | |
895 | 935 | | |
896 | 936 | | |
897 | | - | |
| 937 | + | |
| 938 | + | |
| 939 | + | |
| 940 | + | |
898 | 941 | | |
899 | 942 | | |
900 | 943 | | |
| |||
964 | 1007 | | |
965 | 1008 | | |
966 | 1009 | | |
967 | | - | |
| 1010 | + | |
968 | 1011 | | |
969 | 1012 | | |
970 | 1013 | | |
| |||
975 | 1018 | | |
976 | 1019 | | |
977 | 1020 | | |
978 | | - | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
979 | 1025 | | |
980 | 1026 | | |
981 | 1027 | | |
| |||
1054 | 1100 | | |
1055 | 1101 | | |
1056 | 1102 | | |
1057 | | - | |
| 1103 | + | |
1058 | 1104 | | |
1059 | 1105 | | |
1060 | 1106 | | |
| |||
1072 | 1118 | | |
1073 | 1119 | | |
1074 | 1120 | | |
1075 | | - | |
| 1121 | + | |
| 1122 | + | |
| 1123 | + | |
| 1124 | + | |
1076 | 1125 | | |
1077 | 1126 | | |
1078 | | - | |
| 1127 | + | |
| 1128 | + | |
| 1129 | + | |
| 1130 | + | |
1079 | 1131 | | |
1080 | 1132 | | |
1081 | 1133 | | |
| |||
1167 | 1219 | | |
1168 | 1220 | | |
1169 | 1221 | | |
1170 | | - | |
| 1222 | + | |
| 1223 | + | |
| 1224 | + | |
| 1225 | + | |
1171 | 1226 | | |
1172 | 1227 | | |
1173 | | - | |
| 1228 | + | |
| 1229 | + | |
| 1230 | + | |
| 1231 | + | |
1174 | 1232 | | |
1175 | 1233 | | |
1176 | 1234 | | |
| |||
1226 | 1284 | | |
1227 | 1285 | | |
1228 | 1286 | | |
1229 | | - | |
| 1287 | + | |
| 1288 | + | |
| 1289 | + | |
| 1290 | + | |
1230 | 1291 | | |
1231 | 1292 | | |
1232 | 1293 | | |
| |||
1281 | 1342 | | |
1282 | 1343 | | |
1283 | 1344 | | |
1284 | | - | |
| 1345 | + | |
| 1346 | + | |
| 1347 | + | |
| 1348 | + | |
1285 | 1349 | | |
1286 | 1350 | | |
1287 | 1351 | | |
| |||
1350 | 1414 | | |
1351 | 1415 | | |
1352 | 1416 | | |
1353 | | - | |
| 1417 | + | |
1354 | 1418 | | |
1355 | 1419 | | |
1356 | 1420 | | |
1357 | 1421 | | |
1358 | | - | |
| 1422 | + | |
| 1423 | + | |
| 1424 | + | |
| 1425 | + | |
1359 | 1426 | | |
1360 | 1427 | | |
1361 | 1428 | | |
| |||
1443 | 1510 | | |
1444 | 1511 | | |
1445 | 1512 | | |
1446 | | - | |
| 1513 | + | |
| 1514 | + | |
| 1515 | + | |
| 1516 | + | |
1447 | 1517 | | |
1448 | 1518 | | |
1449 | 1519 | | |
| |||
1530 | 1600 | | |
1531 | 1601 | | |
1532 | 1602 | | |
1533 | | - | |
| 1603 | + | |
| 1604 | + | |
| 1605 | + | |
| 1606 | + | |
1534 | 1607 | | |
1535 | 1608 | | |
1536 | | - | |
| 1609 | + | |
| 1610 | + | |
| 1611 | + | |
| 1612 | + | |
1537 | 1613 | | |
1538 | 1614 | | |
1539 | 1615 | | |
| |||
1597 | 1673 | | |
1598 | 1674 | | |
1599 | 1675 | | |
1600 | | - | |
| 1676 | + | |
| 1677 | + | |
| 1678 | + | |
| 1679 | + | |
1601 | 1680 | | |
1602 | 1681 | | |
1603 | 1682 | | |
| |||
1661 | 1740 | | |
1662 | 1741 | | |
1663 | 1742 | | |
1664 | | - | |
| 1743 | + | |
| 1744 | + | |
| 1745 | + | |
| 1746 | + | |
1665 | 1747 | | |
1666 | 1748 | | |
1667 | 1749 | | |
| |||
0 commit comments