Skip to content

Commit a19e0c1

Browse files
committed
feat: add control to enable customized mla operation.
1 parent 53b6e6f commit a19e0c1

File tree

5 files changed

+51
-12
lines changed

5 files changed

+51
-12
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ DEFINE_bool(enable_mla,
7979
false,
8080
"Whether to enable multi-head latent attention.");
8181

82+
DEFINE_bool(enable_customize_mla_kernel, false, "enable customize mla kernel");
83+
8284
// --- graph mode execution config ---
8385

8486
DEFINE_bool(enable_acl_graph,

xllm/core/common/global_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ DECLARE_int32(expert_parallel_degree);
125125

126126
DECLARE_int32(max_reconnect_count);
127127

128+
DECLARE_bool(enable_customize_mla_kernel);
129+
128130
DECLARE_bool(enable_atb_comm_multiprocess);
129131

130132
DECLARE_string(tool_call_parser);

xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ NpuDeepseekV2DecoderLayerImpl::NpuDeepseekV2DecoderLayerImpl(
286286

287287
param_from_args(prefill_param_, model_args, parallel_args, true);
288288
param_from_args(decode_param_, model_args, parallel_args, false);
289+
param_from_args(decode_mla_param_, model_args, parallel_args, false);
290+
decode_mla_param_.enableCustomizeMla = FLAGS_enable_customize_mla_kernel;
289291

290292
initialize_tensors(options);
291293
}
@@ -1437,6 +1439,7 @@ void NpuDeepseekV2DecoderLayerImpl::update_expert_weight() {
14371439
atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[index]);
14381440
prefill_node_.inTensors.at(index) = &atb_weight_tensors_[index];
14391441
decode_node_.inTensors.at(index) = &atb_weight_tensors_[index];
1442+
decode_mla_node_.inTensors.at(index) = &atb_weight_tensors_[index];
14401443
}
14411444
expert_routing_map_[layer_id_ - first_k_dense_replace_] =
14421445
expert_routing_map_buffer_;
@@ -1456,6 +1459,7 @@ int64_t NpuDeepseekV2DecoderLayerImpl::init_layer() {
14561459
model_name_ = "DeepSeek_V2";
14571460
CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_));
14581461
CHECK_OPERATION_STATUS_RETURN(init_node(decode_node_, decode_param_));
1462+
CHECK_OPERATION_STATUS_RETURN(init_node(decode_mla_node_, decode_mla_param_));
14591463
return atb::NO_ERROR;
14601464
}
14611465

@@ -1536,17 +1540,31 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward(
15361540
} else {
15371541
std::vector<torch::Tensor> attn_mask{tensor_placeholder_,
15381542
tensor_placeholder_};
1539-
build_node_variant_pack(decode_node_,
1540-
x,
1541-
cos_pos,
1542-
sin_pos,
1543-
attn_mask,
1544-
kv_cache,
1545-
input_params,
1546-
false);
1547-
st = execute_node(decode_node_, node_id + 1000, event, event_flag);
1548-
LOG_IF(FATAL, st != 0) << model_name_
1549-
<< "excute decode layer fail, error code: " << st;
1543+
if (!FLAGS_enable_customize_mla_kernel) {
1544+
build_node_variant_pack(decode_node_,
1545+
x,
1546+
cos_pos,
1547+
sin_pos,
1548+
attn_mask,
1549+
kv_cache,
1550+
input_params,
1551+
false);
1552+
st = execute_node(decode_node_, node_id + 1000, event, event_flag);
1553+
LOG_IF(FATAL, st != 0)
1554+
<< model_name_ << "excute decode layer fail, error code: " << st;
1555+
} else {
1556+
build_node_variant_pack(decode_mla_node_,
1557+
x,
1558+
cos_pos,
1559+
sin_pos,
1560+
attn_mask,
1561+
kv_cache,
1562+
input_params,
1563+
false);
1564+
st = execute_node(decode_mla_node_, node_id + 1000, event, event_flag);
1565+
LOG_IF(FATAL, st != 0)
1566+
<< model_name_ << "excute decode layer fail, error code: " << st;
1567+
}
15501568
}
15511569
return tensor_placeholder_;
15521570
}

xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,11 @@ class NpuDeepseekV2DecoderLayerImpl : public NpuBaseLayer {
313313

314314
atb_speed::deepseekV2::DecoderLayerParam prefill_param_;
315315
atb_speed::deepseekV2::DecoderLayerParam decode_param_;
316+
atb_speed::deepseekV2::DecoderLayerParam decode_mla_param_;
316317

317318
atb_speed::Model::Node prefill_node_;
318319
atb_speed::Model::Node decode_node_;
320+
atb_speed::Model::Node decode_mla_node_;
319321

320322
atb::Tensor internal_tensor_;
321323
atb::Tensor internal_tensor_auxiliary_;

xllm/core/runtime/llm_engine.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,22 @@ ForwardOutput LLMEngine::step(std::vector<Batch>& batch) {
664664
<< "The processed raw forward inputs size "
665665
<< batched_raw_forward_inputs.size() << " is not equal to dp size "
666666
<< dp_size_ << ".";
667-
667+
static bool set_enable_mla = FLAGS_enable_customize_mla_kernel;
668+
// decode phase with tokens more than this limit will lead to error in
669+
// customize mla kernel. once detect any input exceed the limit, fall back to
670+
// default kernel.
671+
const int num_tokens_limit = 230;
672+
if (set_enable_mla) {
673+
FLAGS_enable_customize_mla_kernel = std::all_of(
674+
batched_raw_forward_inputs.begin(),
675+
batched_raw_forward_inputs.end(),
676+
[](const std::vector<RawForwardInput>& inputs) {
677+
return std::all_of(
678+
inputs.begin(), inputs.end(), [](const RawForwardInput& input) {
679+
return input.flatten_tokens_vec.size() < num_tokens_limit;
680+
});
681+
});
682+
}
668683
std::vector<folly::SemiFuture<std::optional<RawForwardOutput>>> futures;
669684
futures.reserve(worker_clients_num_);
670685

0 commit comments

Comments
 (0)