Skip to content

Commit 0701bab

Browse files
authored
Fix: row_mode patching in _patch_rms_norm_module (#765)
## Summary There is a bug in #731 that only happens when the model is not initialize with `AutoLigerKernelForCausalLM`, because `self.row_mode` is not initialized. Fixes #764. ## Details Add `row_mode=None` to the default parameter of `_patch_rms_norm_module` fixes the issue. `def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):` ## Testing Done - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 766ca10 commit 0701bab

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
5454
module.__dict__[method_name] = new_method.__get__(module, module.__class__)
5555

5656

57-
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
57+
def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
5858
# Check if the module is a PEFT ModulesToSaveWrapper
5959
# If it is, we need to patch the modules_to_save.default and original_modules
6060
if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
@@ -64,12 +64,14 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
6464
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
6565
)
6666
module.modules_to_save.default.in_place = in_place
67+
module.modules_to_save.default.row_mode = row_mode
6768
module.original_module.offset = offset
6869
module.original_module.casting_mode = casting_mode
6970
module.original_module.variance_epsilon = (
7071
getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
7172
)
7273
module.original_module.in_place = in_place
74+
module.original_module.row_mode = row_mode
7375
_bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
7476
_bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
7577
_bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
@@ -81,6 +83,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
8183
module.casting_mode = casting_mode
8284
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
8385
module.in_place = in_place
86+
module.row_mode = row_mode
8487
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
8588
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
8689
module.__class__.__name__ = LigerRMSNorm.__name__

src/liger_kernel/transformers/rms_norm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ def forward(self, hidden_states):
4141
)
4242

4343
def extra_repr(self):
44-
return (
45-
f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
46-
)
44+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
4745

4846

4947
class LigerRMSNormForGemma(LigerRMSNorm):

0 commit comments

Comments
 (0)