Skip to content

Commit 3028d49

Browse files
authored
feat: support Qwen3-VL model on npu device. (#295)
1 parent 6162f01 commit 3028d49

21 files changed

+1580
-33
lines changed

CMakeLists.txt

100644100755
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@ if(USE_NPU)
2828
if(DEVICE_TYPE STREQUAL "USE_A3")
2929
message("downloading a3 arm xllm kernels")
3030
file(DOWNLOAD
31-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.1-Linux.a3.arm.rpm"
31+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.2-Linux.a3.arm.rpm"
3232
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
3333
)
3434
else()
3535
if(DEVICE_ARCH STREQUAL "ARM")
3636
message("downloading a2 arm xllm_kernels")
3737
file(DOWNLOAD
38-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.1-Linux.a2.arm.rpm"
38+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.2-Linux.a2.arm.rpm"
3939
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
4040
)
4141
else()
4242
message("downloading a2 x86 xllm_kernels")
4343
file(DOWNLOAD
44-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.1-Linux.a2.x86.rpm"
44+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.6.0/xllm_kernels-1.3.2-Linux.a2.x86.rpm"
4545
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
4646
)
4747
endif()

third_party/xllm_ops

Submodule xllm_ops updated from 2cda9bf to 797a0cb

xllm/core/framework/hf_model_loader.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,12 @@ bool HFModelLoader::load_image_preprocessor_args(
360360
image_prerocess_data["norm_std"].get<std::vector<double>>();
361361
}
362362

363+
args_.mm_image_shortest_edge() =
364+
image_preprocess_reader.value_or<int>("size.shortest_edge", 0);
365+
366+
args_.mm_image_longest_edge() =
367+
image_preprocess_reader.value_or<int>("size.longest_edge", 0);
368+
363369
args_.mm_image_min_pixels() =
364370
image_preprocess_reader.value_or<int>("min_pixels", 0);
365371

xllm/core/framework/model/model_args.h

100644100755
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,15 @@ struct ModelArgs {
242242

243243
PROPERTY(int, mm_window_size) = 0;
244244
PROPERTY(std::vector<int64_t>, mm_fullatt_block_indexes);
245+
PROPERTY(std::vector<int64_t>, mm_deepstack_visual_indexes);
245246
PROPERTY(int, mm_tokens_per_second) = 0;
246247
PROPERTY(int, mm_temporal_patch_size) = 0;
247248

248249
// VLM model projector's mm_projector_type
249250
PROPERTY(std::string, mm_projector_type);
250251

252+
//
253+
PROPERTY(int64_t, mm_num_position_embeddings);
251254
// VLM model projector's mm_projector_hidden_act
252255
PROPERTY(std::string, mm_projector_hidden_act);
253256

@@ -284,6 +287,9 @@ struct ModelArgs {
284287
PROPERTY(int, mm_image_min_pixels) = 0;
285288
PROPERTY(int, mm_image_max_pixels) = 0;
286289

290+
PROPERTY(int64_t, mm_image_shortest_edge) = 0;
291+
PROPERTY(int64_t, mm_image_longest_edge) = 0;
292+
287293
PROPERTY(int, mm_image_patch_size) = 0;
288294
PROPERTY(int, mm_image_temporal_patch_size) = 0;
289295
PROPERTY(int, mm_image_merge_size) = 0;
@@ -447,6 +453,11 @@ inline std::ostream& operator<<(std::ostream& os, const ModelArgs& args) {
447453
os << index << ",";
448454
}
449455
os << "]";
456+
os << ", mm_deepstack_visual_indexes: [";
457+
for (auto& index : args.mm_deepstack_visual_indexes()) {
458+
os << index << ",";
459+
}
460+
os << "]";
450461
os << ", mm_tokens_per_second: " << args.mm_tokens_per_second();
451462
os << ", mm_temporal_patch_size: " << args.mm_temporal_patch_size();
452463
os << ", mm_projector_type: " << args.mm_projector_type();
@@ -474,6 +485,8 @@ inline std::ostream& operator<<(std::ostream& os, const ModelArgs& args) {
474485
os << std << ", ";
475486
}
476487
os << "]";
488+
os << ", mm_image_shortest_edge: " << args.mm_image_shortest_edge();
489+
os << ", mm_image_longest_edge: " << args.mm_image_longest_edge();
477490
os << ", mm_image_min_pixels: " << args.mm_image_min_pixels();
478491
os << ", mm_image_max_pixels: " << args.mm_image_max_pixels();
479492
os << ", mm_image_patch_size: " << args.mm_image_patch_size();

xllm/core/framework/model/model_input_params.h

100644100755
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ struct ModelInputParams {
6767

6868
params.input_embedding = safe_to(input_embedding, device);
6969

70+
params.deep_stacks = deep_stacks;
71+
params.visual_pos_masks = visual_pos_masks;
72+
7073
params.mm_data = MMData::to(mm_data, device);
7174
params.dp_global_token_nums = dp_global_token_nums;
7275
params.prefill_seq_len = prefill_seq_len;
@@ -149,6 +152,11 @@ struct ModelInputParams {
149152
// multimodal
150153
MMData mm_data;
151154

155+
// deep_stack for Qwen3-VL
156+
mutable std::vector<torch::Tensor> deep_stacks;
157+
// visual pos mask for Qwen3-VL
158+
mutable torch::Tensor visual_pos_masks;
159+
152160
// num tokens of all workers,mainly used for dp case
153161
std::vector<int32_t> dp_global_token_nums;
154162
// whether the kv-cache is empty for all sequences,mainly used for dp case

xllm/core/framework/quant_args.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct QuantArgs {
2727
PROPERTY(std::string, quant_method);
2828

2929
PROPERTY(std::string, quantize_type);
30-
PROPERTY(std::string, torch_dtype);
30+
PROPERTY(std::string, torch_dtype) = "bfloat16";
3131
// quantization bits
3232
PROPERTY(int64_t, bits) = 0;
3333

xllm/core/layers/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ cc_library(
5353
multi_head_attention.h
5454
qwen2_decoder_layer.h
5555
qwen2dot5_vision_decode_layer.h
56+
qwen3_vision_encode_layer.h
5657
qwen3_decoder_layer.h
5758
qwen3_moe_decoder_layer.h
5859
rms_norm.h

xllm/core/layers/base_layer.cpp

100644100755
File mode changed.

xllm/core/layers/npu/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ cc_library(
1010
npu_pos_embedding_impl.h
1111
npu_lm_head_impl.h
1212
npu_qwen2dot5_vision_encoder_layer_impl.h
13+
npu_qwen3_vision_encoder_layer_impl.h
1314
npu_qwen3_moe_decoder_layer_impl.h
1415
# atb_parallel_linear.h
1516
npu_block_copy_impl.h
@@ -29,6 +30,7 @@ cc_library(
2930
npu_pos_embedding_impl.cpp
3031
npu_lm_head_impl.cpp
3132
npu_qwen2dot5_vision_encoder_layer_impl.cpp
33+
npu_qwen3_vision_encoder_layer_impl.cpp
3234
npu_qwen3_moe_decoder_layer_impl.cpp
3335
# atb_parallel_linear.cpp
3436
npu_block_copy_impl.cpp

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp

100644100755
Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_mlp_parameters(
376376
const ModelArgs& args,
377377
const ParallelArgs& parallel_args) {
378378
param.hasSharedExpert = (args.n_shared_experts() > 0);
379-
param.hasSharedExpertGate = true;
379+
param.hasSharedExpertGate = false;
380380
param.processLogits = "normalization";
381381
param.numOfSelectedExperts = {args.num_experts_per_tok()};
382382

@@ -492,7 +492,6 @@ void NpuQwen3MoeDecoderLayerImpl::process_expert_weights(
492492
const int local_index = expert_index % num_experts_per_partition_;
493493
const bool is_sharded = shard_map.count(index);
494494

495-
std::lock_guard<std::mutex> lock(experts_mutex_);
496495
torch::Tensor tmp_tensor = is_sharded
497496
? get_sharded_tensor(state_dict,
498497
name,
@@ -517,8 +516,6 @@ void NpuQwen3MoeDecoderLayerImpl::process_mlp_common_weights(
517516
const int index = get_mapped_index(name, weight_mapping);
518517
const bool is_sharded = shard_map.count(index);
519518

520-
std::lock_guard<std::mutex> lock(shared_experts_mutex_);
521-
522519
torch::Tensor tmp_tensor = is_sharded
523520
? get_sharded_tensor(state_dict,
524521
name,
@@ -650,7 +647,6 @@ void NpuQwen3MoeDecoderLayerImpl::verify_loaded_weights(
650647

651648
void NpuQwen3MoeDecoderLayerImpl::merge_loaded_weights() {
652649
merge_experts_weights();
653-
654650
at_weight_tensors_[IN_QKV_WEIGHT_0] =
655651
torch::cat({at_weight_tensors_[IN_QKV_WEIGHT_0],
656652
at_weight_tensors_[IN_QKV_WEIGHT_1],

0 commit comments

Comments
 (0)