Skip to content

Commit f2baa8d

Browse files
authored
refactor: add MTP template for future support. (#627)
1 parent 6493d8e commit f2baa8d

17 files changed

+524
-678
lines changed

xllm/core/framework/model/model_input_params.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ struct ModelInputParams {
257257
params.layer_synchronizer = layer_synchronizer;
258258
#endif
259259
params.expert_load_data = expert_load_data;
260+
params.expert_array = expert_array;
260261

261262
params.swap_blocks = std::move(swap_blocks);
262263

@@ -401,6 +402,7 @@ struct ModelInputParams {
401402
DpEpPaddingData dp_ep_padding_data;
402403

403404
torch::Tensor expert_load_data;
405+
torch::Tensor expert_array;
404406

405407
torch::Tensor kv_cache_tokens_nums;
406408
std::vector<int32_t> kv_cache_tokens_nums_host;

xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,6 @@ torch::Tensor NpuGlm4MoeDecoderImpl::forward(
344344
torch::Tensor& attn_mask,
345345
KVCache& kv_cache,
346346
const ModelInputParams& input_params,
347-
torch::Tensor& expert_array,
348347
aclrtEvent* event,
349348
std::atomic<bool>* event_flag,
350349
int node_id) {
@@ -357,11 +356,10 @@ torch::Tensor NpuGlm4MoeDecoderImpl::forward(
357356
attn_mask,
358357
kv_cache,
359358
input_params,
360-
expert_array,
361359
true);
362360
st = execute_node(prefill_node_, node_id, event, event_flag);
363361
LOG_IF(FATAL, st != 0) << model_name_
364-
<< "excute prefill layer fail, error code: " << st;
362+
<< " excute prefill layer fail, error code: " << st;
365363
} else {
366364
build_node_variant_pack(decode_node_,
367365
x,
@@ -370,11 +368,10 @@ torch::Tensor NpuGlm4MoeDecoderImpl::forward(
370368
/*attn_mask*/ tensor_placeholder_,
371369
kv_cache,
372370
input_params,
373-
expert_array,
374371
false);
375372
st = execute_node(decode_node_, node_id + 1000, event, event_flag);
376373
LOG_IF(FATAL, st != 0) << model_name_
377-
<< "excute decode layer fail, error code: " << st;
374+
<< " excute decode layer fail, error code: " << st;
378375
}
379376

380377
return tensor_placeholder_;
@@ -388,7 +385,6 @@ void NpuGlm4MoeDecoderImpl::build_node_variant_pack(
388385
torch::Tensor& attn_mask,
389386
KVCache& kv_cache,
390387
const ModelInputParams& input_params,
391-
torch::Tensor& expert_array,
392388
bool is_prefill) {
393389
internal_tensor_ = atb_speed::Utils::AtTensor2Tensor(x);
394390
auto& dp_ep_padding = input_params.dp_ep_padding_data;
@@ -421,7 +417,7 @@ void NpuGlm4MoeDecoderImpl::build_node_variant_pack(
421417
atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots);
422418

423419
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11) =
424-
atb_speed::Utils::AtTensor2Tensor(expert_array);
420+
atb_speed::Utils::AtTensor2Tensor(input_params.expert_array);
425421
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 12) =
426422
atb_speed::Utils::AtTensor2Tensor(expert_group_);
427423
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 13) =

xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ class NpuGlm4MoeDecoderImpl : public BaseLayer {
4949
torch::Tensor& attn_mask,
5050
KVCache& kv_cache,
5151
const ModelInputParams& input_params,
52-
torch::Tensor& expert_array,
5352
aclrtEvent* event = nullptr,
5453
std::atomic<bool>* event_flag = nullptr,
5554
int node_id = 0);
@@ -100,7 +99,6 @@ class NpuGlm4MoeDecoderImpl : public BaseLayer {
10099
torch::Tensor& attn_mask,
101100
KVCache& kv_cache,
102101
const ModelInputParams& input_params,
103-
torch::Tensor& expert_array,
104102
bool is_prefill);
105103

106104
std::string model_name_;

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,6 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward(
283283
torch::Tensor& attn_mask,
284284
KVCache& kv_cache,
285285
const ModelInputParams& input_params,
286-
torch::Tensor& expert_array,
287286
aclrtEvent* event,
288287
std::atomic<bool>* event_flag,
289288
int node_id) {
@@ -296,7 +295,6 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward(
296295
attn_mask,
297296
kv_cache,
298297
input_params,
299-
expert_array,
300298
true);
301299
st = execute_node(prefill_node_, node_id, event, event_flag);
302300
LOG_IF(FATAL, st != 0) << model_name_
@@ -309,7 +307,6 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward(
309307
/*attn_mask*/ tensor_placeholder_,
310308
kv_cache,
311309
input_params,
312-
expert_array,
313310
false);
314311
st = execute_node(decode_node_, node_id + 1000, event, event_flag);
315312
LOG_IF(FATAL, st != 0) << model_name_
@@ -327,15 +324,14 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack(
327324
torch::Tensor& attn_mask,
328325
KVCache& kv_cache,
329326
const ModelInputParams& input_params,
330-
torch::Tensor& expert_array,
331327
bool is_prefill) {
332328
internal_tensor_ = atb_speed::Utils::AtTensor2Tensor(x);
333329
int32_t input_idx = 0;
334330
auto& dp_ep_padding = input_params.dp_ep_padding_data;
335331

336332
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensor_;
337333
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1) =
338-
atb_speed::Utils::AtTensor2Tensor(expert_array);
334+
atb_speed::Utils::AtTensor2Tensor(input_params.expert_array);
339335
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 2) =
340336
atb_speed::Utils::AtTensor2Tensor(expert_group_);
341337
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3) =

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class NpuQwen3MoeDecoderLayerImpl : public BaseLayer {
5454
torch::Tensor& attn_mask,
5555
KVCache& kv_cache,
5656
const ModelInputParams& input_params,
57-
torch::Tensor& expert_array,
5857
aclrtEvent* event = nullptr,
5958
std::atomic<bool>* event_flag = nullptr,
6059
int node_id = 0);
@@ -104,7 +103,6 @@ class NpuQwen3MoeDecoderLayerImpl : public BaseLayer {
104103
torch::Tensor& attn_mask,
105104
KVCache& kv_cache,
106105
const ModelInputParams& input_params,
107-
torch::Tensor& expert_array,
108106
bool is_prefill);
109107

110108
torch::Tensor block_tables_placeholder_;

xllm/models/llm/npu/deepseek_mtp.h

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include "core/layers/common/rotary_embedding_util.h"
19+
#include "deepseek_v2.h"
20+
#include "mtp_model_base.h"
21+
22+
// DeepSeek v2 compatible with huggingface weights
23+
// ref to:
24+
// https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py
25+
26+
namespace xllm {
27+
28+
class DeepseekMtpModelImpl : public MtpModelImplBase<DeepseekV2DecoderLayer> {
29+
public:
30+
DeepseekMtpModelImpl(const ModelContext& context)
31+
: MtpModelImplBase<DeepseekV2DecoderLayer>("deepseek_v3_mtp", context) {
32+
auto model_args = context.get_model_args();
33+
auto options = context.get_tensor_options();
34+
35+
int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984;
36+
attn_mask_ = layer::AttentionMask(options.device(),
37+
options.dtype().toScalarType(),
38+
/*mask_value=*/mask_value);
39+
40+
cos_sin_ = layer::rotary::get_deepseek_rotary_embedding(
41+
model_args.qk_rope_head_dim(),
42+
model_args.qk_rope_head_dim(),
43+
model_args.max_position_embeddings(),
44+
model_args.rope_scaling_original_max_position_embeddings(),
45+
model_args.rope_theta(),
46+
/*interleaved*/ false,
47+
model_args.rope_scaling_factor(),
48+
model_args.rope_extrapolation_factor(),
49+
model_args.rope_scaling_attn_factor(),
50+
model_args.rope_scaling_beta_fast(),
51+
model_args.rope_scaling_beta_slow(),
52+
model_args.rope_scaling_mscale(),
53+
model_args.rope_scaling_mscale_all_dim(),
54+
options);
55+
}
56+
};
57+
TORCH_MODULE(DeepseekMtpModel);
58+
59+
class DeepseekMtpForCausalLMImpl
60+
: public MtpForCausalLMImplBase<DeepseekMtpModel> {
61+
public:
62+
DeepseekMtpForCausalLMImpl(const ModelContext& context)
63+
: MtpForCausalLMImplBase<DeepseekMtpModel>(context) {}
64+
};
65+
TORCH_MODULE(DeepseekMtpForCausalLM);
66+
67+
// register the causal model
68+
REGISTER_CAUSAL_MODEL(deepseek_v3_mtp, DeepseekMtpForCausalLM);
69+
70+
// example config:
71+
// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json
72+
REGISTER_MODEL_ARGS(deepseek_v3_mtp, [&] {
73+
LOAD_ARG_OR(model_type, "model_type", "deepseek_v3_mtp");
74+
LOAD_ARG_OR(dtype, "torch_dtype", "");
75+
LOAD_ARG_OR(vocab_size, "vocab_size", 129280);
76+
LOAD_ARG_OR(hidden_size, "hidden_size", 7168);
77+
LOAD_ARG_OR(n_layers, "num_hidden_layers", 61);
78+
LOAD_ARG_OR(n_heads, "num_attention_heads", 128);
79+
LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 128);
80+
LOAD_ARG_OR(intermediate_size, "intermediate_size", 18432);
81+
LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 163840);
82+
LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6);
83+
LOAD_ARG_OR(eos_token_id, "eos_token_id", 1);
84+
LOAD_ARG_OR(bos_token_id, "bos_token_id", 0);
85+
LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f);
86+
LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false);
87+
LOAD_ARG_OR(sliding_window, "sliding_window", 4096);
88+
LOAD_ARG_OR(max_window_layers, "max_window_layers", 61);
89+
90+
LOAD_ARG_OR(first_k_dense_replace, "first_k_dense_replace", 0);
91+
LOAD_ARG_OR(moe_layer_freq, "moe_layer_freq", 1);
92+
LOAD_ARG_OR(topk_method, "topk_method", "noaux_tc");
93+
LOAD_ARG_OR(n_routed_experts, "n_routed_experts", 256);
94+
LOAD_ARG_OR(n_shared_experts, "n_shared_experts", 1);
95+
LOAD_ARG_OR(num_experts_per_tok, "num_experts_per_tok", 8);
96+
LOAD_ARG_OR(moe_intermediate_size, "moe_intermediate_size", 2048);
97+
LOAD_ARG_OR(routed_scaling_factor, "routed_scaling_factor", 2.5f);
98+
LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", true);
99+
LOAD_ARG_OR(n_group, "n_group", 8);
100+
LOAD_ARG_OR(topk_group, "topk_group", 4);
101+
LOAD_ARG_OR(qk_nope_head_dim, "qk_nope_head_dim", 128);
102+
LOAD_ARG_OR(qk_rope_head_dim, "qk_rope_head_dim", 64);
103+
LOAD_ARG_OR(v_head_dim, "v_head_dim", 128);
104+
LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 1536);
105+
LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512);
106+
107+
LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
108+
return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim();
109+
});
110+
LOAD_ARG_OR_FUNC(
111+
rotary_dim, "rotary_dim", [&] { return args->qk_rope_head_dim(); });
112+
113+
SET_ARG(rope_scaling_rope_type, "deepseek_yarn");
114+
LOAD_ARG(rope_scaling_beta_fast, "rope_scaling.beta_fast");
115+
LOAD_ARG(rope_scaling_beta_slow, "rope_scaling.beta_slow");
116+
LOAD_ARG(rope_scaling_factor, "rope_scaling.factor");
117+
LOAD_ARG_OR(
118+
rope_extrapolation_factor, "rope_scaling.extrapolation_factor", 1.0f);
119+
LOAD_ARG(rope_scaling_mscale, "rope_scaling.mscale");
120+
LOAD_ARG(rope_scaling_mscale_all_dim, "rope_scaling.mscale_all_dim");
121+
LOAD_ARG(rope_scaling_original_max_position_embeddings,
122+
"rope_scaling.original_max_position_embeddings");
123+
LOAD_ARG_OR(rope_scaling_attn_factor, "rope_scaling.attn_factor", 1.0f);
124+
125+
SET_ARG(stop_token_ids, std::unordered_set<int32_t>({1}));
126+
});
127+
} // namespace xllm

xllm/models/llm/npu/deepseek_v2.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,10 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
142142
norm_ = register_module("norm", layer::NpuRMSNorm(context));
143143

144144
dp_size_ = parallel_args.dp_size();
145-
std::vector<int64_t> indices;
146145
dp_local_tp_size_ = parallel_args.world_size() / dp_size_;
147146
dp_rank_ = parallel_args.rank() / dp_local_tp_size_;
148147
rank_ = parallel_args.rank();
149-
mapping_data_ = parallel_args.mapping_data();
150148
num_experts_per_tok_ = model_args.num_experts_per_tok();
151-
for (int i = 0; i < parallel_args.world_size(); i += dp_local_tp_size_) {
152-
indices.push_back(i);
153-
}
154149
}
155150

156151
torch::Tensor forward(torch::Tensor tokens,
@@ -258,7 +253,6 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
258253
int32_t rank_;
259254
int32_t dp_size_;
260255
int32_t dp_local_tp_size_;
261-
nlohmann::json mapping_data_;
262256
int32_t num_experts_per_tok_;
263257
int32_t num_speculative_tokens_ = 0;
264258
at::Device device_;

0 commit comments

Comments
 (0)