29
29
)
30
30
31
31
32
+ # Check if Triton version supports minRegAutoWS and maxRegAutoWS
33
+ # These parameters are only available in triton/tree/ws-3.5
34
+ def _supports_reg_auto_ws ():
35
+ """Check if the current Triton version supports minRegAutoWS/maxRegAutoWS"""
36
+ try :
37
+ # Try to create a Config with minRegAutoWS to test support
38
+ test_config = triton .Config ({}, minRegAutoWS = 24 , maxRegAutoWS = 152 )
39
+ return True
40
+ except (TypeError , AttributeError ):
41
+ # Parameter not supported in this Triton version
42
+ return False
43
+
44
+
45
+ HAS_REG_AUTO_WS = _supports_reg_auto_ws ()
46
+
47
+
32
48
@triton .jit
33
49
def _attn_fwd_subtile (
34
50
q ,
@@ -221,20 +237,27 @@ def _host_descriptor_pre_hook(nargs):
221
237
NUM_STAGES_OPTIONS = [3 ]
222
238
223
239
if is_tile_enabled ():
240
+ # Helper to build config with optional minRegAutoWS/maxRegAutoWS
241
+ def make_tile_config (BM , BN , occ , subtile , vectmul , add2reduce ):
242
+ config_kwargs = {
243
+ "BLOCK_M" : BM ,
244
+ "BLOCK_N" : BN ,
245
+ "occupancy" : occ ,
246
+ "SUBTILING" : subtile ,
247
+ "VECT_MUL" : vectmul ,
248
+ "FADD2_REDUCE" : add2reduce ,
249
+ }
250
+ extra_kwargs = {"pre_hook" : _host_descriptor_pre_hook }
251
+
252
+ # Only add minRegAutoWS/maxRegAutoWS if supported (triton/tree/ws-3.5)
253
+ if HAS_REG_AUTO_WS :
254
+ extra_kwargs ["minRegAutoWS" ] = 24
255
+ extra_kwargs ["maxRegAutoWS" ] = 152
256
+
257
+ return triton .Config (config_kwargs , ** extra_kwargs )
258
+
224
259
configs = [
225
- triton .Config (
226
- {
227
- "BLOCK_M" : BM ,
228
- "BLOCK_N" : BN ,
229
- "occupancy" : occ ,
230
- "SUBTILING" : subtile ,
231
- "VECT_MUL" : vectmul ,
232
- "FADD2_REDUCE" : add2reduce ,
233
- },
234
- pre_hook = _host_descriptor_pre_hook ,
235
- minRegAutoWS = 24 ,
236
- maxRegAutoWS = 152 ,
237
- )
260
+ make_tile_config (BM , BN , occ , subtile , vectmul , add2reduce )
238
261
for BM in [64 , 128 , 256 ]
239
262
for BN in [64 , 128 ]
240
263
for occ in [1 , 2 ]
@@ -243,22 +266,30 @@ def _host_descriptor_pre_hook(nargs):
243
266
for add2reduce in [False ]
244
267
]
245
268
else :
269
+ # Helper to build config with optional minRegAutoWS/maxRegAutoWS
270
+ def make_standard_config (BM , BN , s , w , subtile , vectmul , add2reduce ):
271
+ config_kwargs = {
272
+ "BLOCK_M" : BM ,
273
+ "BLOCK_N" : BN ,
274
+ "SUBTILING" : subtile ,
275
+ "VECT_MUL" : vectmul ,
276
+ "FADD2_REDUCE" : add2reduce ,
277
+ }
278
+ extra_kwargs = {
279
+ "num_stages" : s ,
280
+ "num_warps" : w ,
281
+ "pre_hook" : _host_descriptor_pre_hook ,
282
+ }
283
+
284
+ # Only add minRegAutoWS/maxRegAutoWS if supported (triton/tree/ws-3.5)
285
+ if HAS_REG_AUTO_WS :
286
+ extra_kwargs ["minRegAutoWS" ] = 24
287
+ extra_kwargs ["maxRegAutoWS" ] = 152
288
+
289
+ return triton .Config (config_kwargs , ** extra_kwargs )
290
+
246
291
configs = [
247
- triton .Config (
248
- {
249
- "BLOCK_M" : BM ,
250
- "BLOCK_N" : BN ,
251
- "SUBTILING" : subtile ,
252
- "VECT_MUL" : vectmul ,
253
- "FADD2_REDUCE" : add2reduce ,
254
- },
255
- num_stages = s ,
256
- num_warps = w ,
257
- pre_hook = _host_descriptor_pre_hook ,
258
- minRegAutoWS = 24 ,
259
- maxRegAutoWS = 152 ,
260
- # ir_override=f"override/_attn_fwd_persist.ttgir"
261
- )
292
+ make_standard_config (BM , BN , s , w , subtile , vectmul , add2reduce )
262
293
for BM in [256 ]
263
294
for BN in [128 ]
264
295
for s in NUM_STAGES_OPTIONS
0 commit comments