|
21 | 21 | import inspect
|
22 | 22 | import logging
|
23 | 23 | import sys
|
24 |
| -from functools import partial |
25 | 24 |
|
26 | 25 | from axolotl.integrations.base import BasePlugin
|
27 | 26 |
|
@@ -55,7 +54,6 @@ def pre_model_load(self, cfg):
|
55 | 54 | )
|
56 | 55 | from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
57 | 56 | from liger_kernel.transformers.functional import liger_cross_entropy
|
58 |
| - from liger_kernel.transformers.geglu import LigerGEGLUMLP |
59 | 57 | from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
60 | 58 | from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
61 | 59 | from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
@@ -141,38 +139,6 @@ def pre_model_load(self, cfg):
|
141 | 139 | modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
142 | 140 | if cfg.liger_fused_linear_cross_entropy:
|
143 | 141 | modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
144 |
| - elif cfg.model_config_type in ["gemma3", "gemma3_text"]: |
145 |
| - from transformers.models.gemma3 import modeling_gemma3 |
146 |
| - |
147 |
| - if cfg.liger_rope: |
148 |
| - modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb |
149 |
| - if cfg.liger_rms_norm: |
150 |
| - |
151 |
| - def _liger_rms_norm_wrapper(dim, **kwargs): |
152 |
| - "Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm" |
153 |
| - return LigerRMSNorm(hidden_size=dim, **kwargs) |
154 |
| - |
155 |
| - modeling_gemma3.Gemma3RMSNorm = partial( |
156 |
| - _liger_rms_norm_wrapper, |
157 |
| - offset=1.0, |
158 |
| - casting_mode="gemma", |
159 |
| - init_fn="zeros", |
160 |
| - in_place=False, |
161 |
| - ) |
162 |
| - if cfg.liger_glu_activation: |
163 |
| - modeling_gemma3.Gemma3MLP = LigerGEGLUMLP |
164 |
| - if cfg.liger_layer_norm: |
165 |
| - modeling_gemma3.nn.LayerNorm = LigerLayerNorm |
166 |
| - |
167 |
| - if cfg.liger_cross_entropy: |
168 |
| - from transformers.loss.loss_utils import nn |
169 |
| - |
170 |
| - nn.functional.cross_entropy = liger_cross_entropy |
171 |
| - |
172 |
| - if cfg.liger_fused_linear_cross_entropy: |
173 |
| - raise NotImplementedError( |
174 |
| - "Fused linear cross entropy is not yet supported for Gemma3." |
175 |
| - ) |
176 | 142 | elif cfg.model_config_type == "llama4":
|
177 | 143 | from axolotl.integrations.liger.models.llama4 import (
|
178 | 144 | apply_liger_kernel_to_llama4,
|
|
0 commit comments