Skip to content

Commit 1278ae7

Browse files
Refactor attn-kernel benchmark and hotfix a logging bug (#227)
* refactor attn-kernel benchmark * removed debug gpu-cpu sync * fix some issues in bwd * fix some issues --------- Co-authored-by: Strivin0311 <hyp@smail.nju.edu.cn>
1 parent 5e9a442 commit 1278ae7

File tree

7 files changed

+1561
-47
lines changed

7 files changed

+1561
-47
lines changed

exps/attn/baselines/attn_impl.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,28 @@
5353
)
5454

5555
try:
56+
import transformer_engine as te
5657
from transformer_engine.pytorch.attention.dot_product_attention.backends import (
5758
FusedAttnFunc,
5859
)
60+
61+
te_2_9_0 = version.parse(te.__version__) >= version.parse("2.9.0")
5962
except ImportError:
6063
FusedAttnFunc = missing_dependency(
6164
dep_name="transformer_engine",
6265
func_name="FusedAttnFunc",
6366
)
6467

68+
try:
69+
from paddle.nn.functional.flash_attention import (
70+
flashmask_attention as flashmask_func,
71+
)
72+
except ImportError:
73+
flashmask_func = missing_dependency(
74+
dep_name="paddle",
75+
func_name="flashmask_attention",
76+
)
77+
6578
if version.parse(torch.__version__) > version.parse("2.4"):
6679
# NOTE: in benchmarking, we should explicitly allow bf16/fp16 reduction for sdpa
6780
# by setting `torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)`
@@ -174,6 +187,13 @@ def cudnn_fused_attn_func(
174187
softmax_offset = None
175188
layer_number = 1
176189

190+
if te_2_9_0:
191+
fused_attn_args = [
192+
False, # return_max_logit
193+
]
194+
else:
195+
fused_attn_args = []
196+
177197
output = FusedAttnFunc.apply(
178198
is_training,
179199
max_seqlen_q,
@@ -206,6 +226,7 @@ def cudnn_fused_attn_func(
206226
softmax_offset,
207227
fp8_output,
208228
layer_number,
229+
*fused_attn_args,
209230
)
210231

211232
return output
@@ -224,4 +245,5 @@ def cudnn_fused_attn_func(
224245
"fa2_varlen_func",
225246
"fa3_varlen_func",
226247
"fa4_varlen_func",
248+
"flashmask_func",
227249
]

0 commit comments

Comments
 (0)