@@ -286,6 +286,8 @@ NpuDeepseekV2DecoderLayerImpl::NpuDeepseekV2DecoderLayerImpl(
286286
287287 param_from_args (prefill_param_, model_args, parallel_args, true );
288288 param_from_args (decode_param_, model_args, parallel_args, false );
289+ param_from_args (decode_mla_param_, model_args, parallel_args, false );
290+ decode_mla_param_.enableCustomizeMla = FLAGS_enable_customize_mla_kernel;
289291
290292 initialize_tensors (options);
291293}
@@ -1437,6 +1439,7 @@ void NpuDeepseekV2DecoderLayerImpl::update_expert_weight() {
14371439 atb_speed::Utils::AtTensor2Tensor (at_weight_tensors_[index]);
14381440 prefill_node_.inTensors .at (index) = &atb_weight_tensors_[index];
14391441 decode_node_.inTensors .at (index) = &atb_weight_tensors_[index];
1442+ decode_mla_node_.inTensors .at (index) = &atb_weight_tensors_[index];
14401443 }
14411444 expert_routing_map_[layer_id_ - first_k_dense_replace_] =
14421445 expert_routing_map_buffer_;
@@ -1456,6 +1459,7 @@ int64_t NpuDeepseekV2DecoderLayerImpl::init_layer() {
14561459 model_name_ = " DeepSeek_V2" ;
14571460 CHECK_OPERATION_STATUS_RETURN (init_node (prefill_node_, prefill_param_));
14581461 CHECK_OPERATION_STATUS_RETURN (init_node (decode_node_, decode_param_));
1462+ CHECK_OPERATION_STATUS_RETURN (init_node (decode_mla_node_, decode_mla_param_));
14591463 return atb::NO_ERROR;
14601464}
14611465
@@ -1536,17 +1540,31 @@ torch::Tensor NpuDeepseekV2DecoderLayerImpl::forward(
15361540 } else {
15371541 std::vector<torch::Tensor> attn_mask{tensor_placeholder_,
15381542 tensor_placeholder_};
1539- build_node_variant_pack (decode_node_,
1540- x,
1541- cos_pos,
1542- sin_pos,
1543- attn_mask,
1544- kv_cache,
1545- input_params,
1546- false );
1547- st = execute_node (decode_node_, node_id + 1000 , event, event_flag);
1548- LOG_IF (FATAL, st != 0 ) << model_name_
1549- << " excute decode layer fail, error code: " << st;
1543+ if (!FLAGS_enable_customize_mla_kernel) {
1544+ build_node_variant_pack (decode_node_,
1545+ x,
1546+ cos_pos,
1547+ sin_pos,
1548+ attn_mask,
1549+ kv_cache,
1550+ input_params,
1551+ false );
1552+ st = execute_node (decode_node_, node_id + 1000 , event, event_flag);
1553+ LOG_IF (FATAL, st != 0 )
1554+ << model_name_ << " excute decode layer fail, error code: " << st;
1555+ } else {
1556+ build_node_variant_pack (decode_mla_node_,
1557+ x,
1558+ cos_pos,
1559+ sin_pos,
1560+ attn_mask,
1561+ kv_cache,
1562+ input_params,
1563+ false );
1564+ st = execute_node (decode_mla_node_, node_id + 1000 , event, event_flag);
1565+ LOG_IF (FATAL, st != 0 )
1566+ << model_name_ << " excute decode layer fail, error code: " << st;
1567+ }
15501568 }
15511569 return tensor_placeholder_;
15521570}
0 commit comments