Skip to content

Commit 16eddff

Browse files
authored
Default sm_scale to None (#344) (#344)
Summary: To make the behavior consistent with FA4: https://www.internalfb.com/code/fbsource/fbcode/ai_codesign/gen_ai/flash_attention_v2/benchmarks/benchmark_attn.py In FA4 benchmark, sm_scale defaults to None, which further translates to ` 1.0 / math.sqrt(head_dim)`: https://www.internalfb.com/code/fbsource/[f03d88961ee4ecd3f4ee76736d7de904351d295c]/fbcode/ai_codesign/gen_ai/flash_attention_v2/flash_attn/cute/interface.py?lines=112 whereas Tritonbench tries to pin the sm_scale to 1.3. Change Tritonbench behavior to set sm_scale to None by default. Reviewed By: njriasan, jduprat Differential Revision: D80652124
1 parent 34e755a commit 16eddff

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

tritonbench/operators/blackwell_attentions/operator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
import argparse
9+
import math
910
import os
1011
from contextlib import nullcontext
1112

@@ -109,6 +110,9 @@ def parse_op_args(args: List[str]):
109110
parser.add_argument(
110111
"--pt2-sdpa", action="store_true", help="Compile SDPA with PT2."
111112
)
113+
parser.add_argument(
114+
"--sm-scale", type=Optional[float], default=None, help="softmax scale"
115+
)
112116
parser.add_argument(
113117
"--input-types",
114118
type=str,
@@ -138,7 +142,7 @@ def __init__(
138142
self.native_sdpa = args.native_sdpa
139143
self.pt2_sdpa = args.pt2_sdpa
140144
self.input_types = args.input_types
141-
self.sm_scale = 1.3
145+
self.sm_scale = args.sm_scale if args.sm_scale else 1.0 / math.sqrt(self.D_HEAD)
142146

143147
@register_benchmark()
144148
def aten(

0 commit comments

Comments
 (0)