Skip to content

Commit 0ba7d36

Browse files
qwen3 and qwen3_moe support for liger kernels (axolotl-ai-cloud#2612)
* qwen3 and qwen3_moe support for liger kernels * fix moe module path * fix: qwen3 liger input args and mlp * fix: qwen3 input args and output class --------- Co-authored-by: NanoCode012 <[email protected]>
1 parent e4f73bc commit 0ba7d36

File tree

3 files changed

+375
-0
lines changed

3 files changed

+375
-0
lines changed

src/axolotl/integrations/liger/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,30 @@ def pre_model_load(self, cfg):
151151
rms_norm=cfg.liger_rms_norm,
152152
layer_norm=cfg.liger_layer_norm,
153153
)
154+
elif cfg.model_config_type == "qwen3":
155+
from axolotl.integrations.liger.models.qwen3 import (
156+
apply_liger_kernel_to_qwen3,
157+
)
158+
159+
apply_liger_kernel_to_qwen3(
160+
cross_entropy=cfg.liger_cross_entropy,
161+
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
162+
glu_activation=cfg.liger_glu_activation,
163+
rms_norm=cfg.liger_rms_norm,
164+
layer_norm=cfg.liger_layer_norm,
165+
)
166+
elif cfg.model_config_type == "qwen3_moe":
167+
from axolotl.integrations.liger.models.qwen3_moe import (
168+
apply_liger_kernel_to_qwen3_moe,
169+
)
170+
171+
apply_liger_kernel_to_qwen3_moe(
172+
cross_entropy=cfg.liger_cross_entropy,
173+
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
174+
glu_activation=cfg.liger_glu_activation,
175+
rms_norm=cfg.liger_rms_norm,
176+
layer_norm=cfg.liger_layer_norm,
177+
)
154178
else:
155179
logging.warning(
156180
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""
2+
Liger FLCE for Qwen3. Based on transformers v4.51.3.
3+
"""
4+
5+
import sys
6+
from typing import Optional, Tuple, Union
7+
8+
import torch
9+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
10+
from transformers.cache_utils import Cache
11+
from transformers.modeling_outputs import CausalLMOutputWithPast
12+
13+
14+
def lce_forward(
15+
self,
16+
input_ids: Optional[torch.LongTensor] = None,
17+
attention_mask: Optional[torch.Tensor] = None,
18+
position_ids: Optional[torch.LongTensor] = None,
19+
past_key_values: Optional[Cache] = None,
20+
inputs_embeds: Optional[torch.FloatTensor] = None,
21+
labels: Optional[torch.LongTensor] = None,
22+
use_cache: Optional[bool] = None,
23+
output_attentions: Optional[bool] = None,
24+
output_hidden_states: Optional[bool] = None,
25+
cache_position: Optional[torch.LongTensor] = None,
26+
logits_to_keep: Union[int, torch.Tensor] = 0,
27+
**kwargs,
28+
) -> Union[Tuple, CausalLMOutputWithPast]:
29+
r"""
30+
Args:
31+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
32+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
33+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
34+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
35+
36+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
37+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
38+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
39+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
40+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
41+
This is useful when using packed tensor format (single dimension for batch and sequence length).
42+
43+
Returns:
44+
"""
45+
46+
# pylint: disable=duplicate-code
47+
output_attentions = (
48+
output_attentions
49+
if output_attentions is not None
50+
else self.config.output_attentions
51+
)
52+
output_hidden_states = (
53+
output_hidden_states
54+
if output_hidden_states is not None
55+
else self.config.output_hidden_states
56+
)
57+
58+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
59+
outputs = self.model(
60+
input_ids=input_ids,
61+
attention_mask=attention_mask,
62+
position_ids=position_ids,
63+
past_key_values=past_key_values,
64+
inputs_embeds=inputs_embeds,
65+
use_cache=use_cache,
66+
output_attentions=output_attentions,
67+
output_hidden_states=output_hidden_states,
68+
cache_position=cache_position,
69+
**kwargs,
70+
)
71+
72+
hidden_states = outputs[0]
73+
74+
logits = None
75+
loss = None
76+
# if in training mode, don't materialize logits
77+
if self.training and (labels is not None):
78+
loss = LigerForCausalLMLoss(
79+
hidden_states=hidden_states,
80+
lm_head_weight=self.lm_head.weight,
81+
labels=labels,
82+
hidden_size=self.config.hidden_size,
83+
**kwargs,
84+
)
85+
86+
else: # if in inference mode materialize logits
87+
slice_indices = (
88+
slice(-logits_to_keep, None)
89+
if isinstance(logits_to_keep, int)
90+
else logits_to_keep
91+
)
92+
logits = self.lm_head(hidden_states[:, slice_indices, :])
93+
if labels is not None:
94+
loss = self.loss_function(
95+
logits=logits,
96+
labels=labels,
97+
vocab_size=self.config.vocab_size,
98+
**kwargs,
99+
)
100+
101+
return CausalLMOutputWithPast(
102+
loss=loss,
103+
logits=logits,
104+
past_key_values=outputs.past_key_values,
105+
hidden_states=outputs.hidden_states,
106+
attentions=outputs.attentions,
107+
)
108+
109+
110+
def apply_liger_kernel_to_qwen3(
111+
cross_entropy: bool = False,
112+
fused_linear_cross_entropy: bool = False,
113+
rms_norm: bool = False,
114+
glu_activation: bool = False,
115+
layer_norm: bool = False,
116+
**kwargs, # pylint: disable=unused-argument
117+
) -> None:
118+
# pylint: disable=duplicate-code
119+
"""
120+
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
121+
122+
Args:
123+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
124+
fused_linear_cross_entropy (bool):
125+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
126+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
127+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
128+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
129+
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
130+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
131+
"""
132+
133+
import transformers.models.qwen3.modeling_qwen3 # noqa: F401 # pylint: disable=unused-import
134+
from liger_kernel.transformers.functional import liger_cross_entropy
135+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
136+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
137+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
138+
139+
assert not (
140+
cross_entropy and fused_linear_cross_entropy
141+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
142+
143+
modeling_qwen3 = sys.modules["transformers.models.qwen3.modeling_qwen3"]
144+
145+
if rms_norm:
146+
modeling_qwen3.Qwen3RMSNorm = LigerRMSNorm
147+
148+
if glu_activation:
149+
modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
150+
151+
if layer_norm:
152+
modeling_qwen3.nn.LayerNorm = LigerLayerNorm
153+
154+
if cross_entropy:
155+
from transformers.loss.loss_utils import nn
156+
157+
nn.functional.cross_entropy = liger_cross_entropy
158+
159+
if fused_linear_cross_entropy:
160+
modeling_qwen3.Qwen3ForCausalLM.forward = lce_forward
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
Liger FLCE for Qwen3 MoE. Based on transformers v4.51.3.
3+
"""
4+
5+
import sys
6+
from copy import deepcopy
7+
from typing import List, Optional, Union
8+
9+
import torch
10+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11+
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
12+
from transformers.models.qwen3_moe.modeling_qwen3_moe import load_balancing_loss_func
13+
14+
15+
def lce_forward(
16+
self,
17+
input_ids: Optional[torch.LongTensor] = None,
18+
attention_mask: Optional[torch.Tensor] = None,
19+
position_ids: Optional[torch.LongTensor] = None,
20+
past_key_values: Optional[List[torch.FloatTensor]] = None,
21+
inputs_embeds: Optional[torch.FloatTensor] = None,
22+
labels: Optional[torch.LongTensor] = None,
23+
use_cache: Optional[bool] = None,
24+
output_attentions: Optional[bool] = None,
25+
output_hidden_states: Optional[bool] = None,
26+
output_router_logits: Optional[bool] = None,
27+
cache_position: Optional[torch.LongTensor] = None,
28+
logits_to_keep: Union[int, torch.Tensor] = 0,
29+
**kwargs,
30+
) -> MoeCausalLMOutputWithPast:
31+
r"""
32+
Args:
33+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
34+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
35+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
36+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
37+
38+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
39+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
40+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
41+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
42+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
43+
This is useful when using packed tensor format (single dimension for batch and sequence length).
44+
45+
Returns:
46+
"""
47+
48+
# pylint: disable=duplicate-code
49+
output_attentions = (
50+
output_attentions
51+
if output_attentions is not None
52+
else self.config.output_attentions
53+
)
54+
output_router_logits = (
55+
output_router_logits
56+
if output_router_logits is not None
57+
else self.config.output_router_logits
58+
)
59+
output_hidden_states = (
60+
output_hidden_states
61+
if output_hidden_states is not None
62+
else self.config.output_hidden_states
63+
)
64+
65+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
66+
outputs = self.model(
67+
input_ids=input_ids,
68+
attention_mask=attention_mask,
69+
position_ids=position_ids,
70+
past_key_values=past_key_values,
71+
inputs_embeds=inputs_embeds,
72+
use_cache=use_cache,
73+
output_attentions=output_attentions,
74+
output_hidden_states=output_hidden_states,
75+
output_router_logits=output_router_logits,
76+
cache_position=cache_position,
77+
**kwargs,
78+
)
79+
80+
hidden_states = outputs[0]
81+
82+
logits = None
83+
loss = None
84+
# if in training mode, don't materialize logits
85+
if self.training and (labels is not None):
86+
loss = LigerForCausalLMLoss(
87+
hidden_states=hidden_states,
88+
lm_head_weight=self.lm_head.weight,
89+
labels=labels,
90+
hidden_size=self.config.hidden_size,
91+
**kwargs,
92+
)
93+
94+
else: # if in inference mode materialize logits
95+
slice_indices = (
96+
slice(-logits_to_keep, None)
97+
if isinstance(logits_to_keep, int)
98+
else logits_to_keep
99+
)
100+
logits = self.lm_head(hidden_states[:, slice_indices, :])
101+
if labels is not None:
102+
loss = self.loss_function(
103+
logits=logits,
104+
labels=labels,
105+
vocab_size=self.config.vocab_size,
106+
**kwargs,
107+
)
108+
109+
aux_loss = None
110+
if output_router_logits:
111+
aux_loss = load_balancing_loss_func(
112+
outputs.router_logits,
113+
self.num_experts,
114+
self.num_experts_per_tok,
115+
attention_mask,
116+
)
117+
if labels is not None:
118+
loss += self.router_aux_loss_coef * aux_loss.to(
119+
loss.device
120+
) # make sure to reside in the same device
121+
122+
return MoeCausalLMOutputWithPast(
123+
loss=loss,
124+
aux_loss=aux_loss,
125+
logits=logits,
126+
past_key_values=outputs.past_key_values,
127+
hidden_states=outputs.hidden_states,
128+
attentions=outputs.attentions,
129+
)
130+
131+
132+
def apply_liger_kernel_to_qwen3_moe(
133+
cross_entropy: bool = False,
134+
fused_linear_cross_entropy: bool = False,
135+
rms_norm: bool = False,
136+
glu_activation: bool = False,
137+
layer_norm: bool = False,
138+
**kwargs, # pylint: disable=unused-argument
139+
) -> None:
140+
# pylint: disable=duplicate-code
141+
"""
142+
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
143+
144+
Args:
145+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
146+
fused_linear_cross_entropy (bool):
147+
Whether to apply Liger's fused linear cross entropy loss. Default is False.
148+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be False.
149+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
150+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
151+
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
152+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
153+
"""
154+
155+
import transformers.models.qwen3_moe.modeling_qwen3_moe # noqa: F401 # pylint: disable=unused-import
156+
from liger_kernel.transformers.functional import liger_cross_entropy
157+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
158+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
159+
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
160+
161+
assert not (
162+
cross_entropy and fused_linear_cross_entropy
163+
), "cross_entropy and fused_linear_cross_entropy cannot both be True."
164+
165+
modeling_qwen3_moe = sys.modules["transformers.models.qwen3_moe.modeling_qwen3_moe"]
166+
167+
if rms_norm:
168+
modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
169+
170+
if glu_activation:
171+
172+
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
173+
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
174+
# clone config to avoid modifying the original
175+
config = deepcopy(config)
176+
if intermediate_size:
177+
setattr(config, "intermediate_size", intermediate_size)
178+
return LigerSwiGLUMLP(config, **kwargs)
179+
180+
modeling_qwen3_moe.Qwen3MoeMLP = _liger_swiglu_mlp_wrapper
181+
182+
if layer_norm:
183+
modeling_qwen3_moe.nn.LayerNorm = LigerLayerNorm
184+
185+
if cross_entropy:
186+
from transformers.loss.loss_utils import nn
187+
188+
nn.functional.cross_entropy = liger_cross_entropy
189+
190+
if fused_linear_cross_entropy:
191+
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = lce_forward

0 commit comments

Comments
 (0)