From fde43aa0bd84c532a7ddf7310f52997c49f5e819 Mon Sep 17 00:00:00 2001 From: changwenbin Date: Tue, 12 Aug 2025 15:10:19 +0800 Subject: [PATCH] optimize DeepSeek_v3 Eliminate redundant calculations & encoder using FA3 --- .../layers/attention/mla_attention_backend.py | 43 ++++++++++---- .../model_executor/models/deepseek_v3.py | 59 +++++++------------ 2 files changed, 52 insertions(+), 50 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 5279b68f6f..8f2c58ac93 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -24,6 +24,11 @@ import paddle from paddle.nn.functional.flash_attention import flash_attn_unpadded +try: + from paddle.nn.functional.flash_attention import flash_attention_v3_varlen +except: + flash_attention_v3_varlen = None + from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, init_kv_signal_per_query, @@ -91,6 +96,7 @@ class MLAAttentionBackend(AttentionBackend): """ __infer_dynamic_dims_fields__ = ["attention_metadata"] + flash_attn_func: callable = None attention_metadata: MLAAttentionMetadata def __init__( @@ -147,6 +153,21 @@ def __init__( self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) self.rank, self.device_id = init_rank_and_device_id(fd_config) + if self.flash_attn_func is None: + prop = paddle.device.cuda.get_device_properties() + cc = prop.major * 10 + prop.minor + is_current_sm_supported = cc >= 90 + is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs()) + if is_current_sm_supported and is_paddle_supported: + self.flash_attn_func = flash_attention_v3_varlen + print("The current platform supports Flash Attention V3.") + self.flash_attn_kwargs = {} + else: + self.flash_attn_func = flash_attn_unpadded + self.flash_attn_kwargs = {"scale": self.head_dim**-0.5, "training": False} + print( + "The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead." + ) def init_attention_metadata(self, forward_meta: ForwardMeta): """Initialize attention metadata hence all layers in the forward pass can reuse it.""" @@ -269,17 +290,16 @@ def forward_extend( ) # Flash注意力计算 - fmha_out = flash_attn_unpadded( + fmha_out = self.flash_attn_func( q, k, v, forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_k, - metadata.max_enc_len_this_time, - metadata.max_enc_len_this_time, - self.attn_softmax_scale, - causal=True, - training=False, + max_seqlen_q=forward_meta.max_len_tensor_cpu[0], + max_seqlen_k=forward_meta.max_len_tensor_cpu[3], + causal=self.causal, + **self.flash_attn_kwargs, )[0] return fmha_out @@ -418,17 +438,16 @@ def forward_mixed( ) # FA - fmha_out = flash_attn_unpadded( + fmha_out = self.flash_attn_func( q, k, v, forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_k, - metadata.max_enc_len_this_time, - metadata.max_enc_len_this_time, - self.attn_softmax_scale, - causal=True, - training=False, + max_seqlen_q=forward_meta.max_len_tensor_cpu[0], + max_seqlen_k=forward_meta.max_len_tensor_cpu[3], + causal=self.causal, + **self.flash_attn_kwargs, )[0] return fmha_out diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 03f6cea76c..74bcee28b9 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -316,30 +316,23 @@ def forward( mask_encoder_batch: paddle.Tensor, ): """ """ - layernorm_out = hidden_states - fmha_out = paddle.zeros( - shape=[ - layernorm_out.shape[0], - self.num_attention_heads_tp * self.v_head_dim, - ], - dtype=layernorm_out.dtype, - ) - - if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time - query = self.q_a_proj(layernorm_out) - query = self.q_a_layernorm(query) - query = self.q_b_proj(query) + fmha_out = None + # NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation. + query = self.q_a_proj(hidden_states) + query = self.q_a_layernorm(query) + query = self.q_b_proj(query) + query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) + query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) - query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) - query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) + key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) + compressed_kv = self.kv_a_layernorm(compressed_kv) - compressed_kv = self.kv_a_proj_with_mqa(layernorm_out) - compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) - key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) - compressed_kv = self.kv_a_layernorm(compressed_kv) - - query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) + query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) + if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time + # NOTE: (changwenbin) We will take the public part key_value = self.kv_b_proj(compressed_kv) key_value = key_value.reshape( [ @@ -371,23 +364,10 @@ def forward( fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype) - fmha_out = fmha_out + fmha_out_prefill - if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time - query = self.q_a_proj(layernorm_out) - query = self.q_a_layernorm(query) - ln_out_or_q_c = query - - compressed_kv = self.kv_a_proj_with_mqa(layernorm_out) - compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) - key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) - compressed_kv = self.kv_a_layernorm(compressed_kv) - - query = self.q_b_proj(ln_out_or_q_c) - query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) - - query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) - query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) + fmha_out = fmha_out_prefill + if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time + # NOTE: (changwenbin) We will take the public part q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) q_input = paddle.concat([q_nope_out, query_pe], axis=-1) @@ -416,7 +396,10 @@ def forward( .transpose([1, 0, 2]) .reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) ) - fmha_out = fmha_out + fmha_out_decode + if fmha_out is None: + fmha_out = fmha_out_decode + else: + fmha_out = fmha_out + fmha_out_decode output = self.o_proj(fmha_out) return output