Skip to content

Commit fe646da

Browse files
authored
[FA][autoWS] fix bugs and add configs for maxnreg
Differential Revision: D83878982 Pull Request resolved: #515
1 parent adad9c3 commit fe646da

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

tritonbench/kernels/attention_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
WITH_TMA = os.getenv("WITH_TMA")
1818
HAS_EXPLICIT_WS = os.getenv("ENABLE_EXPLICIT_WS")
1919
SUPPORT_GLUON = os.getenv("WITH_GLUON")
20+
WITH_MAXNREG = os.getenv("WITH_MAXNREG")
2021

2122

2223
class TmaAutoTuneHelper:

tritonbench/kernels/blackwell_triton_fused_attention_dp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import triton.language as tl
1818
from triton.tools.tensor_descriptor import TensorDescriptor
1919

20+
from .attention_utils import WITH_MAXNREG
21+
2022
from .blackwell_attention_utils import (
2123
is_blackwell,
2224
is_cuda,
@@ -204,12 +206,13 @@ def _host_descriptor_pre_hook(nargs):
204206
num_stages=s,
205207
num_warps=w,
206208
pre_hook=_host_descriptor_pre_hook,
209+
# ir_override=f"override/_attn_fwd_persist.ttgir"
207210
)
208211
for BM in [256]
209212
for BN in [128]
210213
for s in NUM_STAGES_OPTIONS
211214
for w in [4]
212-
for subtile in [True, False]
215+
for subtile in [False] # disable subtiling for now
213216
]
214217

215218

@@ -267,7 +270,7 @@ def _attn_fwd_tma_dp(
267270
off_z = off_hz // H
268271
off_h = off_hz % H
269272

270-
offset_y = off_z + off_h * N_CTX
273+
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
271274
qo_offset_y = offset_y + start_m * BLOCK_M
272275
# initialize offsets
273276
offs_m0 = start_m * BLOCK_M + tl.arange(0, BLOCK_M // 2)
@@ -569,7 +572,7 @@ def grid_debug(META):
569572

570573
ctx.grid = grid
571574
persistent = baseVariant == "persistent" or baseVariant == "ws_persistent"
572-
if is_blackwell() and warp_specialize:
575+
if WITH_MAXNREG and is_blackwell() and warp_specialize:
573576
if HEAD_DIM_K == 128 and (
574577
q.dtype == torch.float16 or q.dtype == torch.bfloat16
575578
):

0 commit comments

Comments
 (0)