Skip to content

Commit 4ce469d

Browse files
authored
fix: upgrade liger to 0.5.8 and use native Gemma3 patches (axolotl-ai-cloud#2527)
* fix: upgrade liger to 0.5.8 and use native Gemma3 patches * fix: make lint happy * doc: update Liger Kernel FLCE support for Gemma 3
1 parent 60a8f09 commit 4ce469d

File tree

3 files changed

+2
-36
lines changed

3 files changed

+2
-36
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ triton>=3.0.0
66
mamba-ssm==1.2.0.post1
77
xformers>=0.0.23.post1
88
autoawq==0.2.7.post3
9-
liger-kernel==0.5.6
9+
liger-kernel==0.5.8
1010
# END section
1111

1212
packaging==23.2

src/axolotl/integrations/liger/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
2525
- deepseek_v2
2626
- gemma
2727
- gemma2
28-
- gemma3 (partial support, no support for FLCE yet)
28+
- gemma3
2929
- granite
3030
- jamba
3131
- llama

src/axolotl/integrations/liger/__init__.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import inspect
2222
import logging
2323
import sys
24-
from functools import partial
2524

2625
from axolotl.integrations.base import BasePlugin
2726

@@ -55,7 +54,6 @@ def pre_model_load(self, cfg):
5554
)
5655
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
5756
from liger_kernel.transformers.functional import liger_cross_entropy
58-
from liger_kernel.transformers.geglu import LigerGEGLUMLP
5957
from liger_kernel.transformers.layer_norm import LigerLayerNorm
6058
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
6159
from liger_kernel.transformers.rms_norm import LigerRMSNorm
@@ -141,38 +139,6 @@ def pre_model_load(self, cfg):
141139
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
142140
if cfg.liger_fused_linear_cross_entropy:
143141
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
144-
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
145-
from transformers.models.gemma3 import modeling_gemma3
146-
147-
if cfg.liger_rope:
148-
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
149-
if cfg.liger_rms_norm:
150-
151-
def _liger_rms_norm_wrapper(dim, **kwargs):
152-
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
153-
return LigerRMSNorm(hidden_size=dim, **kwargs)
154-
155-
modeling_gemma3.Gemma3RMSNorm = partial(
156-
_liger_rms_norm_wrapper,
157-
offset=1.0,
158-
casting_mode="gemma",
159-
init_fn="zeros",
160-
in_place=False,
161-
)
162-
if cfg.liger_glu_activation:
163-
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
164-
if cfg.liger_layer_norm:
165-
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
166-
167-
if cfg.liger_cross_entropy:
168-
from transformers.loss.loss_utils import nn
169-
170-
nn.functional.cross_entropy = liger_cross_entropy
171-
172-
if cfg.liger_fused_linear_cross_entropy:
173-
raise NotImplementedError(
174-
"Fused linear cross entropy is not yet supported for Gemma3."
175-
)
176142
elif cfg.model_config_type == "llama4":
177143
from axolotl.integrations.liger.models.llama4 import (
178144
apply_liger_kernel_to_llama4,

0 commit comments

Comments
 (0)