Skip to content

【Inference Optimize】optimize DeepSeek_v3 #3349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions fastdeploy/model_executor/layers/attention/mla_attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
import paddle
from paddle.nn.functional.flash_attention import flash_attn_unpadded

try:
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
except:
flash_attention_v3_varlen = None

from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block,
init_kv_signal_per_query,
Expand Down Expand Up @@ -91,6 +96,7 @@ class MLAAttentionBackend(AttentionBackend):
"""

__infer_dynamic_dims_fields__ = ["attention_metadata"]
flash_attn_func: callable = None
attention_metadata: MLAAttentionMetadata

def __init__(
Expand Down Expand Up @@ -147,6 +153,21 @@ def __init__(
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)

self.rank, self.device_id = init_rank_and_device_id(fd_config)
if self.flash_attn_func is None:
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
is_current_sm_supported = cc >= 90
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
if is_current_sm_supported and is_paddle_supported:
self.flash_attn_func = flash_attention_v3_varlen
print("The current platform supports Flash Attention V3.")
self.flash_attn_kwargs = {}
else:
self.flash_attn_func = flash_attn_unpadded
self.flash_attn_kwargs = {"scale": self.head_dim**-0.5, "training": False}
print(
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
)

def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
Expand Down Expand Up @@ -269,17 +290,16 @@ def forward_extend(
)

# Flash注意力计算
fmha_out = flash_attn_unpadded(
fmha_out = self.flash_attn_func(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
self.attn_softmax_scale,
causal=True,
training=False,
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
causal=self.causal,
**self.flash_attn_kwargs,
)[0]

return fmha_out
Expand Down Expand Up @@ -418,17 +438,16 @@ def forward_mixed(
)

# FA
fmha_out = flash_attn_unpadded(
fmha_out = self.flash_attn_func(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
self.attn_softmax_scale,
causal=True,
training=False,
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
causal=self.causal,
**self.flash_attn_kwargs,
)[0]

return fmha_out
Expand Down
59 changes: 21 additions & 38 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,30 +316,23 @@ def forward(
mask_encoder_batch: paddle.Tensor,
):
""" """
layernorm_out = hidden_states
fmha_out = paddle.zeros(
shape=[
layernorm_out.shape[0],
self.num_attention_heads_tp * self.v_head_dim,
],
dtype=layernorm_out.dtype,
)

if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
query = self.q_a_proj(layernorm_out)
query = self.q_a_layernorm(query)
query = self.q_b_proj(query)
fmha_out = None
# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.
query = self.q_a_proj(hidden_states)
query = self.q_a_layernorm(query)
query = self.q_b_proj(query)
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)

query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)

compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)

query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)

if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
# NOTE: (changwenbin) We will take the public part
key_value = self.kv_b_proj(compressed_kv)
key_value = key_value.reshape(
[
Expand Down Expand Up @@ -371,23 +364,10 @@ def forward(
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)

fmha_out = fmha_out + fmha_out_prefill
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
query = self.q_a_proj(layernorm_out)
query = self.q_a_layernorm(query)
ln_out_or_q_c = query

compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)

query = self.q_b_proj(ln_out_or_q_c)
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])

query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
fmha_out = fmha_out_prefill

if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
# NOTE: (changwenbin) We will take the public part
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])

q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
Expand Down Expand Up @@ -416,7 +396,10 @@ def forward(
.transpose([1, 0, 2])
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
)
fmha_out = fmha_out + fmha_out_decode
if fmha_out is None:
fmha_out = fmha_out_decode
else:
fmha_out = fmha_out + fmha_out_decode

output = self.o_proj(fmha_out)
return output
Expand Down
Loading