Skip to content

Commit 4a2da99

Browse files
Fix RMS norm patching (#741)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Fixes #739. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> Creates new classes for various RMS norms and removes the use of `partial` for RMS norms. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> Fixes errors of the form: ``` TypeError: isinstance() arg 2 must be a type, a tuple of types, or a union ``` for glm4, olmo2, gemma1 and gemma2. However, we are now seeing errors when matching the actual logits with gemma models, which can be tracked separately in #729. <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 08f2ea4 commit 4a2da99

File tree

3 files changed

+44
-14
lines changed

3 files changed

+44
-14
lines changed

src/liger_kernel/transformers/gema3_rms.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -627,8 +627,8 @@ def apply_liger_kernel_to_gemma(
627627
from transformers.models.gemma import modeling_gemma
628628
from transformers.models.gemma.modeling_gemma import GemmaModel
629629

630-
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
631-
LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
630+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
631+
632632
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
633633

634634
if rope:
@@ -701,7 +701,8 @@ def apply_liger_kernel_to_gemma2(
701701
from transformers.models.gemma2 import modeling_gemma2
702702
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
703703

704-
LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
704+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
705+
705706
_patch_rms_norm_module_for_gemma2 = partial(
706707
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
707708
)
@@ -780,8 +781,8 @@ def apply_liger_kernel_to_gemma3_text(
780781
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
781782
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
782783

783-
from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
784784
from liger_kernel.transformers.model.gemma3 import causal_forward
785+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
785786

786787
_patch_rms_norm_module_for_gemma3 = partial(
787788
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
@@ -1451,11 +1452,12 @@ def apply_liger_kernel_to_olmo2(
14511452
from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
14521453

14531454
from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1455+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
14541456

14551457
if rope:
14561458
modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
14571459
if rms_norm:
1458-
modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
1460+
modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
14591461
if swiglu:
14601462
modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
14611463
if cross_entropy:
@@ -1514,11 +1516,12 @@ def apply_liger_kernel_to_glm4(
15141516
from transformers.models.glm4.modeling_glm4 import Glm4Model
15151517

15161518
from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
1519+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
15171520

15181521
if rope:
15191522
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
15201523
if rms_norm:
1521-
modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
1524+
modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
15221525
if swiglu:
15231526
modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
15241527
if cross_entropy:

src/liger_kernel/transformers/rms_norm.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,38 @@ def extra_repr(self):
4444
return (
4545
f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
4646
)
47+
48+
49+
class LigerRMSNormForGemma(LigerRMSNorm):
50+
def __init__(
51+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None
52+
):
53+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
54+
55+
56+
class LigerRMSNormForGemma2(LigerRMSNorm):
57+
def __init__(
58+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
59+
):
60+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
61+
62+
63+
class LigerRMSNormForGemma3(LigerRMSNorm):
64+
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
65+
66+
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
67+
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
68+
69+
70+
class LigerRMSNormForOlmo2(LigerRMSNorm):
71+
def __init__(
72+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
73+
):
74+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
75+
76+
77+
class LigerRMSNormForGlm4(LigerRMSNorm):
78+
def __init__(
79+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
80+
):
81+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)

0 commit comments

Comments
 (0)