File tree Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Expand file tree Collapse file tree 2 files changed +7
-3
lines changed Original file line number Diff line number Diff line change 17
17
WITH_TMA = os .getenv ("WITH_TMA" )
18
18
HAS_EXPLICIT_WS = os .getenv ("ENABLE_EXPLICIT_WS" )
19
19
SUPPORT_GLUON = os .getenv ("WITH_GLUON" )
20
+ WITH_MAXNREG = os .getenv ("WITH_MAXNREG" )
20
21
21
22
22
23
class TmaAutoTuneHelper :
Original file line number Diff line number Diff line change 17
17
import triton .language as tl
18
18
from triton .tools .tensor_descriptor import TensorDescriptor
19
19
20
+ from .attention_utils import WITH_MAXNREG
21
+
20
22
from .blackwell_attention_utils import (
21
23
is_blackwell ,
22
24
is_cuda ,
@@ -204,12 +206,13 @@ def _host_descriptor_pre_hook(nargs):
204
206
num_stages = s ,
205
207
num_warps = w ,
206
208
pre_hook = _host_descriptor_pre_hook ,
209
+ # ir_override=f"override/_attn_fwd_persist.ttgir"
207
210
)
208
211
for BM in [256 ]
209
212
for BN in [128 ]
210
213
for s in NUM_STAGES_OPTIONS
211
214
for w in [4 ]
212
- for subtile in [True , False ]
215
+ for subtile in [False ] # disable subtiling for now
213
216
]
214
217
215
218
@@ -267,7 +270,7 @@ def _attn_fwd_tma_dp(
267
270
off_z = off_hz // H
268
271
off_h = off_hz % H
269
272
270
- offset_y = off_z + off_h * N_CTX
273
+ offset_y = off_z * ( N_CTX * H ) + off_h * N_CTX
271
274
qo_offset_y = offset_y + start_m * BLOCK_M
272
275
# initialize offsets
273
276
offs_m0 = start_m * BLOCK_M + tl .arange (0 , BLOCK_M // 2 )
@@ -569,7 +572,7 @@ def grid_debug(META):
569
572
570
573
ctx .grid = grid
571
574
persistent = baseVariant == "persistent" or baseVariant == "ws_persistent"
572
- if is_blackwell () and warp_specialize :
575
+ if WITH_MAXNREG and is_blackwell () and warp_specialize :
573
576
if HEAD_DIM_K == 128 and (
574
577
q .dtype == torch .float16 or q .dtype == torch .bfloat16
575
578
):
You can’t perform that action at this time.
0 commit comments