Skip to content

Commit fdc179b

Browse files
committed
Fix incorrect module name when monkey_patch applied to instantiated model
1 parent 812b050 commit fdc179b

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,25 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
5252
module.in_place = in_place
5353
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
5454
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
55+
module.__class__.__name__ = LigerRMSNorm.__name__
5556

5657

5758
def _patch_layer_norm_module(module, eps=1e-6):
5859
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
5960
module.hidden_size = module.normalized_shape
6061
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
6162
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
63+
module.__class__.__name__ = LigerLayerNorm.__name__
64+
65+
66+
def _patch_swiglu_module(module, liger_module):
67+
_bind_method_to_module(module, "forward", liger_module.forward)
68+
module.__class__.__name__ = liger_module.__name__
69+
70+
71+
def _patch_geglu_module(module):
72+
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
73+
module.__class__.__name__ = LigerGEGLUMLP.__name__
6274

6375

6476
def apply_liger_kernel_to_granite(
@@ -134,7 +146,7 @@ def apply_liger_kernel_to_granite(
134146

135147
for decoder_layer in base_model.layers:
136148
if swiglu:
137-
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
149+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
138150
if rms_norm:
139151
_patch_rms_norm_module(decoder_layer.input_layernorm)
140152
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -206,7 +218,7 @@ def apply_liger_kernel_to_llama(
206218

207219
for decoder_layer in base_model.layers:
208220
if swiglu:
209-
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
221+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
210222
if rms_norm:
211223
_patch_rms_norm_module(decoder_layer.input_layernorm)
212224
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -296,7 +308,7 @@ def apply_liger_kernel_to_mllama(
296308
_patch_rms_norm_module(text_model.norm)
297309
for decoder_layer in text_model.layers:
298310
if swiglu:
299-
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
311+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
300312
if rms_norm:
301313
_patch_rms_norm_module(decoder_layer.input_layernorm)
302314
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -370,7 +382,7 @@ def apply_liger_kernel_to_mistral(
370382

371383
for decoder_layer in base_model.layers:
372384
if swiglu:
373-
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
385+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
374386
if rms_norm:
375387
_patch_rms_norm_module(decoder_layer.input_layernorm)
376388
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -442,7 +454,7 @@ def apply_liger_kernel_to_mixtral(
442454
for decoder_layer in base_model.layers:
443455
if swiglu:
444456
for expert in decoder_layer.block_sparse_moe.experts:
445-
_bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
457+
_patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
446458
if rms_norm:
447459
_patch_rms_norm_module(decoder_layer.input_layernorm)
448460
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -516,7 +528,7 @@ def apply_liger_kernel_to_gemma(
516528

517529
for decoder_layer in base_model.layers:
518530
if geglu:
519-
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
531+
_patch_geglu_module(decoder_layer.mlp)
520532
if rms_norm:
521533
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
522534
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
@@ -592,7 +604,7 @@ def apply_liger_kernel_to_gemma2(
592604

593605
for decoder_layer in base_model.layers:
594606
if geglu:
595-
_bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
607+
_patch_geglu_module(decoder_layer.mlp)
596608
if rms_norm:
597609
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
598610
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
@@ -776,7 +788,7 @@ def apply_liger_kernel_to_qwen2(
776788

777789
for decoder_layer in base_model.layers:
778790
if swiglu:
779-
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
791+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
780792
if rms_norm:
781793
_patch_rms_norm_module(decoder_layer.input_layernorm)
782794
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -849,7 +861,7 @@ def apply_liger_kernel_to_qwen2_vl(
849861
_patch_rms_norm_module(base_model.norm)
850862
for decoder_layer in base_model.layers:
851863
if swiglu:
852-
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
864+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
853865
if rms_norm:
854866
_patch_rms_norm_module(decoder_layer.input_layernorm)
855867
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -916,7 +928,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
916928
_patch_rms_norm_module(base_model.norm)
917929
for decoder_layer in base_model.layers:
918930
if swiglu:
919-
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
931+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
920932
if rms_norm:
921933
_patch_rms_norm_module(decoder_layer.input_layernorm)
922934
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -985,7 +997,7 @@ def apply_liger_kernel_to_phi3(
985997

986998
for decoder_layer in base_model.layers:
987999
if swiglu:
988-
_bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
1000+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
9891001
if rms_norm:
9901002
_patch_rms_norm_module(decoder_layer.input_layernorm)
9911003
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -1048,7 +1060,7 @@ def apply_liger_kernel_to_olmo2(
10481060

10491061
for decoder_layer in base_model.layers:
10501062
if swiglu:
1051-
_bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
1063+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
10521064
if rms_norm:
10531065
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
10541066
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)

0 commit comments

Comments
 (0)