File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
tritonbench/operators/blackwell_attentions Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change 6
6
7
7
8
8
import argparse
9
+ import math
9
10
import os
10
11
from contextlib import nullcontext
11
12
@@ -109,6 +110,9 @@ def parse_op_args(args: List[str]):
109
110
parser .add_argument (
110
111
"--pt2-sdpa" , action = "store_true" , help = "Compile SDPA with PT2."
111
112
)
113
+ parser .add_argument (
114
+ "--sm-scale" , type = Optional [float ], default = None , help = "softmax scale"
115
+ )
112
116
parser .add_argument (
113
117
"--input-types" ,
114
118
type = str ,
@@ -138,7 +142,7 @@ def __init__(
138
142
self .native_sdpa = args .native_sdpa
139
143
self .pt2_sdpa = args .pt2_sdpa
140
144
self .input_types = args .input_types
141
- self .sm_scale = 1.3
145
+ self .sm_scale = args . sm_scale if args . sm_scale else 1.0 / math . sqrt ( self . D_HEAD )
142
146
143
147
@register_benchmark ()
144
148
def aten (
You can’t perform that action at this time.
0 commit comments