@@ -90,8 +90,14 @@ def attention_varlen_ref(
90
90
@pytest .mark .parametrize ("kv_len" , [1 , 17 , 544 , 977 , 1999 ])
91
91
@pytest .mark .parametrize ("num_qo_heads" , [32 ])
92
92
@pytest .mark .parametrize ("num_kv_heads" , [8 , 32 ])
93
- @pytest .mark .parametrize ("head_dim_qk,head_dim_vo" , [(192 , 128 ), (128 , 128 ), (64 , 64 )])
94
- @pytest .mark .parametrize ("sm_scale" , [1.0 , 1.0 / math .sqrt (192 ), 1.0 / math .sqrt (128 )])
93
+ @pytest .mark .parametrize (
94
+ "head_dim_qk,head_dim_vo,sm_scale" ,
95
+ [
96
+ (192 , 128 , 1.0 / math .sqrt (192 )),
97
+ (128 , 128 , 1.0 / math .sqrt (128 )),
98
+ (64 , 64 , 1.0 / math .sqrt (64 )),
99
+ ],
100
+ )
95
101
@pytest .mark .parametrize ("causal" , [False , True ])
96
102
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
97
103
def test_blackwell_cutlass_fmha (
@@ -168,8 +174,14 @@ def test_blackwell_cutlass_fmha(
168
174
@pytest .mark .parametrize ("indptr" , VARLEN_INDPTR_PARAMS )
169
175
@pytest .mark .parametrize ("num_qo_heads" , [32 ])
170
176
@pytest .mark .parametrize ("num_kv_heads" , [8 , 32 ])
171
- @pytest .mark .parametrize ("head_dim_qk,head_dim_vo" , [(192 , 128 ), (128 , 128 ), (64 , 64 )])
172
- @pytest .mark .parametrize ("sm_scale" , [1.0 / math .sqrt (128 )])
177
+ @pytest .mark .parametrize (
178
+ "head_dim_qk,head_dim_vo,sm_scale" ,
179
+ [
180
+ (192 , 128 , 1.0 / math .sqrt (192 )),
181
+ (128 , 128 , 1.0 / math .sqrt (128 )),
182
+ (64 , 64 , 1.0 / math .sqrt (64 )),
183
+ ],
184
+ )
173
185
@pytest .mark .parametrize ("causal" , [False , True ])
174
186
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
175
187
def test_blackwell_cutlass_varlen (
@@ -249,8 +261,14 @@ def test_blackwell_cutlass_varlen(
249
261
@pytest .mark .parametrize ("kv_indptr_list" , [[0 , 50 , 50 , 50 , 50 , 50 , 50 , 50 ]])
250
262
@pytest .mark .parametrize ("num_qo_heads" , [32 ])
251
263
@pytest .mark .parametrize ("num_kv_heads" , [8 , 32 ])
252
- @pytest .mark .parametrize ("head_dim_qk,head_dim_vo" , [(192 , 128 ), (128 , 128 ), (64 , 64 )])
253
- @pytest .mark .parametrize ("sm_scale" , [1.0 / math .sqrt (128 )])
264
+ @pytest .mark .parametrize (
265
+ "head_dim_qk,head_dim_vo,sm_scale" ,
266
+ [
267
+ (192 , 128 , 1.0 / math .sqrt (192 )),
268
+ (128 , 128 , 1.0 / math .sqrt (128 )),
269
+ (64 , 64 , 1.0 / math .sqrt (64 )),
270
+ ],
271
+ )
254
272
@pytest .mark .parametrize ("dtype" , [torch .half , torch .bfloat16 ])
255
273
def test_blackwell_cutlass_qo_kv_varlen (
256
274
qo_indptr_list ,
0 commit comments