Skip to content

Commit d2431a9

Browse files
Revert "Bug Fix: name patching for modules" (#833)
This reverts commit 67a5439. ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 67a5439 commit d2431a9

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
7878
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
7979
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
8080
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
81-
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerRMSNorm.__name__)
82-
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerRMSNorm.__name__)
81+
module.modules_to_save.default.__class__.__name__ = LigerRMSNorm.__name__
82+
module.original_module.__class__.__name__ = LigerRMSNorm.__name__
8383
else:
8484
module.offset = offset
8585
module.casting_mode = casting_mode
@@ -88,7 +88,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
8888
module.row_mode = row_mode
8989
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
9090
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
91-
_bind_method_to_module(module, "_get_name", lambda self: LigerRMSNorm.__name__)
91+
module.__class__.__name__ = LigerRMSNorm.__name__
9292

9393

9494
def _patch_layer_norm_module(module, eps=1e-6):
@@ -110,28 +110,28 @@ def _patch_layer_norm_module(module, eps=1e-6):
110110
module.original_module.hidden_size = getattr(module, "hidden_size", None) or getattr(
111111
module, "normalized_shape", None
112112
)
113-
_bind_method_to_module(module.modules_to_save.default, "forward", LigerLayerNorm.forward)
114-
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerLayerNorm.extra_repr)
115-
_bind_method_to_module(module.original_module, "forward", LigerLayerNorm.forward)
116-
_bind_method_to_module(module.original_module, "extra_repr", LigerLayerNorm.extra_repr)
117-
_bind_method_to_module(module.modules_to_save.default, "_get_name", lambda self: LigerLayerNorm.__name__)
118-
_bind_method_to_module(module.original_module, "_get_name", lambda self: LigerLayerNorm.__name__)
113+
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
114+
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
115+
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
116+
_bind_method_to_module(module.original_module, "extra_repr", LigerRMSNorm.extra_repr)
117+
module.modules_to_save.default.__class__.__name__ = LigerLayerNorm.__name__
118+
module.original_module.__class__.__name__ = LigerLayerNorm.__name__
119119
else:
120120
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
121121
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
122122
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
123123
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
124-
_bind_method_to_module(module, "_get_name", lambda self: LigerLayerNorm.__name__)
124+
module.__class__.__name__ = LigerLayerNorm.__name__
125125

126126

127127
def _patch_swiglu_module(module, liger_module):
128128
_bind_method_to_module(module, "forward", liger_module.forward)
129-
_bind_method_to_module(module, "_get_name", lambda self: liger_module.__name__)
129+
module.__class__.__name__ = liger_module.__name__
130130

131131

132132
def _patch_geglu_module(module):
133133
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
134-
_bind_method_to_module(module, "_get_name", lambda self: LigerGEGLUMLP.__name__)
134+
module.__class__.__name__ = LigerGEGLUMLP.__name__
135135

136136

137137
def apply_liger_kernel_to_granite(

0 commit comments

Comments
 (0)