Skip to content

Commit aac457e

Browse files
authored
[TUTORIAL] Replace legacy host side TMA with TensorDescriptor (#6465)
1 parent 0fd62f6 commit aac457e

File tree

3 files changed

+56
-153
lines changed

3 files changed

+56
-153
lines changed

python/triton/tools/experimental_descriptor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,11 @@ class TensorDescriptor:
4646
shape: List[int]
4747
strides: List[int]
4848
block_shape: List[int]
49+
50+
def from_tensor(tensor: Any, block_shape: List[int]):
51+
return TensorDescriptor(
52+
tensor,
53+
tensor.shape,
54+
tensor.stride(),
55+
block_shape,
56+
)

python/tutorials/06-fused-attention.py

Lines changed: 21 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import pytest
1717
import torch
18-
import triton.tools.experimental_descriptor
18+
from triton.tools.experimental_descriptor import TensorDescriptor
1919

2020
import triton
2121
import triton.language as tl
@@ -43,68 +43,6 @@ def supports_tma():
4343
print("TMA benchmarks will be running without grid constant TMA descriptor.", )
4444

4545

46-
# TmaAutoTuneHelper used in htyu's PR #5622
47-
class TmaAutoTuneHelper:
48-
49-
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
50-
class KernelParamWrapper:
51-
52-
def __init__(self, desc):
53-
self.desc = desc
54-
55-
def tma_desc_cpu_ptr(self):
56-
return self.desc.data_ptr()
57-
58-
TMA_SIZE = 128
59-
60-
def __init__(self):
61-
self.fill_1d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_1d_tma_descriptor)
62-
self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor)
63-
if HAS_TMA_DESC:
64-
self.descriptors = {}
65-
else:
66-
self.cuda_descriptors = {}
67-
68-
# Call this method outside of the lambda function for grid size
69-
def init_tma_descriptor(self, name):
70-
if HAS_TMA_DESC:
71-
self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8)
72-
else:
73-
self.cuda_descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8)
74-
75-
# Call this method inside the lambda function for grid size
76-
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
77-
if HAS_TMA_DESC:
78-
desc_x = self.descriptors[name]
79-
assert desc_x.data_ptr() % 64 == 0
80-
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, desc_x.data_ptr())
81-
else:
82-
desc_x = self.cuda_descriptors[name]
83-
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
84-
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr())
85-
desc_x.copy_(buf_x, non_blocking=True)
86-
87-
# Call this method inside the lambda function for grid size
88-
def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size):
89-
if HAS_TMA_DESC:
90-
desc_x = self.descriptors[name]
91-
assert desc_x.data_ptr() % 64 == 0
92-
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr())
93-
else:
94-
desc_x = self.cuda_descriptors[name]
95-
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
96-
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr())
97-
desc_x.copy_(buf_x, non_blocking=True)
98-
99-
def get_tma_descriptor_kernel_param(self, name):
100-
if HAS_TMA_DESC:
101-
assert self.descriptors[name] is not None
102-
return self.KernelParamWrapper(self.descriptors[name])
103-
else:
104-
assert self.cuda_descriptors[name] is not None
105-
return self.cuda_descriptors[name]
106-
107-
10846
@triton.jit
10947
def _attn_fwd_inner(acc, l_i, m_i, q, #
11048
K_block_ptr, V_block_ptr, #
@@ -179,7 +117,7 @@ def _attn_fwd_inner_tma(acc, l_i, m_i, q, #
179117
for start_n in range(lo, hi, BLOCK_N):
180118
start_n = tl.multiple_of(start_n, BLOCK_N)
181119
# -- compute qk ----
182-
k = tl._experimental_descriptor_load(desc_k, [offsetkv_y, 0], [BLOCK_N, HEAD_DIM], dtype).T
120+
k = desc_k.load([offsetkv_y, 0]).T
183121
qk = tl.dot(q, k)
184122
if STAGE == 2:
185123
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
@@ -197,7 +135,7 @@ def _attn_fwd_inner_tma(acc, l_i, m_i, q, #
197135
# -- update output accumulator --
198136
acc = acc * alpha[:, None]
199137
# update acc
200-
v = tl._experimental_descriptor_load(desc_v, [offsetkv_y, 0], [BLOCK_N, HEAD_DIM], dtype)
138+
v = desc_v.load([offsetkv_y, 0])
201139
p = p.to(dtype)
202140
# note that this non transposed v for FP8 is only supported on Blackwell
203141
acc = tl.dot(p, v, acc)
@@ -319,11 +257,21 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
319257
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
320258

321259

260+
def _tma_pre_hook(nargs):
261+
BLOCK_M = nargs["BLOCK_M"]
262+
BLOCK_N = nargs["BLOCK_N"]
263+
HEAD_DIM = nargs["HEAD_DIM"]
264+
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM]
265+
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
266+
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
267+
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM]
268+
269+
322270
# We don't run auto-tuning every time to keep the tutorial fast. Keeping
323271
# the code below and commenting out the equivalent parameters is convenient for
324272
# re-tuning.
325273
configs_tma = [
326-
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
274+
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_tma_pre_hook) \
327275
for BM in [64, 128]\
328276
for BN in [32, 64, 128]\
329277
for s in [2, 3, 4, 6]\
@@ -369,7 +317,7 @@ def _attn_fwd_tma(sm_scale, M, #
369317
qk_scale = sm_scale
370318
qk_scale *= 1.44269504 # 1/log(2)
371319
# load q: it will stay in SRAM throughout
372-
q = tl._experimental_descriptor_load(desc_q, [qo_offset_y, 0], [BLOCK_M, HEAD_DIM], dtype)
320+
q = desc_q.load([qo_offset_y, 0])
373321
# stage 1: off-band
374322
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
375323
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
@@ -395,7 +343,7 @@ def _attn_fwd_tma(sm_scale, M, #
395343
acc = acc / l_i[:, None]
396344
m_ptrs = M + off_hz * N_CTX + offs_m
397345
tl.store(m_ptrs, m_i)
398-
tl._experimental_descriptor_store(desc_o, acc.to(dtype), [qo_offset_y, 0])
346+
desc_o.store([qo_offset_y, 0], acc.to(dtype))
399347

400348

401349
@triton.jit
@@ -670,34 +618,15 @@ def forward(ctx, q, k, v, causal, sm_scale, USE_TMA=True):
670618
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor
671619
y_dim = q.shape[0] * q.shape[1] * q.shape[2]
672620

673-
desc_helper = TmaAutoTuneHelper()
674-
desc_helper.init_tma_descriptor("q")
675-
desc_helper.init_tma_descriptor("v")
676-
desc_helper.init_tma_descriptor("k")
677-
desc_helper.init_tma_descriptor("o")
621+
dummy_block = [1, 1]
622+
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
623+
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
624+
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
625+
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
678626

679627
def grid(META):
680-
nonlocal desc_helper
681-
682-
desc_helper.fill_2d_tma_descriptor("q", q.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_M"], HEAD_DIM_K,
683-
q.element_size())
684-
685-
desc_helper.fill_2d_tma_descriptor("v", v.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_N"], HEAD_DIM_K,
686-
v.element_size())
687-
688-
desc_helper.fill_2d_tma_descriptor("k", k.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_N"], HEAD_DIM_K,
689-
k.element_size())
690-
691-
desc_helper.fill_2d_tma_descriptor("o", o.data_ptr(), y_dim, HEAD_DIM_K, META["BLOCK_M"], HEAD_DIM_K,
692-
o.element_size())
693-
694628
return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
695629

696-
desc_q = desc_helper.get_tma_descriptor_kernel_param("q")
697-
desc_v = desc_helper.get_tma_descriptor_kernel_param("v")
698-
desc_k = desc_helper.get_tma_descriptor_kernel_param("k")
699-
desc_o = desc_helper.get_tma_descriptor_kernel_param("o")
700-
701630
ctx.grid = grid
702631
_attn_fwd_tma[grid](
703632
sm_scale, M, #

python/tutorials/10-block-scaled-matmul.py

Lines changed: 27 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
import triton.language as tl
7373
import triton.tools.experimental_descriptor
7474
import triton.profiler as proton
75-
from triton.tools.experimental_descriptor import TmaDescKernelParam
75+
from triton.tools.experimental_descriptor import TensorDescriptor
7676
from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor
7777

7878

@@ -106,7 +106,7 @@ def _matmul_launch_metadata(grid, kernel, args):
106106
@triton.jit(launch_metadata=_matmul_launch_metadata)
107107
def block_scaled_matmul_kernel( #
108108
a_desc, a_scale, #
109-
b_desc_or_tensor, b_scale, #
109+
b_desc, b_scale, #
110110
c_desc, #
111111
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, #
112112
stride_sk: tl.constexpr, stride_sb: tl.constexpr, stride_sc: tl.constexpr, stride_sd: tl.constexpr,
@@ -120,16 +120,6 @@ def block_scaled_matmul_kernel( #
120120
NUM_STAGES: tl.constexpr, #
121121
USE_2D_SCALE_LOAD: tl.constexpr): #
122122

123-
if ELEM_PER_BYTE_A == 1:
124-
dtype_a = tl.float8e4nv
125-
elif ELEM_PER_BYTE_A == 2:
126-
dtype_a = tl.dtype("uint8")
127-
128-
if ELEM_PER_BYTE_B == 1:
129-
dtype_b = tl.float8e4nv
130-
elif ELEM_PER_BYTE_B == 2:
131-
dtype_b = tl.dtype("uint8")
132-
133123
if output_type == 0:
134124
output_dtype = tl.float32
135125
elif output_type == 1:
@@ -152,23 +142,6 @@ def block_scaled_matmul_kernel( #
152142

153143
MIXED_PREC: tl.constexpr = ELEM_PER_BYTE_A == 1 and ELEM_PER_BYTE_B == 2
154144

155-
if MIXED_PREC:
156-
b_desc = tl.make_tensor_descriptor(
157-
b_desc_or_tensor,
158-
shape=[N, K // ELEM_PER_BYTE_B],
159-
strides=[K // ELEM_PER_BYTE_B, 1],
160-
block_shape=[BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B],
161-
)
162-
else:
163-
b_desc = b_desc_or_tensor
164-
tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [b_desc], dtype=tl.int32,
165-
is_pure=False, pack=1)
166-
167-
tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [a_desc], dtype=tl.int32, is_pure=False,
168-
pack=1)
169-
tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [c_desc], dtype=tl.int32, is_pure=False,
170-
pack=1)
171-
172145
# For now it is recommended to use 2D scale loads for better performance.
173146
# In the future we will bring additional optimizations to either allow 5D loads,
174147
# the use of TMAs for scale factors, or both.
@@ -192,15 +165,8 @@ def block_scaled_matmul_kernel( #
192165

193166
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
194167
for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
195-
a = tl._experimental_descriptor_load(a_desc, [offs_am, offs_k_a], [BLOCK_M, BLOCK_K // ELEM_PER_BYTE_A],
196-
dtype_a)
197-
198-
if MIXED_PREC:
199-
b = b_desc.load([offs_bn, offs_k_b])
200-
else:
201-
b = tl._experimental_descriptor_load(b_desc, [offs_bn, offs_k_b], [BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B],
202-
dtype_b)
203-
168+
a = a_desc.load([offs_am, offs_k_a])
169+
b = b_desc.load([offs_bn, offs_k_b])
204170
scale_a = tl.load(a_scale_ptr)
205171
scale_b = tl.load(b_scale_ptr)
206172
if USE_2D_SCALE_LOAD:
@@ -221,10 +187,10 @@ def block_scaled_matmul_kernel( #
221187
a_scale_ptr += (BLOCK_K // VEC_SIZE // 4) * stride_sb
222188
b_scale_ptr += (BLOCK_K // VEC_SIZE // 4) * stride_sb
223189

224-
tl._experimental_descriptor_store(c_desc, accumulator.to(output_dtype), [offs_am, offs_bn])
190+
c_desc.store([offs_am, offs_bn], accumulator.to(output_dtype))
225191

226192

227-
def block_scaled_matmul(a_desc, a_scale, b_desc_or_tensor, b_scale, dtype_dst, M, N, K, configs):
193+
def block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, dtype_dst, M, N, K, configs):
228194
output = torch.empty((M, N), dtype=dtype_dst, device="cuda")
229195
if dtype_dst == torch.float32:
230196
dtype_dst = 0
@@ -235,11 +201,12 @@ def block_scaled_matmul(a_desc, a_scale, b_desc_or_tensor, b_scale, dtype_dst, M
235201
else:
236202
raise ValueError(f"Unsupported dtype: {dtype_dst}")
237203

238-
c_desc = TmaDescKernelParam(output.data_ptr(), output.shape, [configs["BLOCK_SIZE_M"], configs["BLOCK_SIZE_N"]],
239-
output.element_size())
204+
BLOCK_M = configs["BLOCK_SIZE_M"]
205+
BLOCK_N = configs["BLOCK_SIZE_N"]
206+
c_desc = TensorDescriptor.from_tensor(output, [BLOCK_M, BLOCK_N])
240207

241-
grid = (triton.cdiv(M, configs["BLOCK_SIZE_M"]) * triton.cdiv(N, configs["BLOCK_SIZE_N"]), 1)
242-
block_scaled_matmul_kernel[grid](a_desc, a_scale, b_desc_or_tensor, b_scale, c_desc, M, N, K, a_scale.stride(0),
208+
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
209+
block_scaled_matmul_kernel[grid](a_desc, a_scale, b_desc, b_scale, c_desc, M, N, K, a_scale.stride(0),
243210
a_scale.stride(1), a_scale.stride(2), a_scale.stride(3), dtype_dst,
244211
configs["ELEM_PER_BYTE_A"], configs["ELEM_PER_BYTE_B"], configs["VEC_SIZE"],
245212
configs["BLOCK_SIZE_M"], configs["BLOCK_SIZE_N"], configs["BLOCK_SIZE_K"],
@@ -284,12 +251,17 @@ def initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference
284251

285252
b_ref = b_ref.to(torch.float32).T
286253

287-
a_desc = TmaDescKernelParam(a.data_ptr(), a.shape, [BLOCK_M, BLOCK_K // ELEM_PER_BYTE_A], 1)
254+
a_desc = TensorDescriptor.from_tensor(a, [BLOCK_M, BLOCK_K // ELEM_PER_BYTE_A])
288255

289256
if block_scale_type == "mixed":
290-
b_desc_or_tensor = b
257+
b_desc = TensorDescriptor(
258+
b,
259+
shape=[N, K // ELEM_PER_BYTE_B],
260+
strides=[K // ELEM_PER_BYTE_B, 1],
261+
block_shape=[BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B],
262+
)
291263
else:
292-
b_desc_or_tensor = TmaDescKernelParam(b.data_ptr(), b.shape, [BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B], 1)
264+
b_desc = TensorDescriptor.from_tensor(b, [BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B])
293265

294266
epsilon = 1e-8
295267
a_scale = torch.rand((M // 128, K // VEC_SIZE // 4, 32, 4, 4), device=device) + epsilon
@@ -327,7 +299,7 @@ def unpack_scale(packed):
327299
"ELEM_PER_BYTE_B": ELEM_PER_BYTE_B,
328300
"VEC_SIZE": VEC_SIZE,
329301
}
330-
return a_desc, a_scale, b_desc_or_tensor, b_scale, configs, reference
302+
return a_desc, a_scale, b_desc, b_scale, configs, reference
331303

332304

333305
def validate_block_scaled(M, N, K, block_scale_type="nvfp4"):
@@ -340,9 +312,9 @@ def alloc_fn(size: int, align: int, _):
340312
# TMA load for mixed-precision fp4 is supported only by device TMA.
341313
triton.set_allocator(alloc_fn)
342314

343-
a_desc, a_scale, b_desc_or_tensor, b_scale, configs, reference = initialize_block_scaled(
344-
M, N, K, block_scale_type, compute_reference=True)
345-
output = block_scaled_matmul(a_desc, a_scale, b_desc_or_tensor, b_scale, torch.float16, M, N, K, configs)
315+
a_desc, a_scale, b_desc, b_scale, configs, reference = initialize_block_scaled(M, N, K, block_scale_type,
316+
compute_reference=True)
317+
output = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)
346318
torch.testing.assert_close(reference, output.to(torch.float32), atol=1e-3, rtol=1e-3)
347319
print(f"✅ (pass {block_scale_type})")
348320

@@ -353,19 +325,13 @@ def bench_block_scaled(K, block_scale_type="nvfp4", reps=10):
353325
N = 8192
354326
print(f"Problem Shape = {M}x{N}x{K}")
355327

356-
def alloc_fn(size: int, align: int, _):
357-
return torch.empty(size, dtype=torch.int8, device="cuda")
358-
359-
if block_scale_type == "mixed":
360-
triton.set_allocator(alloc_fn)
361-
362-
a_desc, a_scale, b_desc_or_tensor, b_scale, configs, _ = initialize_block_scaled(
363-
M, N, K, block_scale_type, compute_reference=False)
364-
_ = block_scaled_matmul(a_desc, a_scale, b_desc_or_tensor, b_scale, torch.float16, M, N, K, configs)
328+
a_desc, a_scale, b_desc, b_scale, configs, _ = initialize_block_scaled(M, N, K, block_scale_type,
329+
compute_reference=False)
330+
_ = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)
365331

366332
proton.activate(0)
367333
for _ in range(reps):
368-
_ = block_scaled_matmul(a_desc, a_scale, b_desc_or_tensor, b_scale, torch.float16, M, N, K, configs)
334+
_ = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)
369335
proton.deactivate(0)
370336
print("Done benchmarking")
371337

0 commit comments

Comments
 (0)