@@ -52,13 +52,25 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
5252 module .in_place = in_place
5353 _bind_method_to_module (module , "forward" , LigerRMSNorm .forward )
5454 _bind_method_to_module (module , "extra_repr" , LigerRMSNorm .extra_repr )
55+ module .__class__ .__name__ = LigerRMSNorm .__name__
5556
5657
5758def _patch_layer_norm_module (module , eps = 1e-6 ):
5859 module .variance_epsilon = getattr (module , "variance_epsilon" , None ) or getattr (module , "eps" , None ) or eps
5960 module .hidden_size = module .normalized_shape
6061 _bind_method_to_module (module , "forward" , LigerLayerNorm .forward )
6162 _bind_method_to_module (module , "extra_repr" , LigerLayerNorm .extra_repr )
63+ module .__class__ .__name__ = LigerLayerNorm .__name__
64+
65+
66+ def _patch_swiglu_module (module , liger_module ):
67+ _bind_method_to_module (module , "forward" , liger_module .forward )
68+ module .__class__ .__name__ = liger_module .__name__
69+
70+
71+ def _patch_geglu_module (module ):
72+ _bind_method_to_module (module , "forward" , LigerGEGLUMLP .forward )
73+ module .__class__ .__name__ = LigerGEGLUMLP .__name__
6274
6375
6476def apply_liger_kernel_to_granite (
@@ -134,7 +146,7 @@ def apply_liger_kernel_to_granite(
134146
135147 for decoder_layer in base_model .layers :
136148 if swiglu :
137- _bind_method_to_module (decoder_layer .mlp , "forward" , LigerSwiGLUMLP . forward )
149+ _patch_swiglu_module (decoder_layer .mlp , LigerSwiGLUMLP )
138150 if rms_norm :
139151 _patch_rms_norm_module (decoder_layer .input_layernorm )
140152 _patch_rms_norm_module (decoder_layer .post_attention_layernorm )
@@ -206,7 +218,7 @@ def apply_liger_kernel_to_llama(
206218
207219 for decoder_layer in base_model .layers :
208220 if swiglu :
209- _bind_method_to_module (decoder_layer .mlp , "forward" , LigerSwiGLUMLP . forward )
221+ _patch_swiglu_module (decoder_layer .mlp , LigerSwiGLUMLP )
210222 if rms_norm :
211223 _patch_rms_norm_module (decoder_layer .input_layernorm )
212224 _patch_rms_norm_module (decoder_layer .post_attention_layernorm )
@@ -296,7 +308,7 @@ def apply_liger_kernel_to_mllama(
296308 _patch_rms_norm_module (text_model .norm )
297309 for decoder_layer in text_model .layers :
298310 if swiglu :
299- _bind_method_to_module (decoder_layer .mlp , "forward" , LigerSwiGLUMLP . forward )
311+ _patch_swiglu_module (decoder_layer .mlp , LigerSwiGLUMLP )
300312 if rms_norm :
301313 _patch_rms_norm_module (decoder_layer .input_layernorm )
302314 _patch_rms_norm_module (decoder_layer .post_attention_layernorm )
@@ -370,7 +382,7 @@ def apply_liger_kernel_to_mistral(
370382
371383 for decoder_layer in base_model .layers :
372384 if swiglu :
373- _bind_method_to_module (decoder_layer .mlp , "forward" , LigerSwiGLUMLP . forward )
385+ _patch_swiglu_module (decoder_layer .mlp , LigerSwiGLUMLP )
374386 if rms_norm :
375387 _patch_rms_norm_module (decoder_layer .input_layernorm )
376388 _patch_rms_norm_module (decoder_layer .post_attention_layernorm )
@@ -442,7 +454,7 @@ def apply_liger_kernel_to_mixtral(
442454 for decoder_layer in base_model .layers :
443455 if swiglu :
444456 for expert in decoder_layer .block_sparse_moe .experts :
445- _bind_method_to_module (expert , "forward" , LigerBlockSparseTop2MLP . forward )
457+ _patch_swiglu_module (expert , LigerBlockSparseTop2MLP )
446458 if rms_norm :
447459 _patch_rms_norm_module (decoder_layer .input_layernorm )
448460 _patch_rms_norm_module (decoder_layer .post_attention_layernorm )
@@ -516,7 +528,7 @@ def apply_liger_kernel_to_gemma(
516528
517529 for decoder_layer in base_model .layers :
518530 if geglu :
519- _bind_method_to_module (decoder_layer .mlp , "forward" , LigerGEGLUMLP . forward )
531+ _patch_geglu_module (decoder_layer .mlp )
520532 if rms_norm :
521533 _patch_rms_norm_module_for_gemma (decoder_layer .input_layernorm )
522534 _patch_rms_norm_module_for_gemma (decoder_layer .post_attention_layernorm )
@@ -592,7 +604,7 @@ def apply_liger_kernel_to_gemma2(
592604
593605 for decoder_layer in base_model .layers :
594606 if geglu :
595- _bind_method_to_module (decoder_layer .mlp , "forward" , LigerGEGLUMLP . forward )
607+ _patch_geglu_module (decoder_layer .mlp )
596608 if rms_norm :
597609 _patch_rms_norm_module_for_gemma2 (decoder_layer .input_layernorm )
598610 _patch_rms_norm_module_for_gemma2 (decoder_layer .post_attention_layernorm )
@@ -776,7 +788,7 @@ def apply_liger_kernel_to_qwen2(
776788
777789 for decoder_layer in base_model .layers :
778790 if swiglu :
779- _bind_method_to_module (decoder_layer .mlp , "forward" , LigerSwiGLUMLP . forward )
791+ _patch_swiglu_module (decoder_layer .mlp , LigerSwiGLUMLP )
780792 if rms_norm :
781793 _patch_rms_norm_module (decoder_layer .input_layernorm )
782794 _patch_rms_norm_module (decoder_layer .post_attention_layernorm )
@@ -849,7 +861,7 @@ def apply_liger_kernel_to_qwen2_vl(
849861 _patch_rms_norm_module (base_model .norm )
850862 for decoder_layer in base_model .layers :
851863 if swiglu :
852- _bind_method_to_module (decoder_layer .mlp , "forward" , LigerSwiGLUMLP . forward )
864+ _patch_swiglu_module (decoder_layer .mlp , LigerSwiGLUMLP )
853865 if rms_norm :
854866 _patch_rms_norm_module (decoder_layer .input_layernorm )
855867 _patch_rms_norm_module (decoder_layer .post_attention_layernorm )
@@ -916,7 +928,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
916928 _patch_rms_norm_module (base_model .norm )
917929 for decoder_layer in base_model .layers :
918930 if swiglu :
919- _bind_method_to_module (decoder_layer .mlp , "forward" , LigerSwiGLUMLP . forward )
931+ _patch_swiglu_module (decoder_layer .mlp , LigerSwiGLUMLP )
920932 if rms_norm :
921933 _patch_rms_norm_module (decoder_layer .input_layernorm )
922934 _patch_rms_norm_module (decoder_layer .post_attention_layernorm )
@@ -985,7 +997,7 @@ def apply_liger_kernel_to_phi3(
985997
986998 for decoder_layer in base_model .layers :
987999 if swiglu :
988- _bind_method_to_module (decoder_layer .mlp , "forward" , LigerPhi3SwiGLUMLP . forward )
1000+ _patch_swiglu_module (decoder_layer .mlp , LigerPhi3SwiGLUMLP )
9891001 if rms_norm :
9901002 _patch_rms_norm_module (decoder_layer .input_layernorm )
9911003 _patch_rms_norm_module (decoder_layer .post_attention_layernorm )
@@ -1048,7 +1060,7 @@ def apply_liger_kernel_to_olmo2(
10481060
10491061 for decoder_layer in base_model .layers :
10501062 if swiglu :
1051- _bind_method_to_module (decoder_layer .mlp , "forward" , LigerSwiGLUMLP . forward )
1063+ _patch_swiglu_module (decoder_layer .mlp , LigerSwiGLUMLP )
10521064 if rms_norm :
10531065 _patch_rms_norm_module (decoder_layer .post_attention_layernorm , in_place = False )
10541066 _patch_rms_norm_module (decoder_layer .post_feedforward_layernorm , in_place = False )
0 commit comments