Skip to content

Commit de41fd0

Browse files
aiyiwang2025quinnrong94jeejeelee
authored andcommitted
[Model]Add Tencent HunYuanMoEV1 Model Support (vllm-project#20114)
Signed-off-by: aiyiwang <[email protected]> Signed-off-by: Jee Jee Li <[email protected]> Co-authored-by: quinnrong <[email protected]> Co-authored-by: Jee Jee Li <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
1 parent 48c33a4 commit de41fd0

File tree

6 files changed

+949
-6
lines changed

6 files changed

+949
-6
lines changed

docs/models/supported_models.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ Specified using `--task generate`.
350350
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
351351
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
352352
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
353+
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`etc. | | | ✅︎ |
353354
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
354355
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
355356
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
@@ -387,7 +388,7 @@ Specified using `--task generate`.
387388
| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | ✅︎ |
388389
| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | ✅︎ |
389390
| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
390-
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`etc. | | | |
391+
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`etc. | | | |
391392
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | |
392393
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | |
393394

tests/models/registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ def check_available_online(
188188
"GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501
189189
"Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1",
190190
trust_remote_code=True),
191+
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-A13B-Instruct",
192+
trust_remote_code=True),
191193
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
192194
trust_remote_code=True),
193195
"InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b",
@@ -490,4 +492,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
490492
raise ValueError(f"No example model defined for {model_id}")
491493

492494

493-
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
495+
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)

tests/models/test_initialization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
3333

3434
# Ensure at least 2 expert per group
3535
# Since `grouped_topk` assums top-2
36-
num_experts = getattr(text_config, 'n_group', 1) * 2
36+
n_group = getattr(text_config, 'n_group', None)
37+
num_experts = n_group * 2 if n_group is not None else 2
3738

3839
text_config.update({
3940
"num_layers": 1,

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,41 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
533533
return cache
534534

535535

536+
class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
537+
"""RotaryEmbedding extended with Dynamic NTK alpha.
538+
539+
Based on the original RotaryEmbedding implementation.
540+
"""
541+
542+
def __init__(
543+
self,
544+
head_size: int,
545+
rotary_dim: int,
546+
max_position_embeddings: int,
547+
base: float,
548+
is_neox_style: bool,
549+
scaling_alpha: float,
550+
dtype: torch.dtype,
551+
) -> None:
552+
self.scaling_alpha = scaling_alpha
553+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
554+
is_neox_style, dtype)
555+
556+
def _compute_cos_sin_cache(self) -> torch.Tensor:
557+
# For Hunyuan DynamicNTKAlphaRotaryEmbedding
558+
max_len = self.max_position_embeddings
559+
base = self.base * self.scaling_alpha**(self.rotary_dim /
560+
(self.rotary_dim - 2))
561+
inv_freq = self._compute_inv_freq(base)
562+
t = torch.arange(max_len, dtype=torch.float)
563+
564+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
565+
cos = freqs.cos()
566+
sin = freqs.sin()
567+
cache = torch.cat((cos, sin), dim=-1)
568+
return cache
569+
570+
536571
# Inverse dim formula to find dim based on number of rotations
537572
def _yarn_find_correction_dim(num_rotations: int,
538573
dim: int,
@@ -1929,9 +1964,15 @@ def get_rope(
19291964
mixed_b)
19301965
elif scaling_type == "dynamic":
19311966
scaling_factor = rope_scaling["factor"]
1932-
rotary_emb = DynamicNTKScalingRotaryEmbedding(
1933-
head_size, rotary_dim, max_position, base, is_neox_style,
1934-
scaling_factor, dtype)
1967+
scaling_alpha = rope_scaling["alpha"]
1968+
if scaling_alpha:
1969+
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
1970+
head_size, rotary_dim, max_position, base, is_neox_style,
1971+
scaling_alpha, dtype)
1972+
else:
1973+
rotary_emb = DynamicNTKScalingRotaryEmbedding(
1974+
head_size, rotary_dim, max_position, base, is_neox_style,
1975+
scaling_factor, dtype)
19351976
elif scaling_type == "yarn":
19361977
scaling_factor = rope_scaling["factor"]
19371978
original_max_position = rope_scaling[

0 commit comments

Comments
 (0)