Skip to content

Commit 5eee385

Browse files
authored
Support for host tensor descriptors on devices that don't support TMA descriptors (#6811)
This implements host side tensor descriptors by updating the driver to recognise the decomposition of the tensor descriptors into a tensor pointer, shape and strides. If a cuda device would prefer to keep the tensor descriptor it may add tensor descriptor metadata to the kernel; this is currently the default for cuda devices supporting TMA descriptors.
1 parent 607c50c commit 5eee385

File tree

4 files changed

+219
-131
lines changed

4 files changed

+219
-131
lines changed

python/test/unit/cuda/test_tensor_descriptor.py

Lines changed: 1 addition & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,10 @@
1-
import pytest
21
import torch
32

43
import triton
5-
import triton.language as tl
6-
from triton._internal_testing import is_interpreter, numpy_random, to_triton, requires_tma, unwrap_tensor, tma_dtypes
4+
from triton._internal_testing import requires_tma
75
from triton.tools.tensor_descriptor import TensorDescriptor
86

97

10-
@requires_tma
11-
@pytest.mark.interpreter()
12-
@pytest.mark.parametrize("dtype_str", tma_dtypes)
13-
@pytest.mark.parametrize("num_ctas", [1, 2])
14-
@pytest.mark.parametrize("M_BLOCK,N_BLOCK", [(2, 16), (8, 16), (8, 32), (8, 128), (512, 32), (1, 1024)])
15-
def test_host_tensor_descriptor_load(dtype_str, num_ctas, M_BLOCK, N_BLOCK):
16-
17-
@triton.jit(debug=True)
18-
def kernel(out_ptr, desc, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
19-
assert desc.shape[0] == M
20-
assert desc.shape[1] == N
21-
assert desc.strides[0] == N
22-
assert desc.strides[1] == 1
23-
assert desc.block_shape == [M_BLOCK, N_BLOCK]
24-
block = desc.load([M_BLOCK, 2 * N_BLOCK])
25-
idx = tl.arange(0, M_BLOCK)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :]
26-
tl.store(out_ptr + idx, block)
27-
28-
M, N = M_BLOCK * 3, N_BLOCK * 4
29-
inp = to_triton(numpy_random((M, N), dtype_str), device="cuda", dst_type=dtype_str)
30-
out = inp.new_empty((M_BLOCK, N_BLOCK))
31-
32-
inp_desc = TensorDescriptor(inp, shape=inp.shape, strides=inp.stride(), block_shape=[M_BLOCK, N_BLOCK])
33-
kernel[(1, )](out, inp_desc, M, N, M_BLOCK, N_BLOCK, num_ctas=num_ctas)
34-
35-
expect = unwrap_tensor(inp)[1 * M_BLOCK:2 * M_BLOCK, 2 * N_BLOCK:3 * N_BLOCK]
36-
torch.testing.assert_close(expect, unwrap_tensor(out))
37-
38-
39-
@triton.jit
40-
def matmul_kernel_host_tensor_descriptor(a_desc, b_desc, c_desc):
41-
K = a_desc.shape[1]
42-
BLOCK_M: tl.constexpr = a_desc.block_shape[0]
43-
BLOCK_K: tl.constexpr = a_desc.block_shape[1]
44-
BLOCK_N: tl.constexpr = b_desc.block_shape[1]
45-
46-
pid_m = tl.program_id(axis=0)
47-
pid_n = tl.program_id(axis=1)
48-
offs_am = pid_m * BLOCK_M
49-
offs_bn = pid_n * BLOCK_N
50-
offs_k = 0
51-
52-
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
53-
for k in range(0, tl.cdiv(K, BLOCK_K)):
54-
a = a_desc.load([offs_am, offs_k])
55-
b = b_desc.load([offs_k, offs_bn])
56-
accumulator = tl.dot(a, b, acc=accumulator)
57-
offs_k += BLOCK_K
58-
accumulator = accumulator.to(a_desc.dtype)
59-
c_desc.store([offs_am, offs_bn], accumulator)
60-
61-
62-
@requires_tma
63-
@pytest.mark.interpreter()
64-
@pytest.mark.parametrize("num_ctas", [1, 2])
65-
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, num_stages", [
66-
(128, 128, 16, 1),
67-
(512, 64, 32, 2),
68-
(64, 512, 32, 2),
69-
(128, 128, 16, 4),
70-
(64, 128, 32, 4),
71-
(32, 32, 32, 4),
72-
(256, 128, 32, 4),
73-
])
74-
def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, BLOCK_K):
75-
device = "cuda"
76-
if is_interpreter():
77-
M, N, K = BLOCK_M, BLOCK_N, BLOCK_K
78-
else:
79-
M, N, K = 1024, 512, 256
80-
torch.manual_seed(42)
81-
A = torch.randn((M, K), dtype=torch.float16, device=device)
82-
B = torch.randn((K, N), dtype=torch.float16, device=device)
83-
C = torch.empty((M, N), dtype=torch.float16, device=device)
84-
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), 1)
85-
86-
A_desc = TensorDescriptor(A, A.shape, A.stride(), [BLOCK_M, BLOCK_K])
87-
B_desc = TensorDescriptor(B, B.shape, B.stride(), [BLOCK_K, BLOCK_N])
88-
C_desc = TensorDescriptor(C, C.shape, C.stride(), [BLOCK_M, BLOCK_N])
89-
90-
kernel = matmul_kernel_host_tensor_descriptor[grid](
91-
A_desc,
92-
B_desc,
93-
C_desc, #
94-
num_warps=8,
95-
num_stages=num_stages,
96-
num_ctas=num_ctas,
97-
)
98-
ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16)
99-
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
100-
if is_interpreter():
101-
return
102-
103-
if BLOCK_M >= 64 * num_ctas and BLOCK_N >= 64 and torch.cuda.get_device_capability()[0] == 9:
104-
# TODO: The use of stmatrix for Blackwell is currently not supported.
105-
# Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
106-
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]
107-
108-
1098
@requires_tma
1109
def test_specialization_after_host_tensordesc():
11110

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,8 +1494,6 @@ def test_tensor_descriptor_reduce(kind, descriptor, dtype_str, num_ctas, M_BLOCK
14941494
if not is_native:
14951495
if num_ctas != 1:
14961496
pytest.skip("Multi-CTA not supported")
1497-
if descriptor == "host":
1498-
pytest.skip("NYI: Host side tensor descriptor fallback")
14991497
if is_hip_cdna3() and (kind, dtype_str, M_BLOCK, N_BLOCK) in REDUCE_SKIP_HIP_CDNA3:
15001498
pytest.skip("Broken on rocm")
15011499

@@ -1573,3 +1571,105 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
15731571
expect = REDUCE_OP[kind](inp, out)
15741572
kernel[(grid_m, grid_n)](out_desc, out, inp, M, N, M_BLOCK, N_BLOCK, kind, num_ctas=num_ctas)
15751573
torch.testing.assert_close(expect, unwrap_tensor(out), check_dtype=False)
1574+
1575+
1576+
@pytest.mark.interpreter()
1577+
@pytest.mark.parametrize("dtype_str", tma_dtypes)
1578+
@pytest.mark.parametrize("num_ctas", [1, 2])
1579+
@pytest.mark.parametrize("M_BLOCK,N_BLOCK", [(2, 16), (8, 16), (8, 32), (8, 128)])
1580+
def test_host_tensor_descriptor_load(dtype_str, num_ctas, M_BLOCK, N_BLOCK, device):
1581+
if num_ctas == 2 and (not is_cuda() or torch.cuda.get_device_capability(0)[0] not in (9, 10)):
1582+
pytest.skip("CTAs is unsupported for these cards")
1583+
1584+
@triton.jit(debug=True)
1585+
def kernel(out_ptr, desc, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
1586+
assert desc.shape[0] == M
1587+
assert desc.shape[1] == N
1588+
assert desc.strides[0] == N
1589+
assert desc.strides[1] == 1
1590+
assert desc.block_shape == [M_BLOCK, N_BLOCK]
1591+
block = desc.load([M_BLOCK, 2 * N_BLOCK])
1592+
idx = tl.arange(0, M_BLOCK)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :]
1593+
tl.store(out_ptr + idx, block)
1594+
1595+
M, N = M_BLOCK * 3, N_BLOCK * 4
1596+
inp = to_triton(numpy_random((M, N), dtype_str), device=device, dst_type=dtype_str)
1597+
out = inp.new_empty((M_BLOCK, N_BLOCK))
1598+
1599+
inp_desc = TensorDescriptor(inp, shape=inp.shape, strides=inp.stride(), block_shape=[M_BLOCK, N_BLOCK])
1600+
kernel[(1, )](out, inp_desc, M, N, M_BLOCK, N_BLOCK, num_ctas=num_ctas)
1601+
1602+
expect = unwrap_tensor(inp)[1 * M_BLOCK:2 * M_BLOCK, 2 * N_BLOCK:3 * N_BLOCK]
1603+
torch.testing.assert_close(expect, unwrap_tensor(out))
1604+
1605+
1606+
@triton.jit
1607+
def matmul_kernel_host_tensor_descriptor(a_desc, b_desc, c_desc):
1608+
K = a_desc.shape[1]
1609+
BLOCK_M: tl.constexpr = a_desc.block_shape[0]
1610+
BLOCK_K: tl.constexpr = a_desc.block_shape[1]
1611+
BLOCK_N: tl.constexpr = b_desc.block_shape[1]
1612+
1613+
pid_m = tl.program_id(axis=0)
1614+
pid_n = tl.program_id(axis=1)
1615+
offs_am = pid_m * BLOCK_M
1616+
offs_bn = pid_n * BLOCK_N
1617+
offs_k = 0
1618+
1619+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
1620+
for k in range(0, tl.cdiv(K, BLOCK_K)):
1621+
a = a_desc.load([offs_am, offs_k])
1622+
b = b_desc.load([offs_k, offs_bn])
1623+
accumulator = tl.dot(a, b, acc=accumulator)
1624+
offs_k += BLOCK_K
1625+
accumulator = accumulator.to(a_desc.dtype)
1626+
c_desc.store([offs_am, offs_bn], accumulator)
1627+
1628+
1629+
@pytest.mark.interpreter()
1630+
@pytest.mark.parametrize("num_ctas", [1, 2])
1631+
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, num_stages", [
1632+
(128, 128, 16, 1),
1633+
(256, 64, 32, 2),
1634+
(64, 512, 32, 2),
1635+
(128, 128, 16, 4),
1636+
(64, 128, 32, 4),
1637+
(32, 32, 32, 4),
1638+
(256, 128, 32, 4),
1639+
])
1640+
def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, BLOCK_K, device):
1641+
if num_ctas == 2 and (not is_cuda() or torch.cuda.get_device_capability(0)[0] not in (9, 10)):
1642+
pytest.skip("CTAs is unsupported for these cards")
1643+
1644+
if is_hip() and (BLOCK_M, BLOCK_N, BLOCK_K, num_stages) == (256, 128, 32, 4):
1645+
pytest.skip("Insufficient shared memory on HIP devices")
1646+
1647+
if is_interpreter():
1648+
M, N, K = BLOCK_M, BLOCK_N, BLOCK_K
1649+
else:
1650+
M, N, K = 1024, 512, 256
1651+
torch.manual_seed(42)
1652+
A = torch.randn((M, K), dtype=torch.float16, device=device)
1653+
B = torch.randn((K, N), dtype=torch.float16, device=device)
1654+
C = torch.empty((M, N), dtype=torch.float16, device=device)
1655+
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), 1)
1656+
1657+
A_desc = TensorDescriptor(A, A.shape, A.stride(), [BLOCK_M, BLOCK_K])
1658+
B_desc = TensorDescriptor(B, B.shape, B.stride(), [BLOCK_K, BLOCK_N])
1659+
C_desc = TensorDescriptor(C, C.shape, C.stride(), [BLOCK_M, BLOCK_N])
1660+
1661+
kernel = matmul_kernel_host_tensor_descriptor[grid](
1662+
A_desc,
1663+
B_desc,
1664+
C_desc, #
1665+
num_warps=8,
1666+
num_stages=num_stages,
1667+
num_ctas=num_ctas,
1668+
)
1669+
ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16)
1670+
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
1671+
1672+
if BLOCK_M >= 64 * num_ctas and BLOCK_N >= 64 and is_cuda() and torch.cuda.get_device_capability()[0] == 9:
1673+
# TODO: The use of stmatrix for Blackwell is currently not supported.
1674+
# Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
1675+
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm["ptx"]

third_party/amd/backend/driver.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import subprocess
55
import sysconfig
66
import tempfile
7+
import re
78
from pathlib import Path
89
from triton.runtime.build import _build
910
from triton import knobs
1011
from triton.runtime.cache import get_cache_manager
1112
from triton.backends.compiler import GPUTarget
1213
from triton.backends.driver import GPUDriver, platform_key
14+
from triton.tools.tensor_descriptor import TensorDescriptor
1315

1416
dirname = os.path.dirname(os.path.realpath(__file__))
1517
include_dir = [os.path.join(dirname, "include")]
@@ -193,8 +195,37 @@ def ty_to_cpp(ty):
193195
}[ty]
194196

195197

198+
_BASE_ARGS_FORMAT = "piiiKKOOOO"
199+
200+
196201
def make_launcher(constants, signature, warp_size):
197202

203+
def _expand_signature(signature):
204+
output = []
205+
# Expand tensor descriptor arguments into base pointer, shape, and
206+
# strides
207+
for sig in signature:
208+
if isinstance(sig, str) and sig.startswith("tensordesc"):
209+
ndim = sig.count(",") + 1
210+
dtype = re.match("tensordesc<([^[>]*)", sig).group()
211+
212+
output.append("*" + dtype)
213+
for _ in range(2 * ndim):
214+
output.append("i64")
215+
# Currently the host side tensor descriptors get passed in as a
216+
# tensor desc, shape, and strides. We have no way to use these
217+
# shape and strides when processing tensor descriptors which is
218+
# why we provide our own decomposition above. Sadly this means
219+
# we have to pass the shape and strides twice.
220+
for _ in range(ndim):
221+
output.append("i32")
222+
for _ in range(ndim):
223+
output.append("i64")
224+
else:
225+
output.append(sig)
226+
227+
return output
228+
198229
def _serialize_signature(sig):
199230
if isinstance(sig, tuple):
200231
return ','.join(map(_serialize_signature, sig))
@@ -232,8 +263,10 @@ def format_of(ty):
232263
"uint64_t": "K",
233264
}[ty_to_cpp(ty)]
234265

266+
signature = {idx: s for idx, s in enumerate(_expand_signature(signature.values()))}
267+
235268
args_format = ''.join([format_of(ty) for ty in signature.values()])
236-
format = "piiiKKOOOO" + args_format
269+
format = _BASE_ARGS_FORMAT + args_format
237270
signature = ','.join(map(_serialize_signature, signature.values()))
238271
signature = list(filter(bool, signature.split(',')))
239272
signature = {i: s for i, s in enumerate(signature)}
@@ -494,6 +527,31 @@ def format_of(ty):
494527
return src
495528

496529

530+
def wrap_handle_tensor_descriptor(launcher):
531+
"""
532+
Replace all tensor descriptors with the base ptr, shape, and strides
533+
"""
534+
535+
def inner(*args):
536+
meta_args = args[:len(_BASE_ARGS_FORMAT)]
537+
raw_kernel_args = args[len(_BASE_ARGS_FORMAT):]
538+
final_args = []
539+
for arg in raw_kernel_args:
540+
if isinstance(arg, TensorDescriptor):
541+
# Currently the host side tensor descriptors get decomposed in
542+
# the frontend to tensor desc, shape, and strides. We have no
543+
# way to use these shape and strides when processing tensor
544+
# descriptors which is why we provide our own decomposition
545+
# above. Sadly this means we have to pass the shape and strides
546+
# twice.
547+
final_args.extend([arg.base, *arg.shape, *arg.strides, *arg.shape, *arg.strides])
548+
else:
549+
final_args.append(arg)
550+
return launcher(*meta_args, *final_args)
551+
552+
return inner
553+
554+
497555
class HIPLauncher(object):
498556

499557
def __init__(self, src, metadata):
@@ -503,7 +561,9 @@ def __init__(self, src, metadata):
503561
signature = {idx: value for idx, value in src.signature.items()}
504562
src = make_launcher(constants, signature, metadata.warp_size)
505563
mod = compile_module_from_src(src, "__triton_launcher")
506-
self.launch = mod.launch
564+
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
565+
566+
self.launch = wrap_handle_tensor_descriptor(mod.launch) if has_tensor_desc_arg else mod.launch
507567
self.launch_cooperative_grid = metadata.launch_cooperative_grid
508568

509569
def __call__(self, *args):

0 commit comments

Comments
 (0)