Skip to content

Commit 58ca603

Browse files
authored
Enable TileIR for FA (#498) (#498)
1 parent 496c120 commit 58ca603

File tree

3 files changed

+114
-73
lines changed

3 files changed

+114
-73
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
Set of common attention utils that are exclusive to Blackwell. Separated to avoid issues with more
3+
generic attention kernels.
4+
"""
5+
6+
import os
7+
from functools import lru_cache
8+
9+
import torch
10+
import triton
11+
12+
13+
@lru_cache
14+
def is_tile_enabled():
15+
# Note: This assumes you have the TileIR backend.
16+
# We don't have a reliable way to check this at this time.
17+
return os.getenv("ENABLE_TILE", "0") == "1"
18+
19+
20+
def is_hip():
21+
return triton.runtime.driver.active.get_current_target().backend == "hip"
22+
23+
24+
# Note: This seems to only be set at autotuning and cannot be reliably used.
25+
def is_cuda_tileir():
26+
return (
27+
triton.runtime.driver.active.get_current_target().backend == "triton_cuda_tile"
28+
)
29+
30+
31+
def is_cuda_triton():
32+
return triton.runtime.driver.active.get_current_target().backend == "cuda"
33+
34+
35+
def is_cuda():
36+
return is_cuda_triton() or is_cuda_tileir()
37+
38+
39+
def supports_host_descriptor():
40+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
41+
42+
43+
def is_blackwell():
44+
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
45+
46+
47+
def is_hopper():
48+
return is_cuda() and torch.cuda.get_device_capability()[0] == 9

tritonbench/kernels/blackwell_triton_fused_attention.py

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,20 @@
1111
1212
"""
1313

14-
import sys
15-
16-
from typing import Optional
17-
1814
import torch
1915

2016
import triton
2117
import triton.language as tl
2218
from triton.tools.tensor_descriptor import TensorDescriptor
2319

24-
25-
def is_hip():
26-
return triton.runtime.driver.active.get_current_target().backend == "hip"
27-
28-
29-
def is_cuda():
30-
return triton.runtime.driver.active.get_current_target().backend == "cuda"
31-
32-
33-
def supports_host_descriptor():
34-
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
35-
36-
37-
def is_blackwell():
38-
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
39-
40-
41-
def is_hopper():
42-
return is_cuda() and torch.cuda.get_device_capability()[0] == 9
20+
from .blackwell_attention_utils import (
21+
is_blackwell,
22+
is_cuda,
23+
is_hip,
24+
is_hopper,
25+
is_tile_enabled,
26+
supports_host_descriptor,
27+
)
4328

4429

4530
@triton.jit
@@ -97,7 +82,7 @@ def _attn_fwd_inner(
9782
alpha = tl.math.exp2(m_i - m_ij)
9883
l_ij = tl.sum(p, 1)
9984
# -- update output accumulator --
100-
if not IS_HOPPER and warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128:
85+
if ((not IS_HOPPER and warp_specialize) and BLOCK_M == 128) and HEAD_DIM == 128:
10186
BM: tl.constexpr = acc.shape[0]
10287
BN: tl.constexpr = acc.shape[1]
10388
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
@@ -145,18 +130,29 @@ def _host_descriptor_pre_hook(nargs):
145130
else:
146131
NUM_STAGES_OPTIONS = [2, 3, 4]
147132

148-
configs = [
149-
triton.Config(
150-
{"BLOCK_M": BM, "BLOCK_N": BN},
151-
num_stages=s,
152-
num_warps=w,
153-
pre_hook=_host_descriptor_pre_hook,
154-
)
155-
for BM in [64, 128]
156-
for BN in [32, 64, 128]
157-
for s in NUM_STAGES_OPTIONS
158-
for w in [4, 8]
159-
]
133+
if is_tile_enabled():
134+
configs = [
135+
triton.Config(
136+
{"BLOCK_M": BM, "BLOCK_N": BN, "occupancy": occ},
137+
pre_hook=_host_descriptor_pre_hook,
138+
)
139+
for BM in [64, 128, 256]
140+
for BN in [64, 128]
141+
for occ in [1, 2]
142+
]
143+
else:
144+
configs = [
145+
triton.Config(
146+
{"BLOCK_M": BM, "BLOCK_N": BN},
147+
num_stages=s,
148+
num_warps=w,
149+
pre_hook=_host_descriptor_pre_hook,
150+
)
151+
for BM in [64, 128]
152+
for BN in [32, 64, 128]
153+
for s in NUM_STAGES_OPTIONS
154+
for w in [4, 8]
155+
]
160156

161157

162158
def keep(conf):
@@ -453,6 +449,7 @@ def forward(ctx, q, k, v, causal, sm_scale, baseVariant):
453449
M = torch.empty(
454450
(q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
455451
)
452+
warp_specialize = baseVariant == "ws" or baseVariant == "ws_persistent"
456453
# Use device_descriptor for Hopper + warpspec.
457454
if supports_host_descriptor() and not (is_hopper() and warp_specialize):
458455
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor

tritonbench/kernels/blackwell_triton_fused_attention_dp.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,20 @@
1111
1212
"""
1313

14-
import sys
15-
16-
from typing import Optional
17-
1814
import torch
1915

2016
import triton
2117
import triton.language as tl
2218
from triton.tools.tensor_descriptor import TensorDescriptor
2319

24-
25-
def is_hip():
26-
return triton.runtime.driver.active.get_current_target().backend == "hip"
27-
28-
29-
def is_cuda():
30-
return triton.runtime.driver.active.get_current_target().backend == "cuda"
31-
32-
33-
def supports_host_descriptor():
34-
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
35-
36-
37-
def is_blackwell():
38-
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
39-
40-
41-
def is_hopper():
42-
return is_cuda() and torch.cuda.get_device_capability()[0] == 9
20+
from .blackwell_attention_utils import (
21+
is_blackwell,
22+
is_cuda,
23+
is_hip,
24+
is_hopper,
25+
is_tile_enabled,
26+
supports_host_descriptor,
27+
)
4328

4429

4530
@triton.jit
@@ -201,20 +186,31 @@ def _host_descriptor_pre_hook(nargs):
201186
else:
202187
NUM_STAGES_OPTIONS = [3]
203188

204-
configs = [
205-
triton.Config(
206-
{"BLOCK_M": BM, "BLOCK_N": BN, "SUBTILING": subtile},
207-
num_stages=s,
208-
num_warps=w,
209-
pre_hook=_host_descriptor_pre_hook,
210-
# ir_override=f"/home/mren/OpenSource/tritonbench/override/_attn_fwd_persist.ttgir"
211-
)
212-
for BM in [256]
213-
for BN in [128]
214-
for s in NUM_STAGES_OPTIONS
215-
for w in [4]
216-
for subtile in [True]
217-
]
189+
if is_tile_enabled():
190+
configs = [
191+
triton.Config(
192+
{"BLOCK_M": BM, "BLOCK_N": BN, "occupancy": occ, "SUBTILING": subtile},
193+
pre_hook=_host_descriptor_pre_hook,
194+
)
195+
for BM in [64, 128, 256]
196+
for BN in [64, 128]
197+
for occ in [1, 2]
198+
for subtile in [True]
199+
]
200+
else:
201+
configs = [
202+
triton.Config(
203+
{"BLOCK_M": BM, "BLOCK_N": BN, "SUBTILING": subtile},
204+
num_stages=s,
205+
num_warps=w,
206+
pre_hook=_host_descriptor_pre_hook,
207+
)
208+
for BM in [256]
209+
for BN in [128]
210+
for s in NUM_STAGES_OPTIONS
211+
for w in [4]
212+
for subtile in [True, False]
213+
]
218214

219215

220216
def keep(conf):

0 commit comments

Comments
 (0)