Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ DEFINE_bool(enable_mla,
false,
"Whether to enable multi-head latent attention.");

DEFINE_bool(enable_customize_mla_kernel, false, "enable customize mla kernel");

// --- graph mode execution config ---

DEFINE_bool(enable_acl_graph,
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ DECLARE_int32(expert_parallel_degree);

DECLARE_int32(max_reconnect_count);

DECLARE_bool(enable_customize_mla_kernel);

DECLARE_bool(enable_atb_comm_multiprocess);

DECLARE_string(tool_call_parser);
Expand Down
40 changes: 29 additions & 11 deletions xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ NpuDeepseekV2DecoderLayerImpl::NpuDeepseekV2DecoderLayerImpl(

param_from_args(prefill_param_, model_args, parallel_args, true);
param_from_args(decode_param_, model_args, parallel_args, false);
param_from_args(decode_mla_param_, model_args, parallel_args, false);
decode_mla_param_.enableCustomizeMla = FLAGS_enable_customize_mla_kernel;

initialize_tensors(options);
}
Expand Down Expand Up @@ -1437,6 +1439,7 @@ void NpuDeepseekV2DecoderLayerImpl::update_expert_weight() {
atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[index]);
prefill_node_.inTensors.at(index) = &atb_weight_tensors_[index];
decode_node_.inTensors.at(index) = &atb_weight_tensors_[index];
decode_mla_node_.inTensors.at(index) = &atb_weight_tensors_[index];
}
expert_routing_map_[layer_id_ - first_k_dense_replace_] =
expert_routing_map_buffer_;
Expand All @@ -1456,6 +1459,7 @@ int64_t NpuDeepseekV2DecoderLayerImpl::init_layer() {
model_name_ = "DeepSeek_V2";
CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_));
CHECK_OPERATION_STATUS_RETURN(init_node(decode_node_, decode_param_));
CHECK_OPERATION_STATUS_RETURN(init_node(decode_mla_node_, decode_mla_param_));
return atb::NO_ERROR;
}

Expand Down Expand Up @@ -1536,17 +1540,31 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward(
} else {
std::vector<torch::Tensor> attn_mask{tensor_placeholder_,
tensor_placeholder_};
build_node_variant_pack(decode_node_,
x,
cos_pos,
sin_pos,
attn_mask,
kv_cache,
input_params,
false);
st = execute_node(decode_node_, node_id + 1000, event, event_flag);
LOG_IF(FATAL, st != 0) << model_name_
<< "excute decode layer fail, error code: " << st;
if (!FLAGS_enable_customize_mla_kernel) {
build_node_variant_pack(decode_node_,
x,
cos_pos,
sin_pos,
attn_mask,
kv_cache,
input_params,
false);
st = execute_node(decode_node_, node_id + 1000, event, event_flag);
LOG_IF(FATAL, st != 0)
<< model_name_ << "excute decode layer fail, error code: " << st;
} else {
build_node_variant_pack(decode_mla_node_,
x,
cos_pos,
sin_pos,
attn_mask,
kv_cache,
input_params,
false);
st = execute_node(decode_mla_node_, node_id + 1000, event, event_flag);
LOG_IF(FATAL, st != 0)
<< model_name_ << "excute decode layer fail, error code: " << st;
}
}
return tensor_placeholder_;
}
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,11 @@ class NpuDeepseekV2DecoderLayerImpl : public NpuBaseLayer {

atb_speed::deepseekV2::DecoderLayerParam prefill_param_;
atb_speed::deepseekV2::DecoderLayerParam decode_param_;
atb_speed::deepseekV2::DecoderLayerParam decode_mla_param_;

atb_speed::Model::Node prefill_node_;
atb_speed::Model::Node decode_node_;
atb_speed::Model::Node decode_mla_node_;

atb::Tensor internal_tensor_;
atb::Tensor internal_tensor_auxiliary_;
Expand Down
17 changes: 16 additions & 1 deletion xllm/core/runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,22 @@ ForwardOutput LLMEngine::step(std::vector<Batch>& batch) {
<< "The processed raw forward inputs size "
<< batched_raw_forward_inputs.size() << " is not equal to dp size "
<< dp_size_ << ".";

static bool set_enable_mla = FLAGS_enable_customize_mla_kernel;
// decode phase with tokens more than this limit will lead to error in
// customize mla kernel. once detect any input exceed the limit, fall back to
// default kernel.
const int num_tokens_limit = 230;
if (set_enable_mla) {
FLAGS_enable_customize_mla_kernel = std::all_of(
batched_raw_forward_inputs.begin(),
batched_raw_forward_inputs.end(),
[](const std::vector<RawForwardInput>& inputs) {
return std::all_of(
inputs.begin(), inputs.end(), [](const RawForwardInput& input) {
return input.flatten_tokens_vec.size() < num_tokens_limit;
});
});
}
std::vector<folly::SemiFuture<std::optional<RawForwardOutput>>> futures;
futures.reserve(worker_clients_num_);

Expand Down