Skip to content

Commit 0ea822f

Browse files
Bug fixes in patching module (#834)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> 1. Fix `_patch_layer_norm_module` by replacing `LigerRMSNorm` with `LigerLayerNorm`. 2. Correctly change the name of the instance and not of the Class by replacing patches like `module.__class__.__name__ = LigerLayerNorm.__name__` with `_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)`. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> ``` from transformers import AutoModelForCausalLM from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 import torch device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct").to(device) apply_liger_kernel_to_qwen2(model=model) print(model) ``` prints: ``` Applied Liger kernels to Qwen2 Qwen2ForCausalLM( (model): Qwen2Model( (embed_tokens): Embedding(151936, 896) (layers): ModuleList( (0-23): 24 x Qwen2DecoderLayer( (self_attn): Qwen2Attention( (q_proj): Linear(in_features=896, out_features=896, bias=True) (k_proj): Linear(in_features=896, out_features=128, bias=True) (v_proj): Linear(in_features=896, out_features=128, bias=True) (o_proj): Linear(in_features=896, out_features=896, bias=False) ) (mlp): LigerSwiGLUMLP( (gate_proj): Linear(in_features=896, out_features=4864, bias=False) (up_proj): Linear(in_features=896, out_features=4864, bias=False) (down_proj): Linear(in_features=4864, out_features=896, bias=False) (act_fn): SiLU() ) (input_layernorm): LigerRMSNorm((896,), eps=1e-06, offset=0.0, in_place=True, row_mode=None) (post_attention_layernorm): LigerRMSNorm((896,), eps=1e-06, offset=0.0, in_place=True, row_mode=None) ) ) (norm): LigerRMSNorm((896,), eps=1e-06, offset=0.0, in_place=True, row_mode=None) (rotary_emb): Qwen2RotaryEmbedding() ) (lm_head): Linear(in_features=896, out_features=151936, bias=False) ) ``` <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent d2431a9 commit 0ea822f

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
7878
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
7979
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
8080
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
81-
module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
82-
module.original_module.__class__.__name__ = LigerRMSNorm.__name__
81+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
82+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
8383
else:
8484
module.offset = offset
8585
module.casting_mode = casting_mode
@@ -88,7 +88,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
8888
module.row_mode = row_mode
8989
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
9090
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
91-
module.__class__.__name__ = LigerRMSNorm.__name__
91+
_bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
9292

9393

9494
def _patch_layer_norm_module(module, eps=1e-6):
@@ -110,28 +110,28 @@ def _patch_layer_norm_module(module, eps=1e-6):
110110
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
111111
module, "normalized_shape", None
112112
)
113-
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
114-
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
115-
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
116-
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
117-
module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
118-
module.original_module.__class__.__name__ = LigerLayerNorm.__name__
113+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
114+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
115+
_bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
116+
_bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
117+
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
118+
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
119119
else:
120120
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
121121
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
122122
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
123123
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
124-
module.__class__.__name__ = LigerLayerNorm.__name__
124+
_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
125125

126126

127127
def _patch_swiglu_module(module, liger_module):
128128
_bind_method_to_module(module, "forward", liger_module.forward)
129-
module.__class__.__name__ = liger_module.__name__
129+
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
130130

131131

132132
def _patch_geglu_module(module):
133133
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
134-
module.__class__.__name__ = LigerGEGLUMLP.__name__
134+
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
135135

136136

137137
def apply_liger_kernel_to_granite(

0 commit comments

Comments
 (0)