Skip to content

Commit 98d9e19

Browse files
Kingsleyandheraurelishimizust
authored
add hunyuanv1 dense and moe model (#940)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This pull request introduces support for the Hunyuanv1 dense and moe model within the Liger-Kernel framework. [HunyuanV1 Model PR](huggingface/transformers#39606) <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <img width="3482" height="595" alt="1763112849386" src="https://github.com/user-attachments/assets/eaab7b15-7737-4285-9f23-1d01cc09ee91" /> And a simple test: <img width="2748" height="1051" alt="image" src="https://github.com/user-attachments/assets/fa6155f3-9456-4054-a874-41e1c25e4b47" /> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: H20 - [ x ] run `make test` to ensure correctness - [ x ] run `make checkstyle` to ensure code style - [ x ] run `make test-convergence` to ensure convergence Co-authored-by: aureli <aureli@tecent.com> Co-authored-by: Steven Shimizu <shimizust@gmail.com>
1 parent 24416a4 commit 98d9e19

File tree

11 files changed

+840
-0
lines changed

11 files changed

+840
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ loss.backward()
264264
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
265265
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266266
| InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267+
| HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
268+
| HunyuanV1 MoE | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267269

268270

269271
## Low-level APIs

src/liger_kernel/transformers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
4343
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
4444
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
45+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
46+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
4547
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
4648
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
4749
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
@@ -128,6 +130,8 @@ def __getattr__(name: str):
128130
"apply_liger_kernel_to_qwen3_vl_moe",
129131
"apply_liger_kernel_to_smollm3",
130132
"apply_liger_kernel_to_smolvlm",
133+
"apply_liger_kernel_to_hunyuan_v1_dense",
134+
"apply_liger_kernel_to_hunyuan_v1_moe",
131135
}
132136

133137
if name in monkey_patch_symbols:
@@ -202,5 +206,7 @@ def __getattr__(name: str):
202206
"apply_liger_kernel_to_qwen3_vl_moe",
203207
"apply_liger_kernel_to_smollm3",
204208
"apply_liger_kernel_to_smolvlm",
209+
"apply_liger_kernel_to_hunyuan_v1_dense",
210+
"apply_liger_kernel_to_hunyuan_v1_moe",
205211
]
206212
)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from typing import List
2+
from typing import Optional
3+
from typing import Union
4+
5+
import torch
6+
7+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
8+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
9+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
10+
11+
12+
def lce_forward(
13+
self,
14+
input_ids: Optional[torch.LongTensor] = None,
15+
attention_mask: Optional[torch.Tensor] = None,
16+
position_ids: Optional[torch.LongTensor] = None,
17+
past_key_values: Optional[List[torch.FloatTensor]] = None,
18+
inputs_embeds: Optional[torch.FloatTensor] = None,
19+
labels: Optional[torch.LongTensor] = None,
20+
use_cache: Optional[bool] = None,
21+
output_attentions: Optional[bool] = None,
22+
output_hidden_states: Optional[bool] = None,
23+
cache_position: Optional[torch.LongTensor] = None,
24+
logits_to_keep: Union[int, torch.Tensor] = 0,
25+
skip_logits: Optional[bool] = None,
26+
return_dict: Optional[bool] = None,
27+
**kwargs,
28+
) -> LigerCausalLMOutputWithPast:
29+
r"""
30+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
31+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
32+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
33+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
34+
35+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
36+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
37+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
38+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
39+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
40+
This is useful when using packed tensor format (single dimension for batch and sequence length).
41+
42+
Returns:
43+
44+
Example:
45+
46+
```python
47+
>>> from transformers import AutoTokenizer, HunYuanDenseV1ForCausalLM
48+
49+
>>> model = HunYuanDenseV1ForCausalLM.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")
50+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-hunyuan_v1_dense/HunYuanDenseV1-2-7b-hf")
51+
52+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
53+
>>> inputs = tokenizer(prompt, return_tensors="pt")
54+
55+
>>> # Generate
56+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
59+
```"""
60+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
61+
output_hidden_states = (
62+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
63+
)
64+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65+
66+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
67+
outputs = self.model(
68+
input_ids=input_ids,
69+
attention_mask=attention_mask,
70+
position_ids=position_ids,
71+
past_key_values=past_key_values,
72+
inputs_embeds=inputs_embeds,
73+
use_cache=use_cache,
74+
output_attentions=output_attentions,
75+
output_hidden_states=output_hidden_states,
76+
cache_position=cache_position,
77+
**kwargs,
78+
)
79+
80+
hidden_states = outputs[0]
81+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
82+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
83+
kept_hidden_states = hidden_states[:, slice_indices, :]
84+
85+
shift_labels = kwargs.pop("shift_labels", None)
86+
logits = None
87+
loss = None
88+
token_accuracy = None
89+
90+
if skip_logits and labels is None and shift_labels is None:
91+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
92+
93+
if skip_logits is None:
94+
# By default, if in training mode, don't materialize logits
95+
skip_logits = self.training and (labels is not None or shift_labels is not None)
96+
97+
# Compute loss
98+
if skip_logits:
99+
result = LigerForCausalLMLoss(
100+
hidden_states=kept_hidden_states,
101+
lm_head_weight=self.lm_head.weight,
102+
labels=labels,
103+
shift_labels=shift_labels,
104+
hidden_size=self.config.hidden_size,
105+
**kwargs,
106+
)
107+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
108+
109+
else:
110+
logits = self.lm_head(kept_hidden_states)
111+
if labels is not None or shift_labels is not None:
112+
loss = self.loss_function(
113+
logits=logits,
114+
labels=labels,
115+
shift_labels=shift_labels,
116+
vocab_size=self.config.vocab_size,
117+
**kwargs,
118+
)
119+
120+
if not return_dict:
121+
output = (logits,) + outputs[1:]
122+
output = ((loss,) + output) if loss is not None else output
123+
output = output + (token_accuracy,) if token_accuracy is not None else output
124+
return output
125+
126+
# Return custom output class with accuracy field
127+
return LigerCausalLMOutputWithPast(
128+
loss=loss,
129+
logits=logits,
130+
past_key_values=outputs.past_key_values,
131+
hidden_states=outputs.hidden_states,
132+
attentions=outputs.attentions,
133+
token_accuracy=token_accuracy,
134+
)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2558,6 +2558,123 @@ def apply_liger_kernel_to_qwen3_next(
25582558
_patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
25592559

25602560

2561+
def apply_liger_kernel_to_hunyuan_v1_dense(
2562+
rope: bool = True,
2563+
cross_entropy: bool = False,
2564+
fused_linear_cross_entropy: bool = True,
2565+
rms_norm: bool = True,
2566+
swiglu: bool = True,
2567+
model: PreTrainedModel = None,
2568+
) -> None:
2569+
"""
2570+
Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
2571+
"""
2572+
assert not (cross_entropy and fused_linear_cross_entropy), (
2573+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
2574+
)
2575+
2576+
from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
2577+
from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
2578+
2579+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
2580+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2581+
2582+
if rope:
2583+
modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
2584+
2585+
if rms_norm:
2586+
modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
2587+
2588+
if cross_entropy:
2589+
from transformers.loss.loss_utils import nn
2590+
2591+
nn.functional.cross_entropy = liger_cross_entropy
2592+
2593+
if fused_linear_cross_entropy:
2594+
if model is not None:
2595+
model.forward = MethodType(hunyuan_v1_lce_forward, model)
2596+
else:
2597+
modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
2598+
2599+
if swiglu:
2600+
modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
2601+
2602+
if model is not None:
2603+
# The model instance already exists, so we need to additionally patch the
2604+
# instance variables that reference already-instantiated modules
2605+
2606+
# get the base model from the model instance
2607+
base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
2608+
2609+
if rms_norm:
2610+
_patch_rms_norm_module(base_model.norm)
2611+
for decoder_layer in base_model.layers:
2612+
if swiglu:
2613+
_patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
2614+
if rms_norm:
2615+
_patch_rms_norm_module(decoder_layer.input_layernorm)
2616+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2617+
2618+
2619+
def apply_liger_kernel_to_hunyuan_v1_moe(
2620+
rope: bool = True,
2621+
cross_entropy: bool = False,
2622+
fused_linear_cross_entropy: bool = True,
2623+
rms_norm: bool = True,
2624+
swiglu: bool = True,
2625+
model: PreTrainedModel = None,
2626+
) -> None:
2627+
"""
2628+
Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
2629+
"""
2630+
assert not (cross_entropy and fused_linear_cross_entropy), (
2631+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
2632+
)
2633+
2634+
from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
2635+
from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
2636+
2637+
from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
2638+
from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2639+
2640+
if rope:
2641+
modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
2642+
2643+
if rms_norm:
2644+
modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
2645+
2646+
if cross_entropy:
2647+
from transformers.loss.loss_utils import nn
2648+
2649+
nn.functional.cross_entropy = liger_cross_entropy
2650+
2651+
if fused_linear_cross_entropy:
2652+
if model is not None:
2653+
model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
2654+
else:
2655+
modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
2656+
2657+
if swiglu:
2658+
modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
2659+
2660+
if model is not None:
2661+
# The model instance already exists, so we need to additionally patch the
2662+
# instance variables that reference already-instantiated modules
2663+
2664+
# get the base model from the model instance
2665+
base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
2666+
2667+
if rms_norm:
2668+
_patch_rms_norm_module(base_model.norm)
2669+
for decoder_layer in base_model.layers:
2670+
if swiglu:
2671+
for mlp_expert in decoder_layer.mlp.experts:
2672+
_patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
2673+
if rms_norm:
2674+
_patch_rms_norm_module(decoder_layer.input_layernorm)
2675+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2676+
2677+
25612678
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
25622679
MODEL_TYPE_TO_APPLY_LIGER_FN = {
25632680
"gemma": apply_liger_kernel_to_gemma,
@@ -2595,6 +2712,8 @@ def apply_liger_kernel_to_qwen3_next(
25952712
"paligemma": apply_liger_kernel_to_paligemma,
25962713
"falcon_h1": apply_liger_kernel_to_falcon_h1,
25972714
"smolvlm": apply_liger_kernel_to_smolvlm,
2715+
"hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
2716+
"hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
25982717
}
25992718

26002719

src/liger_kernel/transformers/swiglu.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,20 @@ def __init__(self, config, intermediate_size=None):
7777

7878
def forward(self, x):
7979
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
80+
81+
82+
class LigerHunyuanV1SwiGLUMLP(nn.Module):
83+
def __init__(self, config, layer_idx=None, is_shared_mlp=False):
84+
super().__init__()
85+
self.config = config
86+
self.hidden_size = config.hidden_size
87+
self.intermediate_size = config.intermediate_size
88+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
89+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
90+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
91+
self.layer_idx = layer_idx
92+
if config.hidden_act not in ["silu", "swish"]:
93+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
94+
95+
def forward(self, x):
96+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))

0 commit comments

Comments
 (0)