-
Notifications
You must be signed in to change notification settings - Fork 595
【Inference Optimize】optimize DeepSeek_v3 #3349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
fde43aa
optimize DeepSeek_v3 Eliminate redundant calculations & encoder using…
chang-wenbin ba92551
Merge branch 'develop' into DSK_OPT1
XieYunshen 696b0cc
Merge remote-tracking branch 'origin/develop' into DSK_OPT1
chang-wenbin 7924b8c
Merge remote-tracking branch 'cwb/DSK_OPT1' into DSK_OPT1
chang-wenbin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,11 @@ | |
import paddle | ||
from paddle.nn.functional.flash_attention import flash_attn_unpadded | ||
|
||
try: | ||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen | ||
except: | ||
flash_attention_v3_varlen = None | ||
|
||
from fastdeploy.model_executor.layers.attention.ops import ( | ||
get_block_shape_and_split_kv_block, | ||
init_kv_signal_per_query, | ||
|
@@ -91,6 +96,7 @@ class MLAAttentionBackend(AttentionBackend): | |
""" | ||
|
||
__infer_dynamic_dims_fields__ = ["attention_metadata"] | ||
flash_attn_func: callable = None | ||
attention_metadata: MLAAttentionMetadata | ||
|
||
def __init__( | ||
|
@@ -147,6 +153,21 @@ def __init__( | |
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) | ||
|
||
self.rank, self.device_id = init_rank_and_device_id(fd_config) | ||
if self.flash_attn_func is None: | ||
prop = paddle.device.cuda.get_device_properties() | ||
cc = prop.major * 10 + prop.minor | ||
is_current_sm_supported = cc >= 90 | ||
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs()) | ||
if is_current_sm_supported and is_paddle_supported: | ||
self.flash_attn_func = flash_attention_v3_varlen | ||
print("The current platform supports Flash Attention V3.") | ||
self.flash_attn_kwargs = {} | ||
else: | ||
self.flash_attn_func = flash_attn_unpadded | ||
self.flash_attn_kwargs = {"scale": self.head_dim**-0.5, "training": False} | ||
print( | ||
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead." | ||
) | ||
|
||
def init_attention_metadata(self, forward_meta: ForwardMeta): | ||
"""Initialize attention metadata hence all layers in the forward pass can reuse it.""" | ||
|
@@ -269,17 +290,16 @@ def forward_extend( | |
) | ||
|
||
# Flash注意力计算 | ||
fmha_out = flash_attn_unpadded( | ||
fmha_out = self.flash_attn_func( | ||
q, | ||
k, | ||
v, | ||
forward_meta.cu_seqlens_q, | ||
forward_meta.cu_seqlens_k, | ||
metadata.max_enc_len_this_time, | ||
metadata.max_enc_len_this_time, | ||
self.attn_softmax_scale, | ||
causal=True, | ||
training=False, | ||
max_seqlen_q=forward_meta.max_len_tensor_cpu[0], | ||
max_seqlen_k=forward_meta.max_len_tensor_cpu[3], | ||
causal=self.causal, | ||
**self.flash_attn_kwargs, | ||
)[0] | ||
|
||
return fmha_out | ||
|
@@ -418,17 +438,16 @@ def forward_mixed( | |
) | ||
|
||
# FA | ||
fmha_out = flash_attn_unpadded( | ||
fmha_out = self.flash_attn_func( | ||
q, | ||
k, | ||
v, | ||
forward_meta.cu_seqlens_q, | ||
forward_meta.cu_seqlens_k, | ||
metadata.max_enc_len_this_time, | ||
metadata.max_enc_len_this_time, | ||
self.attn_softmax_scale, | ||
causal=True, | ||
training=False, | ||
max_seqlen_q=forward_meta.max_len_tensor_cpu[0], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 改之后,测过之前的flash_attn_unpadded没, 输出是否有变化 |
||
max_seqlen_k=forward_meta.max_len_tensor_cpu[3], | ||
causal=self.causal, | ||
**self.flash_attn_kwargs, | ||
)[0] | ||
|
||
return fmha_out | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -316,30 +316,23 @@ def forward( | |
mask_encoder_batch: paddle.Tensor, | ||
): | ||
""" """ | ||
layernorm_out = hidden_states | ||
fmha_out = paddle.zeros( | ||
shape=[ | ||
layernorm_out.shape[0], | ||
self.num_attention_heads_tp * self.v_head_dim, | ||
], | ||
dtype=layernorm_out.dtype, | ||
) | ||
|
||
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time | ||
query = self.q_a_proj(layernorm_out) | ||
query = self.q_a_layernorm(query) | ||
query = self.q_b_proj(query) | ||
fmha_out = None | ||
# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation. | ||
query = self.q_a_proj(hidden_states) | ||
query = self.q_a_layernorm(query) | ||
query = self.q_b_proj(query) | ||
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) | ||
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) | ||
|
||
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) | ||
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) | ||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states) | ||
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) | ||
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) | ||
compressed_kv = self.kv_a_layernorm(compressed_kv) | ||
|
||
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out) | ||
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) | ||
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) | ||
compressed_kv = self.kv_a_layernorm(compressed_kv) | ||
|
||
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) | ||
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) | ||
|
||
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time | ||
# NOTE: (changwenbin) We will take the public part | ||
key_value = self.kv_b_proj(compressed_kv) | ||
key_value = key_value.reshape( | ||
[ | ||
|
@@ -371,23 +364,10 @@ def forward( | |
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) | ||
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype) | ||
|
||
fmha_out = fmha_out + fmha_out_prefill | ||
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time | ||
query = self.q_a_proj(layernorm_out) | ||
query = self.q_a_layernorm(query) | ||
ln_out_or_q_c = query | ||
|
||
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out) | ||
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) | ||
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim]) | ||
compressed_kv = self.kv_a_layernorm(compressed_kv) | ||
|
||
query = self.q_b_proj(ln_out_or_q_c) | ||
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim]) | ||
|
||
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) | ||
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) | ||
fmha_out = fmha_out_prefill | ||
|
||
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time | ||
# NOTE: (changwenbin) We will take the public part | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. take 如何理解, 是准备表达will solove public part,还是会计算public part |
||
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2]) | ||
|
||
q_input = paddle.concat([q_nope_out, query_pe], axis=-1) | ||
|
@@ -416,7 +396,10 @@ def forward( | |
.transpose([1, 0, 2]) | ||
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim]) | ||
) | ||
fmha_out = fmha_out + fmha_out_decode | ||
if fmha_out is None: | ||
fmha_out = fmha_out_decode | ||
else: | ||
fmha_out = fmha_out + fmha_out_decode | ||
|
||
output = self.o_proj(fmha_out) | ||
return output | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.attn_softmax_scale 这个参数,这样传入的话,会用你上面创建的{"scale": self.head_dim**-0.5, 吗?
self.attn_softmax_scale 看之前的代码,是经过self.attn_softmax_scale * mscale * mscale计算出来的