Skip to content

Commit 2beb393

Browse files
authored
[autoWS] CI fix - check if feature is supported for autotuning on minRegAutoWS and maxRegAutoWS (#535)
1 parent ee97170 commit 2beb393

File tree

1 file changed

+59
-28
lines changed

1 file changed

+59
-28
lines changed

tritonbench/kernels/blackwell_triton_fused_attention_dp.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,22 @@
2929
)
3030

3131

32+
# Check if Triton version supports minRegAutoWS and maxRegAutoWS
33+
# These parameters are only available in triton/tree/ws-3.5
34+
def _supports_reg_auto_ws():
35+
"""Check if the current Triton version supports minRegAutoWS/maxRegAutoWS"""
36+
try:
37+
# Try to create a Config with minRegAutoWS to test support
38+
test_config = triton.Config({}, minRegAutoWS=24, maxRegAutoWS=152)
39+
return True
40+
except (TypeError, AttributeError):
41+
# Parameter not supported in this Triton version
42+
return False
43+
44+
45+
HAS_REG_AUTO_WS = _supports_reg_auto_ws()
46+
47+
3248
@triton.jit
3349
def _attn_fwd_subtile(
3450
q,
@@ -221,20 +237,27 @@ def _host_descriptor_pre_hook(nargs):
221237
NUM_STAGES_OPTIONS = [3]
222238

223239
if is_tile_enabled():
240+
# Helper to build config with optional minRegAutoWS/maxRegAutoWS
241+
def make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce):
242+
config_kwargs = {
243+
"BLOCK_M": BM,
244+
"BLOCK_N": BN,
245+
"occupancy": occ,
246+
"SUBTILING": subtile,
247+
"VECT_MUL": vectmul,
248+
"FADD2_REDUCE": add2reduce,
249+
}
250+
extra_kwargs = {"pre_hook": _host_descriptor_pre_hook}
251+
252+
# Only add minRegAutoWS/maxRegAutoWS if supported (triton/tree/ws-3.5)
253+
if HAS_REG_AUTO_WS:
254+
extra_kwargs["minRegAutoWS"] = 24
255+
extra_kwargs["maxRegAutoWS"] = 152
256+
257+
return triton.Config(config_kwargs, **extra_kwargs)
258+
224259
configs = [
225-
triton.Config(
226-
{
227-
"BLOCK_M": BM,
228-
"BLOCK_N": BN,
229-
"occupancy": occ,
230-
"SUBTILING": subtile,
231-
"VECT_MUL": vectmul,
232-
"FADD2_REDUCE": add2reduce,
233-
},
234-
pre_hook=_host_descriptor_pre_hook,
235-
minRegAutoWS=24,
236-
maxRegAutoWS=152,
237-
)
260+
make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce)
238261
for BM in [64, 128, 256]
239262
for BN in [64, 128]
240263
for occ in [1, 2]
@@ -243,22 +266,30 @@ def _host_descriptor_pre_hook(nargs):
243266
for add2reduce in [False]
244267
]
245268
else:
269+
# Helper to build config with optional minRegAutoWS/maxRegAutoWS
270+
def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce):
271+
config_kwargs = {
272+
"BLOCK_M": BM,
273+
"BLOCK_N": BN,
274+
"SUBTILING": subtile,
275+
"VECT_MUL": vectmul,
276+
"FADD2_REDUCE": add2reduce,
277+
}
278+
extra_kwargs = {
279+
"num_stages": s,
280+
"num_warps": w,
281+
"pre_hook": _host_descriptor_pre_hook,
282+
}
283+
284+
# Only add minRegAutoWS/maxRegAutoWS if supported (triton/tree/ws-3.5)
285+
if HAS_REG_AUTO_WS:
286+
extra_kwargs["minRegAutoWS"] = 24
287+
extra_kwargs["maxRegAutoWS"] = 152
288+
289+
return triton.Config(config_kwargs, **extra_kwargs)
290+
246291
configs = [
247-
triton.Config(
248-
{
249-
"BLOCK_M": BM,
250-
"BLOCK_N": BN,
251-
"SUBTILING": subtile,
252-
"VECT_MUL": vectmul,
253-
"FADD2_REDUCE": add2reduce,
254-
},
255-
num_stages=s,
256-
num_warps=w,
257-
pre_hook=_host_descriptor_pre_hook,
258-
minRegAutoWS=24,
259-
maxRegAutoWS=152,
260-
# ir_override=f"override/_attn_fwd_persist.ttgir"
261-
)
292+
make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce)
262293
for BM in [256]
263294
for BN in [128]
264295
for s in NUM_STAGES_OPTIONS

0 commit comments

Comments
 (0)