Skip to content

Commit 4e7dc91

Browse files
authored
[AMD][GLUON] Expose WMMA for RDNA3 and RDNA4 (triton-lang#8111)
This PR exposes WMMA instruction for RDNA3 and RDNA4. Similar to `amd.cdna4.mfma`, both are wrapper around `tl.dot` with additional checks of input layouts. Add a wmma kernel which has been tested on Radeon RX 9070 XT (RDNA4) and Radeon RX 7900 XTX (RDNA3).
1 parent 8a5d1ee commit 4e7dc91

File tree

5 files changed

+210
-1
lines changed

5 files changed

+210
-1
lines changed

python/test/gluon/test_core.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from triton._internal_testing import (
99
is_ampere_or_newer,
1010
is_blackwell,
11+
is_hip_gfx11,
12+
is_hip_gfx12,
1113
is_hip_cdna3,
1214
is_hip_cdna4,
1315
is_hopper_or_newer,
@@ -202,6 +204,67 @@ def kernel(a_ptr, b_ptr, use_buffer_load: ttgl.constexpr):
202204
assert 'vmcnt(0)' in pgm.asm['amdgcn']
203205

204206

207+
@pytest.mark.skipif(not (is_hip_gfx11() or is_hip_gfx12()), reason="Requires RDNA3 or RDNA4")
208+
@pytest.mark.parametrize("M, N, K", [(64, 64, 64)])
209+
@pytest.mark.parametrize("in_dtype", ['float16', 'bfloat16'])
210+
def test_amd_wmma(M, N, K, in_dtype):
211+
212+
@gluon.jit
213+
def kernel(a_ptr, b_ptr, c_ptr, #
214+
stride_am, stride_ak, #
215+
stride_bk, stride_bn, #
216+
stride_cm, stride_cn, #
217+
BLOCK_SIZE_M: ttgl.constexpr, #
218+
BLOCK_SIZE_N: ttgl.constexpr, #
219+
BLOCK_SIZE_K: ttgl.constexpr, #
220+
BLOCKED_LAYOUT: ttgl.constexpr, #
221+
WMMA_LAYOUT: ttgl.constexpr, #
222+
K_WIDTH: ttgl.constexpr):
223+
offs_am = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT))
224+
offs_bn = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT))
225+
226+
offs_ak = ttgl.arange(0, BLOCK_SIZE_K, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT))
227+
offs_bk = ttgl.arange(0, BLOCK_SIZE_K, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT))
228+
229+
offs_a = offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak
230+
offs_b = offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn
231+
232+
a = ttgl.load(a_ptr + offs_a)
233+
b = ttgl.load(b_ptr + offs_b)
234+
235+
a = ttgl.convert_layout(a, layout=ttgl.DotOperandLayout(0, WMMA_LAYOUT, K_WIDTH))
236+
b = ttgl.convert_layout(b, layout=ttgl.DotOperandLayout(1, WMMA_LAYOUT, K_WIDTH))
237+
238+
acc = ttgl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], ttgl.float32, WMMA_LAYOUT)
239+
if WMMA_LAYOUT.version == 1:
240+
c = ttgl.amd.rdna3.wmma(a, b, acc)
241+
else:
242+
ttgl.static_assert(WMMA_LAYOUT.version == 2, "WMMA_LAYOUT.version must be 1 or 2")
243+
c = ttgl.amd.rdna4.wmma(a, b, acc)
244+
c = c.to(a_ptr.dtype.element_ty)
245+
246+
offs_cm = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, WMMA_LAYOUT))
247+
offs_cn = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, WMMA_LAYOUT))
248+
offs_c = offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
249+
ttgl.store(c_ptr + offs_c, c)
250+
251+
elem_type = torch.float16 if in_dtype == 'float16' else torch.bfloat16
252+
a = torch.randn((M, K), device='cuda', dtype=elem_type)
253+
b = torch.randn((K, N), device='cuda', dtype=elem_type)
254+
c = torch.empty((M, N), device=a.device, dtype=elem_type)
255+
256+
blocked = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0])
257+
wmma_version = 1 if is_hip_gfx11() else 2
258+
k_width = 16 if is_hip_gfx11() else 8
259+
wmma = ttgl.amd.AMDWMMALayout(wmma_version, True, [2, 2])
260+
kernel[1, 1](a, b, c, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=M,
261+
BLOCK_SIZE_N=N, BLOCK_SIZE_K=K, BLOCKED_LAYOUT=blocked, WMMA_LAYOUT=wmma, K_WIDTH=k_width, num_warps=4)
262+
263+
ref = torch.matmul(a, b)
264+
triton_output = c
265+
torch.testing.assert_close(ref, triton_output)
266+
267+
205268
@pytest.mark.parametrize("M, N, K", [(32, 32, 16), (16, 16, 32)])
206269
@pytest.mark.parametrize("in_dtype", ['float16', 'bfloat16'])
207270
@pytest.mark.parametrize("num_warps", [4, 8])

python/test/gluon/test_frontend.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
BLACKWELL_TARGET = GPUTarget("cuda", 100, 32)
2727
HOPPER_TARGET = GPUTarget("cuda", 90, 32)
2828
AMPERE_TARGET = GPUTarget("cuda", 80, 32)
29+
HIP_TARGET_RDNA3 = GPUTarget("hip", "gfx1100", 32)
2930
HIP_TARGET_RDNA4 = GPUTarget("hip", "gfx1200", 32)
3031
HIP_TARGET_CDNA3 = GPUTarget("hip", "gfx942", 64)
3132
HIP_TARGET_CDNA4 = GPUTarget("hip", "gfx950", 64)
@@ -2129,6 +2130,78 @@ def test_buffer_load_store_with_broadcast(target):
21292130
""")
21302131

21312132

2133+
@pytest.mark.parametrize("target", [HIP_TARGET_RDNA3])
2134+
def test_amd_rdna3_wmma(target):
2135+
2136+
@gluon.jit
2137+
def kernel():
2138+
wmma_layout: ttgl.constexpr = ttgl.amd.AMDWMMALayout(version=1, transposed=True, warps_per_cta=[4, 1])
2139+
2140+
a = ttgl.full([64, 64], 1.0, ttgl.float16, layout=ttgl.DotOperandLayout(0, wmma_layout, 16))
2141+
b = ttgl.full([64, 64], 2.0, ttgl.float16, layout=ttgl.DotOperandLayout(1, wmma_layout, 16))
2142+
2143+
acc = ttgl.full([64, 64], 0.0, ttgl.float32, layout=wmma_layout)
2144+
acc = ttgl.amd.rdna3.wmma(a, b, acc)
2145+
2146+
ttgl.static_assert(isinstance(acc, ttgl.tensor))
2147+
ttgl.static_assert(acc.type.layout == wmma_layout)
2148+
2149+
module = run_parser(kernel, target=target)
2150+
expecttest.assert_expected_inline(
2151+
anonymize_ir(module.str_nodebug()), """\
2152+
#mma = #ttg.amd_wmma<{version = 1, isTranspose = true, warpsPerCTA = [4, 1]}>
2153+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
2154+
tt.func public @kernel() attributes {noinline = false} {
2155+
%cst = arith.constant 1.000000e+00 : f16
2156+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
2157+
%cst_1 = arith.constant 2.000000e+00 : f16
2158+
%cst_2 = arith.constant dense<2.000000e+00> : tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
2159+
%cst_3 = arith.constant 0.000000e+00 : f32
2160+
%cst_4 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
2161+
%cst_5 = arith.constant 0.000000e+00 : f32
2162+
%0 = tt.dot %cst_0, %cst_2, %cst_4 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<64x64xf32, #mma>
2163+
tt.return
2164+
}
2165+
}
2166+
""")
2167+
2168+
2169+
@pytest.mark.parametrize("target", [HIP_TARGET_RDNA4])
2170+
def test_amd_rdna4_wmma(target):
2171+
2172+
@gluon.jit
2173+
def kernel():
2174+
wmma_layout: ttgl.constexpr = ttgl.amd.AMDWMMALayout(version=2, transposed=True, warps_per_cta=[4, 1])
2175+
2176+
a = ttgl.full([64, 64], 1.0, ttgl.float16, layout=ttgl.DotOperandLayout(0, wmma_layout, 8))
2177+
b = ttgl.full([64, 64], 2.0, ttgl.float16, layout=ttgl.DotOperandLayout(1, wmma_layout, 8))
2178+
2179+
acc = ttgl.full([64, 64], 0.0, ttgl.float32, layout=wmma_layout)
2180+
acc = ttgl.amd.rdna4.wmma(a, b, acc)
2181+
2182+
ttgl.static_assert(isinstance(acc, ttgl.tensor))
2183+
ttgl.static_assert(acc.type.layout == wmma_layout)
2184+
2185+
module = run_parser(kernel, target=target)
2186+
expecttest.assert_expected_inline(
2187+
anonymize_ir(module.str_nodebug()), """\
2188+
#mma = #ttg.amd_wmma<{version = 2, isTranspose = true, warpsPerCTA = [4, 1]}>
2189+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
2190+
tt.func public @kernel() attributes {noinline = false} {
2191+
%cst = arith.constant 1.000000e+00 : f16
2192+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
2193+
%cst_1 = arith.constant 2.000000e+00 : f16
2194+
%cst_2 = arith.constant dense<2.000000e+00> : tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
2195+
%cst_3 = arith.constant 0.000000e+00 : f32
2196+
%cst_4 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
2197+
%cst_5 = arith.constant 0.000000e+00 : f32
2198+
%0 = tt.dot %cst_0, %cst_2, %cst_4 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<64x64xf32, #mma>
2199+
tt.return
2200+
}
2201+
}
2202+
""")
2203+
2204+
21322205
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
21332206
def test_amd_mfma(target):
21342207

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._layouts import AMDMFMALayout, AMDWMMALayout
22
from . import cdna3, cdna4
3+
from . import rdna3, rdna4
34

4-
__all__ = ["AMDMFMALayout", "AMDWMMALayout", "cdna3", "cdna4"]
5+
__all__ = ["AMDMFMALayout", "AMDWMMALayout", "cdna3", "cdna4", "rdna3", "rdna4"]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from triton import knobs
2+
from triton.experimental.gluon.language import _core as ttgl
3+
from triton.experimental.gluon.language._semantic import _check
4+
5+
from .._layouts import AMDWMMALayout
6+
from ..._layouts import DotOperandLayout
7+
from ..._core import builtin
8+
9+
__all__ = ["wmma"]
10+
11+
12+
@builtin
13+
def wmma(a, b, acc, _semantic=None):
14+
"""
15+
Computes matrix-multiplication of a * b + acc using AMD WMMA instruction.
16+
17+
Args:
18+
a (tensor): The operand a to be multiplied.
19+
b (tensor): The operand b to be multiplied.
20+
acc (tensor): The accumulator tensor.
21+
"""
22+
_check(acc is not None, lambda: "acc is required")
23+
layout = acc.type.layout
24+
_check(
25+
isinstance(layout, AMDWMMALayout) and layout.version == 1,
26+
lambda: "Expected layout to be an instance of AMDWMMALayout with version 1")
27+
_check(
28+
isinstance(a.type.layout, DotOperandLayout) and a.type.layout.parent == layout,
29+
lambda: "Expected a's layout to be a DotOperandLayout with parent matching AMDWMMALayout")
30+
_check(
31+
isinstance(b.type.layout, DotOperandLayout) and b.type.layout.parent == layout,
32+
lambda: "Expected b's layout to be a DotOperandLayout with parent matching AMDWMMALayout")
33+
34+
handle = _semantic.dot(a, b, acc, input_precision=knobs.language.fp32_default, max_num_imprecise_acc=None,
35+
out_dtype=acc.dtype).handle
36+
return ttgl.tensor(handle, acc.type)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from triton import knobs
2+
from triton.experimental.gluon.language import _core as ttgl
3+
from triton.experimental.gluon.language._semantic import _check
4+
5+
from .._layouts import AMDWMMALayout
6+
from ..._layouts import DotOperandLayout
7+
from ..._core import builtin
8+
9+
__all__ = ["wmma"]
10+
11+
12+
@builtin
13+
def wmma(a, b, acc, _semantic=None):
14+
"""
15+
Computes matrix-multiplication of a * b + acc using AMD WMMA instruction.
16+
17+
Args:
18+
a (tensor): The operand a to be multiplied.
19+
b (tensor): The operand b to be multiplied.
20+
acc (tensor): The accumulator tensor.
21+
"""
22+
_check(acc is not None, lambda: "acc is required")
23+
layout = acc.type.layout
24+
_check(
25+
isinstance(layout, AMDWMMALayout) and layout.version == 2,
26+
lambda: "Expected layout to be an instance of AMDWMMALayout with version 2")
27+
_check(
28+
isinstance(a.type.layout, DotOperandLayout) and a.type.layout.parent == layout,
29+
lambda: "Expected a's layout to be a DotOperandLayout with parent matching AMDWMMALayout")
30+
_check(
31+
isinstance(b.type.layout, DotOperandLayout) and b.type.layout.parent == layout,
32+
lambda: "Expected b's layout to be a DotOperandLayout with parent matching AMDWMMALayout")
33+
34+
handle = _semantic.dot(a, b, acc, input_precision=knobs.language.fp32_default, max_num_imprecise_acc=None,
35+
out_dtype=acc.dtype).handle
36+
return ttgl.tensor(handle, acc.type)

0 commit comments

Comments
 (0)