diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 9e41164e..5c930bab 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -79,6 +79,8 @@ DEFINE_bool(enable_mla, false, "Whether to enable multi-head latent attention."); +DEFINE_bool(enable_customize_mla_kernel, false, "enable customize mla kernel"); + // --- graph mode execution config --- DEFINE_bool(enable_acl_graph, diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 7fc36442..6e951d1e 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -125,6 +125,8 @@ DECLARE_int32(expert_parallel_degree); DECLARE_int32(max_reconnect_count); +DECLARE_bool(enable_customize_mla_kernel); + DECLARE_bool(enable_atb_comm_multiprocess); DECLARE_string(tool_call_parser); diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp index 0ae6cedb..f8417d6e 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp @@ -286,6 +286,8 @@ NpuDeepseekV2DecoderLayerImpl::NpuDeepseekV2DecoderLayerImpl( param_from_args(prefill_param_, model_args, parallel_args, true); param_from_args(decode_param_, model_args, parallel_args, false); + param_from_args(decode_mla_param_, model_args, parallel_args, false); + decode_mla_param_.enableCustomizeMla = FLAGS_enable_customize_mla_kernel; initialize_tensors(options); } @@ -1437,6 +1439,7 @@ void NpuDeepseekV2DecoderLayerImpl::update_expert_weight() { atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[index]); prefill_node_.inTensors.at(index) = &atb_weight_tensors_[index]; decode_node_.inTensors.at(index) = &atb_weight_tensors_[index]; + decode_mla_node_.inTensors.at(index) = &atb_weight_tensors_[index]; } expert_routing_map_[layer_id_ - first_k_dense_replace_] = expert_routing_map_buffer_; @@ -1456,6 +1459,7 @@ int64_t NpuDeepseekV2DecoderLayerImpl::init_layer() { model_name_ = "DeepSeek_V2"; CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); CHECK_OPERATION_STATUS_RETURN(init_node(decode_node_, decode_param_)); + CHECK_OPERATION_STATUS_RETURN(init_node(decode_mla_node_, decode_mla_param_)); return atb::NO_ERROR; } @@ -1536,17 +1540,31 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward( } else { std::vector attn_mask{tensor_placeholder_, tensor_placeholder_}; - build_node_variant_pack(decode_node_, - x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - false); - st = execute_node(decode_node_, node_id + 1000, event, event_flag); - LOG_IF(FATAL, st != 0) << model_name_ - << "excute decode layer fail, error code: " << st; + if (!FLAGS_enable_customize_mla_kernel) { + build_node_variant_pack(decode_node_, + x, + cos_pos, + sin_pos, + attn_mask, + kv_cache, + input_params, + false); + st = execute_node(decode_node_, node_id + 1000, event, event_flag); + LOG_IF(FATAL, st != 0) + << model_name_ << "excute decode layer fail, error code: " << st; + } else { + build_node_variant_pack(decode_mla_node_, + x, + cos_pos, + sin_pos, + attn_mask, + kv_cache, + input_params, + false); + st = execute_node(decode_mla_node_, node_id + 1000, event, event_flag); + LOG_IF(FATAL, st != 0) + << model_name_ << "excute decode layer fail, error code: " << st; + } } return tensor_placeholder_; } diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h index 00c830c4..98996f2d 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h @@ -313,9 +313,11 @@ class NpuDeepseekV2DecoderLayerImpl : public NpuBaseLayer { atb_speed::deepseekV2::DecoderLayerParam prefill_param_; atb_speed::deepseekV2::DecoderLayerParam decode_param_; + atb_speed::deepseekV2::DecoderLayerParam decode_mla_param_; atb_speed::Model::Node prefill_node_; atb_speed::Model::Node decode_node_; + atb_speed::Model::Node decode_mla_node_; atb::Tensor internal_tensor_; atb::Tensor internal_tensor_auxiliary_; diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp index b2eddb5c..bc432375 100644 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -664,7 +664,22 @@ ForwardOutput LLMEngine::step(std::vector& batch) { << "The processed raw forward inputs size " << batched_raw_forward_inputs.size() << " is not equal to dp size " << dp_size_ << "."; - + static bool set_enable_mla = FLAGS_enable_customize_mla_kernel; + // decode phase with tokens more than this limit will lead to error in + // customize mla kernel. once detect any input exceed the limit, fall back to + // default kernel. + const int num_tokens_limit = 230; + if (set_enable_mla) { + FLAGS_enable_customize_mla_kernel = std::all_of( + batched_raw_forward_inputs.begin(), + batched_raw_forward_inputs.end(), + [](const std::vector& inputs) { + return std::all_of( + inputs.begin(), inputs.end(), [](const RawForwardInput& input) { + return input.flatten_tokens_vec.size() < num_tokens_limit; + }); + }); + } std::vector>> futures; futures.reserve(worker_clients_num_);