Skip to content

Commit 2942b58

Browse files
vvvdwbvvvTcc0403
andauthored
Add support for Qwen3Next model with Liger kernels (#912)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Adding qwen3-next model support #896 https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct?library=transformers Available in transformers>=4.57.0 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Co-authored-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com>
1 parent c856fba commit 2942b58

File tree

10 files changed

+601
-3
lines changed

10 files changed

+601
-3
lines changed

src/liger_kernel/transformers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
5656
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
5757
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
58+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
5859
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
5960

6061

@@ -117,6 +118,7 @@ def __getattr__(name: str):
117118
"apply_liger_kernel_to_qwen2_vl",
118119
"apply_liger_kernel_to_qwen3",
119120
"apply_liger_kernel_to_qwen3_moe",
121+
"apply_liger_kernel_to_qwen3_next",
120122
"apply_liger_kernel_to_smollm3",
121123
}
122124

@@ -185,6 +187,7 @@ def __getattr__(name: str):
185187
"apply_liger_kernel_to_qwen2_vl",
186188
"apply_liger_kernel_to_qwen3",
187189
"apply_liger_kernel_to_qwen3_moe",
190+
"apply_liger_kernel_to_qwen3_next",
188191
"apply_liger_kernel_to_smollm3",
189192
]
190193
)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from typing import TYPE_CHECKING
2+
from typing import List
3+
from typing import Optional
4+
from typing import Union
5+
6+
import torch
7+
8+
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
9+
from transformers.modeling_outputs import MoeModelOutputWithPast
10+
11+
if TYPE_CHECKING:
12+
from transformers.models.qwen3_next.modeling_qwen3_next import load_balancing_loss_func
13+
14+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15+
16+
17+
def lce_forward(
18+
self,
19+
input_ids: Optional[torch.LongTensor] = None,
20+
attention_mask: Optional[torch.Tensor] = None,
21+
position_ids: Optional[torch.LongTensor] = None,
22+
past_key_values: Optional[List[torch.FloatTensor]] = None,
23+
inputs_embeds: Optional[torch.FloatTensor] = None,
24+
labels: Optional[torch.LongTensor] = None,
25+
use_cache: Optional[bool] = None,
26+
output_attentions: Optional[bool] = None,
27+
output_hidden_states: Optional[bool] = None,
28+
output_router_logits: Optional[bool] = None,
29+
cache_position: Optional[torch.LongTensor] = None,
30+
logits_to_keep: Union[int, torch.Tensor] = 0,
31+
skip_logits: Optional[bool] = None,
32+
**kwargs,
33+
) -> MoeCausalLMOutputWithPast:
34+
r"""
35+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
36+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
37+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
38+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
39+
40+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
41+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
42+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
43+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
44+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
45+
This is useful when using packed tensor format (single dimension for batch and sequence length).
46+
47+
Returns:
48+
49+
Example:
50+
51+
```python
52+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
53+
54+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
55+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct")
56+
57+
>>> prompt = "Give me a short introduction to large language model."
58+
>>> inputs = tokenizer(prompt, return_tensors="pt")
59+
60+
>>> # Generate
61+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
62+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
63+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
64+
```"""
65+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
66+
output_router_logits = (
67+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
68+
)
69+
70+
output_hidden_states = (
71+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72+
)
73+
74+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
75+
outputs: MoeModelOutputWithPast = self.model(
76+
input_ids=input_ids,
77+
attention_mask=attention_mask,
78+
position_ids=position_ids,
79+
past_key_values=past_key_values,
80+
inputs_embeds=inputs_embeds,
81+
use_cache=use_cache,
82+
output_attentions=output_attentions,
83+
output_hidden_states=output_hidden_states,
84+
output_router_logits=output_router_logits,
85+
cache_position=cache_position,
86+
**kwargs,
87+
)
88+
89+
hidden_states = outputs.last_hidden_state
90+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
91+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
92+
kept_hidden_states = hidden_states[:, slice_indices, :]
93+
94+
shift_labels = kwargs.pop("shift_labels", None)
95+
logits = None
96+
loss = None
97+
98+
if skip_logits is None:
99+
skip_logits = self.training and (labels is not None or shift_labels is not None)
100+
101+
if skip_logits:
102+
loss = LigerForCausalLMLoss(
103+
hidden_states=kept_hidden_states,
104+
lm_head_weight=self.lm_head.weight,
105+
labels=labels,
106+
shift_labels=shift_labels,
107+
hidden_size=self.config.hidden_size,
108+
**kwargs,
109+
)
110+
else: # if in inference model materialize logits
111+
logits = self.lm_head(kept_hidden_states)
112+
if labels is not None or shift_labels is not None:
113+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
114+
115+
aux_loss = None
116+
if output_router_logits:
117+
aux_loss = load_balancing_loss_func(
118+
outputs.router_logits,
119+
self.num_experts,
120+
self.num_experts_per_tok,
121+
attention_mask,
122+
)
123+
if labels is not None:
124+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
125+
126+
return MoeCausalLMOutputWithPast(
127+
loss=loss,
128+
aux_loss=aux_loss,
129+
logits=logits,
130+
past_key_values=outputs.past_key_values,
131+
hidden_states=outputs.hidden_states,
132+
attentions=outputs.attentions,
133+
router_logits=outputs.router_logits,
134+
)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2180,6 +2180,97 @@ def apply_liger_kernel_to_falcon_h1(
21802180
_patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
21812181

21822182

2183+
def apply_liger_kernel_to_qwen3_next(
2184+
rope: bool = False,
2185+
cross_entropy: bool = False,
2186+
fused_linear_cross_entropy: bool = True,
2187+
rms_norm: bool = True,
2188+
swiglu: bool = True,
2189+
model: PreTrainedModel = None,
2190+
) -> None:
2191+
"""
2192+
Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2193+
2194+
Args:
2195+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2196+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2197+
fused_linear_cross_entropy (bool):
2198+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
2199+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2200+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2201+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2202+
swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2203+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2204+
loaded. Default is None.
2205+
"""
2206+
assert not (cross_entropy and fused_linear_cross_entropy), (
2207+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
2208+
)
2209+
2210+
from transformers.models.qwen3_next import modeling_qwen3_next
2211+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2212+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2213+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2214+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
2215+
2216+
from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2217+
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2218+
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
2219+
2220+
if rope:
2221+
# It might enocunter nan issue
2222+
# modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2223+
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
2224+
if rms_norm:
2225+
modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
2226+
if cross_entropy:
2227+
from transformers.loss.loss_utils import nn
2228+
2229+
nn.functional.cross_entropy = liger_cross_entropy
2230+
if fused_linear_cross_entropy:
2231+
if model is not None:
2232+
if isinstance(model, Qwen3NextForCausalLM):
2233+
model.forward = MethodType(qwen3_next_lce_forward, model)
2234+
else:
2235+
raise TypeError(
2236+
f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2237+
)
2238+
else:
2239+
modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
2240+
if swiglu:
2241+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2242+
modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
2243+
2244+
if model is not None:
2245+
# The model instance already exists, so we need to additionally patch the
2246+
# instance variables that reference already-instantiated modules
2247+
if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2248+
base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
2249+
else:
2250+
raise TypeError(
2251+
f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2252+
)
2253+
2254+
if rms_norm:
2255+
_patch_rms_norm_module(base_model.norm)
2256+
2257+
for decoder_layer in base_model.layers:
2258+
if rms_norm:
2259+
_patch_rms_norm_module(decoder_layer.input_layernorm)
2260+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2261+
2262+
# Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2263+
if swiglu:
2264+
if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2265+
_patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2266+
if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2267+
_patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2268+
experts = getattr(decoder_layer.mlp, "experts", None)
2269+
if experts is not None:
2270+
for expert in experts:
2271+
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2272+
2273+
21832274
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
21842275
MODEL_TYPE_TO_APPLY_LIGER_FN = {
21852276
"gemma": apply_liger_kernel_to_gemma,
@@ -2207,6 +2298,7 @@ def apply_liger_kernel_to_falcon_h1(
22072298
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
22082299
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
22092300
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2301+
"qwen3_next": apply_liger_kernel_to_qwen3_next,
22102302
"smollm3": apply_liger_kernel_to_smollm3,
22112303
"phi3": apply_liger_kernel_to_phi3,
22122304
"paligemma": apply_liger_kernel_to_paligemma,

src/liger_kernel/transformers/rms_norm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,10 @@ def __init__(
7777
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
7878
):
7979
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
80+
81+
82+
class LigerRMSNormForQwen3Next(LigerRMSNorm):
83+
def __init__(
84+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
85+
):
86+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)

test/convergence/bf16/test_mini_models.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
4141
from liger_kernel.transformers import apply_liger_kernel_to_qwen3
4242
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe
43+
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next
4344
from liger_kernel.transformers import apply_liger_kernel_to_smollm3
4445
from test.utils import DEFAULT_DATASET_PATH
4546
from test.utils import MiniModelConfig
@@ -68,6 +69,7 @@
6869
from test.utils import revert_liger_kernel_to_qwen2_vl
6970
from test.utils import revert_liger_kernel_to_qwen3
7071
from test.utils import revert_liger_kernel_to_qwen3_moe
72+
from test.utils import revert_liger_kernel_to_qwen3_next
7173
from test.utils import revert_liger_kernel_to_smollm3
7274
from test.utils import set_seed
7375
from test.utils import simple_collate_fn
@@ -212,6 +214,15 @@
212214
except ImportError:
213215
FALCONH1_AVAILABLE = False
214216

217+
try:
218+
# Qwen3Next is only available in transformers>=4.57.0
219+
from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig
220+
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
221+
222+
QWEN3NEXT_AVAILABLE = True
223+
except ImportError:
224+
QWEN3NEXT_AVAILABLE = False
225+
215226
from liger_kernel.utils import infer_device
216227

217228
device = infer_device()
@@ -1106,6 +1117,43 @@
11061117
),
11071118
)
11081119

1120+
if QWEN3NEXT_AVAILABLE:
1121+
MINI_MODEL_SETUPS["mini_qwen3_next"] = MiniModelConfig(
1122+
liger_kernel_patch_func=apply_liger_kernel_to_qwen3_next,
1123+
liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen3_next,
1124+
model_class=Qwen3NextForCausalLM,
1125+
mini_model_config=Qwen3NextConfig( # Copypaste Qwen3MoeConfig
1126+
vocab_size=32000,
1127+
hidden_size=896,
1128+
intermediate_size=4864,
1129+
num_hidden_layers=4,
1130+
num_attention_heads=8,
1131+
num_key_value_heads=2,
1132+
hidden_act="silu",
1133+
max_position_embeddings=32768,
1134+
initializer_range=0.02,
1135+
rms_norm_eps=1e-6,
1136+
use_cache=True,
1137+
tie_word_embeddings=False,
1138+
rope_theta=10000.0,
1139+
rope_scaling=None,
1140+
attention_bias=False,
1141+
use_sliding_window=False,
1142+
sliding_window=4096,
1143+
max_window_layers=28,
1144+
attention_dropout=0.0,
1145+
decoder_sparse_step=1,
1146+
moe_intermediate_size=768,
1147+
num_experts_per_tok=2,
1148+
num_experts=8,
1149+
norm_topk_prob=False,
1150+
output_router_logits=False,
1151+
router_aux_loss_coef=0.001,
1152+
# config.dtype must be set if fla installed since there's a bug in the original code (No torch.get_current_dtype())
1153+
dtype=torch.bfloat16,
1154+
),
1155+
)
1156+
11091157

11101158
def create_model(model_name="mini_llama4"):
11111159
"""
@@ -1141,7 +1189,7 @@ def run_mini_model(
11411189
"rms_norm": True,
11421190
}
11431191

1144-
if "glm4" in model_name:
1192+
if "glm4" in model_name or "qwen3_next" in model_name:
11451193
kwargs["rope"] = False
11461194

11471195
model_supports_layer_norm = "qwen2_vl" in model_name
@@ -1634,6 +1682,25 @@ def run_mini_model(
16341682
),
16351683
],
16361684
),
1685+
pytest.param(
1686+
"mini_qwen3_next",
1687+
32,
1688+
1e-5,
1689+
torch.bfloat16,
1690+
1e-2,
1691+
1e-2,
1692+
1e-1,
1693+
1e-1,
1694+
1e-2,
1695+
1e-2,
1696+
marks=[
1697+
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
1698+
pytest.mark.skipif(
1699+
not QWEN3NEXT_AVAILABLE,
1700+
reason="Qwen3Next not available in this version of transformers",
1701+
),
1702+
],
1703+
),
16371704
],
16381705
)
16391706
def test_mini_model(

0 commit comments

Comments
 (0)