Skip to content

Commit 395231d

Browse files
authored
feat: support Qwen3-VL-MOE model on npu device. (#313)
1 parent 53b6e6f commit 395231d

File tree

4 files changed

+207
-19
lines changed

4 files changed

+207
-19
lines changed

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,10 @@ static const std::unordered_map<std::string, int> WEIGHT_MAPPING = {
9696
{"input_layernorm.weight", IN_INPUT_NORM_WEIGHT},
9797

9898
{"self_attn.q_proj.weight", IN_QKV_WEIGHT_0},
99-
{"self_attn.q_proj.bias", IN_QKV_BIAS_0},
10099

101100
{"self_attn.k_proj.weight", IN_QKV_WEIGHT_1},
102-
{"self_attn.k_proj.bias", IN_QKV_BIAS_1},
103101

104102
{"self_attn.v_proj.weight", IN_QKV_WEIGHT_2},
105-
{"self_attn.v_proj.bias", IN_QKV_WEIGHT_2},
106103

107104
{"self_attn.o_proj.weight", IN_ATTENTION_OUT_WEIGHT},
108105

@@ -184,11 +181,8 @@ static const std::unordered_map<std::string, std::vector<int>>
184181

185182
static const std::map<int, int> WEIGHT_SHARD = {
186183
{IN_QKV_WEIGHT_0, 0},
187-
{IN_QKV_BIAS_0, 0},
188184
{IN_QKV_WEIGHT_1, 0},
189-
{IN_QKV_BIAS_1, 0},
190185
{IN_QKV_WEIGHT_2, 0},
191-
{IN_QKV_BIAS_2, 0},
192186
{IN_ATTENTION_OUT_WEIGHT, 1},
193187
{IN_MLP_GATEUP_WEIGHT_EXPERT, 0},
194188
{IN_MLP_DOWN_WEIGHT_EXPERT, 1},
@@ -658,17 +652,6 @@ void NpuQwen3MoeDecoderLayerImpl::merge_loaded_weights() {
658652
at_weight_tensors_[IN_QKV_WEIGHT_2] =
659653
torch::zeros({1}, torch::kFloat16).to(device_);
660654

661-
at_weight_tensors_[IN_QKV_BIAS_0] =
662-
torch::cat({at_weight_tensors_[IN_QKV_BIAS_0],
663-
at_weight_tensors_[IN_QKV_BIAS_1],
664-
at_weight_tensors_[IN_QKV_BIAS_2]},
665-
0)
666-
.contiguous();
667-
at_weight_tensors_[IN_QKV_BIAS_1] =
668-
torch::zeros({1}, torch::kFloat16).to(device_);
669-
at_weight_tensors_[IN_QKV_BIAS_2] =
670-
torch::zeros({1}, torch::kFloat16).to(device_);
671-
672655
if (quantize_type_.compare("w8a8_dynamic") == 0) {
673656
at_weight_tensors_[IN_QKV_BIAS_0] =
674657
torch::zeros({1}, torch::kFloat16).to(device_);

xllm/models/llm/qwen3_moe.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,6 @@ class Qwen3MoeModelImpl : public torch::nn::Module {
272272
if (input_params.layer_synchronizer != nullptr) {
273273
event = input_params.layer_synchronizer->get_event(i);
274274
event_flag = input_params.layer_synchronizer->get_event_flag(i);
275-
} else {
276-
LOG(INFO) << "layer_synchronizer is nullptr";
277275
}
278276
auto& layer = layers_[i];
279277
layer(h,

xllm/models/models.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License.
3232
#include "vlm/minicpmv.h" // IWYU pragma: keep
3333
#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep
3434
#include "vlm/qwen3_vl.h" // IWYU pragma: keep
35+
#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep
3536
#endif
3637

3738
#include "llm/llm_model_base.h" // IWYU pragma: keep

xllm/models/vlm/qwen3_vl_moe.h

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <atb/atb_infer.h>
19+
#include <c10/core/ScalarType.h>
20+
#include <glog/logging.h>
21+
#include <torch/torch.h>
22+
23+
#include <boost/algorithm/string.hpp>
24+
#include <unordered_map>
25+
26+
#include "core/framework/kv_cache/kv_cache.h"
27+
#include "core/framework/model/model_input_params.h"
28+
#include "core/framework/model_context.h"
29+
#include "core/layers/lm_head.h"
30+
#include "core/layers/qwen3_vision_encode_layer.h"
31+
#include "core/layers/rms_norm.h"
32+
#include "models/llm/qwen3_moe.h"
33+
#include "models/model_registry.h"
34+
#include "processors/input_processor.h"
35+
#include "processors/qwen2_vl_image_processor.h"
36+
#include "qwen2_5_vl.h"
37+
#include "qwen3_vl.h"
38+
#include "xllm_kernels/core/include/atb_speed/log.h"
39+
40+
namespace xllm {
41+
42+
using torch::indexing::None;
43+
using ISlice = torch::indexing::Slice;
44+
45+
class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module {
46+
public:
47+
Qwen3_VLMoeForConditionalGenerationImpl(const ModelContext& context)
48+
: model_args_(context.get_model_args()),
49+
options_(context.get_tensor_options()) {
50+
visual_ = register_module("visual", Qwen3_VisionTransformer(context));
51+
52+
language_model_ =
53+
register_module("language_model", Qwen3MoeForCausalLM(context));
54+
}
55+
56+
torch::Tensor get_input_embeddings(
57+
torch::Tensor input_ids,
58+
const std::optional<Qwen3_VLImageInputs>& image_input,
59+
const std::optional<Qwen3_VLVideoInputs>& video_input,
60+
const ModelInputParams& input_params) {
61+
auto inputs_embeds = language_model_->get_input_embeddings(input_ids);
62+
if (image_input) {
63+
// visual
64+
auto [image_embeds, deep_stacks] =
65+
visual_(image_input->pixel_values.to(options_),
66+
image_input->image_grid_thw,
67+
input_params);
68+
input_params.deep_stacks = deep_stacks;
69+
// merge
70+
auto is_multimodal = torch::isin(input_ids, model_args_.image_token_id());
71+
input_params.visual_pos_masks = is_multimodal;
72+
inputs_embeds.index_put_({is_multimodal}, image_embeds);
73+
}
74+
return inputs_embeds;
75+
}
76+
77+
torch::Tensor forward(const std::vector<torch::Tensor>& tokens,
78+
const std::vector<torch::Tensor>& positions,
79+
std::vector<KVCache>& kv_caches,
80+
const std::vector<ModelInputParams>& input_params) {
81+
torch::NoGradGuard no_grad;
82+
const auto& mm_data = input_params[0].mm_data;
83+
torch::Tensor pixel_values;
84+
if (const auto& res = mm_data.get<torch::Tensor>("pixel_values"))
85+
pixel_values = res.value();
86+
87+
torch::Tensor image_grid_thw;
88+
if (const auto& res = mm_data.get<torch::Tensor>("image_grid_thw"))
89+
image_grid_thw = res.value();
90+
std::optional<Qwen3_VLImageInputs> image_inputs;
91+
std::optional<Qwen3_VLVideoInputs> video_inputs;
92+
93+
if (pixel_values.defined() && image_grid_thw.defined())
94+
image_inputs = Qwen3_VLImageInputs{pixel_values, image_grid_thw};
95+
96+
auto inputs_embeds = get_input_embeddings(
97+
tokens[0], image_inputs, video_inputs, input_params[0]);
98+
input_params[0].input_embedding = inputs_embeds;
99+
auto emb = language_model_(tokens, positions, kv_caches, input_params);
100+
101+
return emb;
102+
}
103+
104+
torch::Tensor logits(const torch::Tensor& hidden_states,
105+
const torch::Tensor& seleted_idxes) {
106+
return language_model_->logits(hidden_states, seleted_idxes);
107+
}
108+
109+
void load_model(std::unique_ptr<ModelLoader> loader) {
110+
for (const auto& state_dict : loader->get_state_dicts()) {
111+
visual_->load_state_dict(
112+
state_dict->get_dict_with_prefix("model.visual."));
113+
}
114+
// verify
115+
visual_->verify_loaded_weights("model.visual.");
116+
visual_->merge_loaded_weights();
117+
if (!model_args_.image_embedding_mode()) {
118+
language_model_->load_model(std::move(loader), "model.language_model.");
119+
}
120+
}
121+
122+
layer::LmHead get_lm_head() { return language_model_->get_lm_head(); }
123+
void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); }
124+
125+
std::vector<layer::WordEmbedding> get_word_embedding() {
126+
return language_model_->get_word_embedding();
127+
}
128+
129+
void set_word_embedding(std::vector<layer::WordEmbedding>& word_embedding) {
130+
language_model_->set_word_embedding(word_embedding);
131+
}
132+
133+
private:
134+
ModelArgs model_args_;
135+
torch::TensorOptions options_;
136+
Qwen3_VisionTransformer visual_{nullptr};
137+
Qwen3MoeForCausalLM language_model_{nullptr};
138+
};
139+
TORCH_MODULE(Qwen3_VLMoeForConditionalGeneration);
140+
141+
REGISTER_INPUT_PROCESSOR(qwen3_vl_moe, Qwen2_5_VLInputProcessor);
142+
REGISTER_CAUSAL_VLM_MODEL(qwen3_vl_moe, Qwen3_VLMoeForConditionalGeneration);
143+
REGISTER_IMAGE_PROCESSOR(qwen3_vl_moe, Qwen2VLImageProcessor);
144+
// register the model args
145+
REGISTER_MODEL_ARGS(qwen3_vl_moe, [&] {
146+
// text config
147+
LOAD_ARG_OR(model_type, "model_type", "qwen3_vl_moe");
148+
LOAD_ARG_OR(attention_bias, "text_config.attention_bias", false);
149+
LOAD_ARG_OR(attention_dropout, "attention_dropout", 0.0f);
150+
LOAD_ARG_OR(bos_token_id, "text_config.bos_token_id", 151643);
151+
LOAD_ARG_OR(decoder_sparse_step, "text_config.decoder_sparse_step", 1);
152+
LOAD_ARG_OR(dtype, "text_config.dtype", "bfloat16");
153+
LOAD_ARG_OR(eos_token_id, "text_config.eos_token_id", 151645);
154+
LOAD_ARG_OR_FUNC(head_dim, "text_config.head_dim", [&] {
155+
return args->hidden_size() / args->n_heads();
156+
});
157+
LOAD_ARG_OR(hidden_act, "text_config.hidden_act", "silu");
158+
LOAD_ARG_OR(hidden_size, "text_config.hidden_size", 2048);
159+
LOAD_ARG_OR(initializer_range, "text_config.initializer_range", 0.02);
160+
LOAD_ARG_OR(intermediate_size, "text_config.intermediate_size", 5632);
161+
LOAD_ARG_OR(
162+
max_position_embeddings, "text_config.max_position_embeddings", 128000);
163+
// LOAD_ARG(mlp_only_layers, "text_config.mlp_only_layers");
164+
LOAD_ARG_OR(moe_intermediate_size, "text_config.moe_intermediate_size", 1408);
165+
LOAD_ARG_OR(norm_topk_prob, "text_config.norm_topk_prob", true);
166+
LOAD_ARG_OR(n_heads, "text_config.num_attention_heads", 16);
167+
LOAD_ARG_OR(num_experts, "text_config.num_experts", 128);
168+
LOAD_ARG_OR(num_experts_per_tok, "text_config.num_experts_per_tok", 8);
169+
LOAD_ARG_OR(n_layers, "text_config.num_hidden_layers", 24);
170+
LOAD_ARG_OR(n_kv_heads, "text_config.num_key_value_heads", 16);
171+
LOAD_ARG_OR(rms_norm_eps, "text_config.rms_norm_eps", 1e-06);
172+
LOAD_ARG_OR(rope_scaling_rope_type, "text_config.rope_scaling.type", "mrope");
173+
LOAD_ARG(rope_scaling_mrope_section,
174+
"text_config.rope_scaling.mrope_section");
175+
// LOAD_ARG_OR(rope_scaling_mrope_interleaved,"text_config.rope_scaling.mrope_interleaved",true);
176+
LOAD_ARG_OR(rope_theta, "text_config.rope_theta", 5000000.0f);
177+
LOAD_ARG_OR(vocab_size, "text_config.vocab_size", 151936);
178+
179+
// vision config
180+
LOAD_ARG(mm_deepstack_visual_indexes,
181+
"vision_config.deepstack_visual_indexes");
182+
LOAD_ARG_OR(mm_num_hidden_layers, "vision_config.depth", 27);
183+
LOAD_ARG_OR(mm_hidden_act, "vision_config.hidden_act", "gelu_pytorch_tanh");
184+
LOAD_ARG_OR(mm_hidden_size, "vision_config.hidden_size", 1152);
185+
LOAD_ARG_OR(mm_num_channels, "vision_config.in_channels", 3);
186+
LOAD_ARG_OR(mm_initializer_range, "vision_config.initializer_range", 0.02);
187+
LOAD_ARG_OR(mm_intermediate_size, "vision_config.intermediate_size", 4304);
188+
LOAD_ARG_OR(mm_num_attention_heads, "vision_config.num_heads", 16);
189+
LOAD_ARG_OR(mm_num_position_embeddings,
190+
"vision_config.num_position_embeddings",
191+
2304);
192+
LOAD_ARG_OR(mm_projection_dim, "vision_config.out_hidden_size", 3584);
193+
LOAD_ARG_OR(mm_patch_size, "vision_config.patch_size", 16);
194+
LOAD_ARG_OR(mm_spatial_merge_size, "vision_config.spatial_merge_size", 2);
195+
LOAD_ARG_OR(mm_temporal_patch_size, "vision_config.temporal_patch_size", 2);
196+
LOAD_ARG_OR_FUNC(mm_head_dim, "head_dim", [&] {
197+
return args->mm_hidden_size() / args->mm_num_attention_heads();
198+
});
199+
200+
LOAD_ARG_OR(image_token_id, "image_token_id", 151655);
201+
LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false);
202+
LOAD_ARG_OR(video_token_id, "video_token_id", 151656);
203+
LOAD_ARG_OR(vision_end_token_id, "vision_end_token_id", 151653);
204+
LOAD_ARG_OR(vision_start_token_id, "vision_start_token_id", 151652);
205+
});
206+
} // namespace xllm

0 commit comments

Comments
 (0)