Skip to content

Commit 0213e77

Browse files
authored
Add Support for TLX matmul (#340) (#340)
Summary: Adds a TLX matmul to TritonBench and guards it behind TLX support. Differential Revision: D79178034
1 parent 7cea6a6 commit 0213e77

File tree

4 files changed

+320
-1
lines changed

4 files changed

+320
-1
lines changed

tritonbench/operators/gemm/operator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@
1616
blackwell_matmul_tma,
1717
blackwell_matmul_tma_persistent,
1818
)
19+
from tritonbench.utils.triton_utils import has_tlx
20+
21+
if has_tlx():
22+
from tritonbench.operators.gemm.tlx_matmul import tlx_matmul as _tlx_matmul
23+
else:
24+
25+
def _tlx_matmul(*args, **kwargs):
26+
raise RuntimeError("TLX not available in this Triton version")
27+
28+
1929
from tritonbench.utils.data_utils import get_production_shapes
2030
from tritonbench.utils.env_utils import (
2131
get_nvidia_gpu_model,
@@ -445,6 +455,13 @@ def triton_blackwell_descriptor_matmul(self, a, b, bias) -> Callable:
445455
a, b, warp_specialize=False
446456
)
447457

458+
@register_benchmark(enabled=False)
459+
def tlx_matmul(self, a, b, bias) -> Callable:
460+
if bias is not None:
461+
return lambda: _tlx_matmul(a, b) + bias
462+
else:
463+
return lambda: _tlx_matmul(a, b)
464+
448465
@register_x_val(label="(M, N, K)")
449466
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
450467
# x-value: computation intensity

tritonbench/operators/gemm/stream_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,6 @@ def grid(META):
646646
K, #
647647
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
648648
ENABLE_BUFFER_OPS_ASSUMES=True, #
649-
NUM_SMS=num_sms #
649+
NUM_SMS=num_sms, #
650650
)
651651
return c
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# TLX GEMM kernel optimized for Blackwell Warp Specialization
2+
import torch
3+
4+
import triton
5+
import triton.language as tl
6+
import triton.language.extra.tlx as tlx
7+
from triton.tools.tensor_descriptor import TensorDescriptor
8+
9+
10+
def get_cuda_autotune_config():
11+
return [
12+
triton.Config(
13+
{
14+
"BLOCK_SIZE_M": BM,
15+
"BLOCK_SIZE_N": BN,
16+
"BLOCK_SIZE_K": BK,
17+
"GROUP_SIZE_M": 8,
18+
"NUM_SMEM_BUFFERS": s,
19+
"NUM_TMEM_BUFFERS": t,
20+
"EPILOGUE_SUBTILE": subtile,
21+
},
22+
num_warps=4,
23+
num_stages=1,
24+
pre_hook=matmul_tma_set_block_size_hook,
25+
)
26+
for BM in [128]
27+
for BN in [128, 256]
28+
for BK in [64, 128]
29+
for s in [2, 3, 4]
30+
for t in [2, 3]
31+
for subtile in [True]
32+
]
33+
34+
35+
def matmul_tma_set_block_size_hook(nargs):
36+
BLOCK_M = nargs["BLOCK_SIZE_M"]
37+
BLOCK_N = nargs["BLOCK_SIZE_N"]
38+
BLOCK_K = nargs["BLOCK_SIZE_K"]
39+
nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
40+
nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
41+
EPILOGUE_SUBTILE = nargs.get("EPILOGUE_SUBTILE", False)
42+
if EPILOGUE_SUBTILE:
43+
nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N // 2]
44+
else:
45+
nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
46+
47+
48+
@triton.jit
49+
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M):
50+
group_id = tile_id // num_pid_in_group
51+
first_pid_m = group_id * GROUP_SIZE_M
52+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
53+
pid_m = first_pid_m + (tile_id % group_size_m)
54+
pid_n = (tile_id % num_pid_in_group) // group_size_m
55+
return pid_m, pid_n
56+
57+
58+
@triton.autotune(
59+
configs=get_cuda_autotune_config(),
60+
key=["M", "N", "K"],
61+
)
62+
@triton.jit
63+
def matmul_kernel_tma_ws_blackwell(
64+
a_desc,
65+
b_desc,
66+
c_desc,
67+
M,
68+
N,
69+
K,
70+
BLOCK_SIZE_M: tl.constexpr,
71+
BLOCK_SIZE_N: tl.constexpr,
72+
BLOCK_SIZE_K: tl.constexpr, #
73+
GROUP_SIZE_M: tl.constexpr, #
74+
NUM_SMEM_BUFFERS: tl.constexpr, #
75+
NUM_TMEM_BUFFERS: tl.constexpr, #
76+
NUM_SMS: tl.constexpr, #
77+
EPILOGUE_SUBTILE: tl.constexpr, #
78+
):
79+
# allocate NUM_SMEM_BUFFERS buffers
80+
buffers_A = tlx.local_alloc(
81+
(BLOCK_SIZE_M, BLOCK_SIZE_K), tl.float16, NUM_SMEM_BUFFERS
82+
)
83+
buffers_B = tlx.local_alloc(
84+
(BLOCK_SIZE_K, BLOCK_SIZE_N), tl.float16, NUM_SMEM_BUFFERS
85+
)
86+
# use multiple TMEM buffers to overlap MMA and epilogue
87+
tmem_buffers = tlx.local_alloc(
88+
(BLOCK_SIZE_M, BLOCK_SIZE_N),
89+
tl.float32,
90+
NUM_TMEM_BUFFERS,
91+
tlx.storage_kind.tmem,
92+
)
93+
94+
# allocate barriers
95+
smem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1)
96+
smem_full_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1)
97+
tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1)
98+
tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1)
99+
100+
with tlx.async_tasks():
101+
with tlx.async_task("default"): # producer, TMA load
102+
# common code duplicated for each region to avoid SMEM overhead
103+
start_pid = tl.program_id(axis=0)
104+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
105+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
106+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
107+
num_tiles = num_pid_m * num_pid_n
108+
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
109+
# end of common code
110+
111+
load_phase = 0 # the current phase of TMA load
112+
# we virtually "flatten" the two layer loop as if we're performing tma loads on
113+
# one big list of data
114+
processed_k_iters = 0
115+
for tile_id in range(start_pid, num_tiles, NUM_SMS):
116+
pid_m, pid_n = _compute_pid(
117+
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M
118+
)
119+
offs_am = pid_m * BLOCK_SIZE_M
120+
offs_bn = pid_n * BLOCK_SIZE_N
121+
122+
for k in range(0, k_tiles):
123+
# processed_k_iters + k means we use the immediate next buffer slot of tile_id x when we start tile_id x+1
124+
buf = (processed_k_iters + k) % NUM_SMEM_BUFFERS
125+
# wait for previous phase(round) of dot for this buf
126+
tlx.barrier_wait(smem_empty_bars[buf], load_phase ^ 1)
127+
# buffer is now ready to be used again
128+
offs_k = k * BLOCK_SIZE_K
129+
tlx.barrier_expect_bytes(
130+
smem_full_bars[buf],
131+
2 * (BLOCK_SIZE_M + BLOCK_SIZE_N) * BLOCK_SIZE_K,
132+
) # float16
133+
tlx.async_descriptor_load(
134+
a_desc, buffers_A[buf], [offs_am, offs_k], smem_full_bars[buf]
135+
)
136+
tlx.async_descriptor_load(
137+
b_desc, buffers_B[buf], [offs_k, offs_bn], smem_full_bars[buf]
138+
)
139+
# flip phase at the end of a round
140+
load_phase = load_phase ^ (buf == NUM_SMEM_BUFFERS - 1)
141+
processed_k_iters += k_tiles
142+
with tlx.async_task(num_warps=1, num_regs=232): # MMA consumer
143+
# common code duplicated for each region to avoid SMEM overhead
144+
start_pid = tl.program_id(axis=0)
145+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
146+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
147+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
148+
num_tiles = num_pid_m * num_pid_n
149+
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
150+
# end of common code
151+
152+
dot_phase = 0 # the current phase of dot op
153+
tmem_write_phase = 1 # sync between epilogue consumer and MMA consumer
154+
cur_tmem_buf = 0
155+
156+
processed_k_iters = 0
157+
for tile_id in range(start_pid, num_tiles, NUM_SMS):
158+
pid_m, pid_n = _compute_pid(
159+
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M
160+
)
161+
offs_am = pid_m * BLOCK_SIZE_M
162+
offs_bn = pid_n * BLOCK_SIZE_N
163+
164+
# wait epilogue consumer to be done with the buffer before reusing it
165+
tlx.barrier_wait(tmem_empty_bars[cur_tmem_buf], tmem_write_phase)
166+
# flip phase at the end of a round of using TMEM barriers
167+
tmem_write_phase = tmem_write_phase ^ (
168+
cur_tmem_buf == NUM_TMEM_BUFFERS - 1
169+
)
170+
171+
# now iterate along K to compute result for the block
172+
for k in range(0, k_tiles):
173+
# processed_k_iters + k means we use the immediate next buffer slot of tile_id x when we start tile_id x+1
174+
buf = (processed_k_iters + k) % NUM_SMEM_BUFFERS
175+
# wait for current phase(round) of load for this buf
176+
tlx.barrier_wait(smem_full_bars[buf], dot_phase)
177+
# buffer is now ready with loaded data, tlx.async_dot will signal `mBarrier` when done
178+
tlx.async_dot(
179+
buffers_A[buf],
180+
buffers_B[buf],
181+
tmem_buffers[cur_tmem_buf],
182+
use_acc=k > 0,
183+
mBarriers=[smem_empty_bars[buf]],
184+
out_dtype=tl.float32,
185+
)
186+
# flip phase at the end of a round
187+
dot_phase = dot_phase ^ (buf == NUM_SMEM_BUFFERS - 1)
188+
189+
# wait for last mma to complete
190+
last_buf = (processed_k_iters + k_tiles - 1) % NUM_SMEM_BUFFERS
191+
# in case phase was flipped, we should use the phase value when dot op was issued
192+
last_dot_phase = dot_phase ^ (last_buf == NUM_SMEM_BUFFERS - 1)
193+
tlx.barrier_wait(smem_empty_bars[last_buf], last_dot_phase)
194+
195+
# done filling this buffer, signal epilogue consumer
196+
tlx.barrier_arrive(tmem_full_bars[cur_tmem_buf], 1)
197+
198+
# possibly enter next iteration (next tile) without waiting for epilogue
199+
cur_tmem_buf = (cur_tmem_buf + 1) % NUM_TMEM_BUFFERS
200+
processed_k_iters += k_tiles
201+
202+
with tlx.async_task(num_warps=4, num_regs=232): # epilogue consumer
203+
# common code duplicated for each region to avoid SMEM overhead
204+
start_pid = tl.program_id(axis=0)
205+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
206+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
207+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
208+
num_tiles = num_pid_m * num_pid_n
209+
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
210+
# end of common code
211+
212+
tmem_read_phase = 0
213+
cur_tmem_buf = 0
214+
215+
for tile_id in range(start_pid, num_tiles, NUM_SMS):
216+
pid_m, pid_n = _compute_pid(
217+
tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M
218+
)
219+
offs_am = pid_m * BLOCK_SIZE_M
220+
offs_bn = pid_n * BLOCK_SIZE_N
221+
222+
tlx.barrier_wait(tmem_full_bars[cur_tmem_buf], tmem_read_phase)
223+
# flip phase at the end of a round of using TMEM barriers
224+
tmem_read_phase = tmem_read_phase ^ (
225+
cur_tmem_buf == NUM_TMEM_BUFFERS - 1
226+
)
227+
228+
# load the result from TMEM to registers
229+
acc_tmem = tmem_buffers[cur_tmem_buf]
230+
231+
if EPILOGUE_SUBTILE:
232+
# We load/store the result half by half to reduce SMEM pressure
233+
acc_tmem_subslice1 = tlx.subslice(acc_tmem, 0, BLOCK_SIZE_N // 2)
234+
result = tlx.local_load(acc_tmem_subslice1)
235+
c = result.to(tl.float16)
236+
c_desc.store([offs_am, offs_bn], c)
237+
238+
acc_tmem_subslice2 = tlx.subslice(
239+
acc_tmem, BLOCK_SIZE_N // 2, BLOCK_SIZE_N // 2
240+
)
241+
result = tlx.local_load(acc_tmem_subslice2)
242+
c = result.to(tl.float16)
243+
c_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], c)
244+
else:
245+
result = tlx.local_load(acc_tmem)
246+
c = result.to(tl.float16)
247+
c_desc.store([offs_am, offs_bn], c)
248+
249+
# done storing this buffer, signal MMA consumer to resume writing to it
250+
tlx.barrier_arrive(tmem_empty_bars[cur_tmem_buf], 1)
251+
252+
cur_tmem_buf = (cur_tmem_buf + 1) % NUM_TMEM_BUFFERS
253+
254+
255+
def tlx_matmul(a, b):
256+
# Check constraints.
257+
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
258+
assert a.is_contiguous(), "Matrix A must be contiguous"
259+
M, K = a.shape
260+
K, N = b.shape
261+
# Allocates output.
262+
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
263+
264+
# A dummy block value that will be overwritten when we have the real block size
265+
dummy_block = [1, 1]
266+
a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
267+
b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
268+
c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
269+
270+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
271+
272+
# Persistent kernel to have thread block resident in SM as long as possible
273+
grid = lambda META: (
274+
min(
275+
NUM_SMS,
276+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
277+
),
278+
)
279+
matmul_kernel_tma_ws_blackwell[grid](
280+
a_desc,
281+
b_desc,
282+
c_desc, #
283+
M,
284+
N,
285+
K, #
286+
NUM_SMS=NUM_SMS, #
287+
)
288+
return c

tritonbench/utils/triton_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# utils to identify triton versions
22

33
import triton.language as tl
4+
import functools
5+
import importlib.util
46

57

68
class AsyncTaskContext:
@@ -34,3 +36,15 @@ def has_new_tma():
3436
import triton.language as tl
3537

3638
return hasattr(triton, "set_allocator") and hasattr(tl, "make_tensor_descriptor")
39+
40+
41+
@functools.lru_cache
42+
def has_tlx():
43+
"""
44+
Returns whether TLX is supported.
45+
"""
46+
# TODO: Replace with the variant in compat once that's
47+
# available in OSS.
48+
tlx_module = "triton.language.extra.tlx"
49+
spec = importlib.util.find_spec(tlx_module)
50+
return spec is not None

0 commit comments

Comments
 (0)