Skip to content

Commit 48862cd

Browse files
authored
[AMD][Gluon] Support global/buffer load to shared (#7880)
This PR introduces following new builtin in Gluon: - `global_load_to_shared`: similar to `ttgl.nvidia.ampere.async_copy.async_copy_global_to_shared` - `async_wait`: similar to `ttgl.nvidia.ampere.async_copy.wait_group` - `load_shared_relaxed`: load from shared memory with hints to compiler not insert fence. should be used in pair with `async_wait`. this function will annotate issued local load op to prevent LLVM emitting conservative wait counts before local load. following the logic of [`annotateLocalLoadsSyncedViaAsyncWait`](https://github.com/triton-lang/triton/blob/4530b0b5d01fff3682aa8f4f312db1b7140906dc/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp#L52-L65). Along the way, there are other small changes: - Support broadcast `mask` and `other` in `buffer_load_to_shared` - Expose `other` in `create_async_copy_global_to_local` for `global_load_to_shared` - Change `buffer_load_to_shared` to CDNA4-only
1 parent 37f2659 commit 48862cd

File tree

9 files changed

+439
-55
lines changed

9 files changed

+439
-55
lines changed

python/src/gluon_ir.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,11 @@ void init_gluon_ir(py::module &&m) {
403403
})
404404
.def("create_async_copy_global_to_local",
405405
[](GluonOpBuilder &self, Value smem, Value pointer, Value mask,
406-
tt::CacheModifier cacheModifier,
406+
Value other, tt::CacheModifier cacheModifier,
407407
tt::EvictionPolicy evictionPolicy, bool isVolatile) {
408408
self.create<ttg::AsyncCopyGlobalToLocalOp>(
409-
pointer, smem, mask,
410-
/*other*/ Value{}, cacheModifier, evictionPolicy, isVolatile);
409+
pointer, smem, mask, other, cacheModifier, evictionPolicy,
410+
isVolatile);
411411
})
412412
.def("create_async_copy_mbarrier_arrive",
413413
[](GluonOpBuilder &self, Value mbarrier, bool incrementCount) {

python/test/gluon/test_core.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import pytest
3+
import re
34

45
import triton
56
import triton.language as tl
@@ -10,6 +11,7 @@
1011
from triton.experimental.gluon.language.nvidia.ampere import async_copy, mbarrier
1112
from triton.experimental.gluon.language.nvidia.hopper import tma, fence_async_shared
1213
from triton.experimental.gluon.language.nvidia import hopper
14+
from triton.experimental.gluon.language.amd.cdna4 import async_copy as cdna4_async_copy
1315
from triton.experimental.gluon.language.extra import libdevice
1416

1517

@@ -149,6 +151,42 @@ def test_warpgroup_mma(ASYNC):
149151
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-1)
150152

151153

154+
@pytest.mark.skipif(not is_hip_cdna4(), reason="Requires CDNA4")
155+
@pytest.mark.parametrize("use_buffer_load", [True, False])
156+
def test_amd_direct_load_to_shared(use_buffer_load):
157+
158+
@gluon.jit
159+
def kernel(a_ptr, b_ptr, use_buffer_load: ttgl.constexpr):
160+
blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 2], [4, 1], [1, 0])
161+
shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0])
162+
163+
smem = ttgl.allocate_shared_memory(a_ptr.dtype.element_ty, [128, 16], shared)
164+
offsets = ttgl.arange(0, 128, layout=ttgl.SliceLayout(1, blocked))[:, None] * 16 + \
165+
ttgl.arange(0, 16, layout=ttgl.SliceLayout(0, blocked))[None, :]
166+
if use_buffer_load:
167+
cdna4_async_copy.buffer_load_to_shared(smem, a_ptr, offsets)
168+
else:
169+
cdna4_async_copy.global_load_to_shared(smem, a_ptr + offsets)
170+
171+
cdna4_async_copy.async_wait(0)
172+
a = cdna4_async_copy.load_shared_relaxed(smem, blocked)
173+
174+
ttgl.store(b_ptr + offsets, a)
175+
176+
torch.manual_seed(0)
177+
a = torch.randn((128, 16), dtype=torch.float16, device='cuda')
178+
b = torch.empty_like(a)
179+
pgm = kernel[(1, )](a, b, use_buffer_load)
180+
181+
torch.testing.assert_close(a, b)
182+
assert re.search(r'ttg\.local_load .* \{ttg\.amdgpu\.syncedViaAsyncWait = true\}', pgm.asm['ttgir'], re.MULTILINE)
183+
if use_buffer_load:
184+
assert re.search(r"buffer_load.*lds$", pgm.asm['amdgcn'], re.MULTILINE)
185+
else:
186+
assert re.search(r"global_load_lds", pgm.asm['amdgcn'], re.MULTILINE)
187+
assert 'vmcnt(0)' in pgm.asm['amdgcn']
188+
189+
152190
@pytest.mark.parametrize("M, N, K", [(32, 32, 16), (16, 16, 32)])
153191
@pytest.mark.parametrize("in_dtype", ['float16', 'bfloat16'])
154192
@pytest.mark.parametrize("num_warps", [4, 8])

python/test/gluon/test_frontend.py

Lines changed: 231 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout, async_copy
1212
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
1313
from triton.experimental.gluon.language.amd import _layouts as amd_layouts
14+
from triton.experimental.gluon.language.amd.cdna4 import async_copy as cdna4_async_copy
1415
from triton.experimental.gluon.language.extra import libdevice
1516

1617
from triton._filecheck import filecheck_test, run_parser
@@ -1590,7 +1591,175 @@ def test_infer_layout_for_amd_mfma(target):
15901591
""")
15911592

15921593

1593-
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
1594+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
1595+
def test_amd_load_shared_relaxed(target):
1596+
1597+
@gluon.jit
1598+
def kernel():
1599+
blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 2], [4, 1], [1, 0])
1600+
shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0])
1601+
1602+
smem = ttgl.allocate_shared_memory(ttgl.float16, [128, 16], shared)
1603+
cdna4_async_copy.load_shared_relaxed(smem, blocked)
1604+
1605+
mod = run_parser(kernel, target=target)
1606+
expecttest.assert_expected_inline(
1607+
anonymize_ir(mod.str_nodebug()), """\
1608+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
1609+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
1610+
#smem = #ttg.shared_memory
1611+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1612+
tt.func public @kernel() attributes {noinline = false} {
1613+
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable>
1614+
%1 = ttg.local_load %0 {ttg.amdgpu.syncedViaAsyncWait = true} : !ttg.memdesc<128x16xf16, #shared, #smem, mutable> -> tensor<128x16xf16, #blocked>
1615+
tt.return
1616+
}
1617+
}
1618+
""")
1619+
1620+
1621+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
1622+
def test_amd_load_shared_relaxed_in_loop(target):
1623+
1624+
@gluon.jit
1625+
def kernel():
1626+
blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 2], [4, 1], [1, 0])
1627+
shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0])
1628+
1629+
smem = ttgl.allocate_shared_memory(ttgl.float16, [128, 16], shared)
1630+
for i in range(10):
1631+
cdna4_async_copy.load_shared_relaxed(smem, blocked)
1632+
1633+
mod = run_parser(kernel, target=target)
1634+
expecttest.assert_expected_inline(
1635+
anonymize_ir(mod.str_nodebug()), """\
1636+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
1637+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
1638+
#smem = #ttg.shared_memory
1639+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1640+
tt.func public @kernel() attributes {noinline = false} {
1641+
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable>
1642+
%c0_i32 = arith.constant 0 : i32
1643+
%c10_i32 = arith.constant 10 : i32
1644+
%c1_i32 = arith.constant 1 : i32
1645+
%1 = arith.bitcast %c0_i32 : i32 to i32
1646+
%2 = arith.bitcast %c10_i32 : i32 to i32
1647+
%3 = arith.bitcast %c1_i32 : i32 to i32
1648+
%4 = ub.poison : i32
1649+
scf.for %arg0 = %1 to %2 step %3 : i32 {
1650+
%5 = ttg.local_load %0 {ttg.amdgpu.syncedViaAsyncWait = true} : !ttg.memdesc<128x16xf16, #shared, #smem, mutable> -> tensor<128x16xf16, #blocked>
1651+
}
1652+
tt.return
1653+
}
1654+
}
1655+
""")
1656+
1657+
1658+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
1659+
def test_amd_global_load_to_shared(target):
1660+
1661+
@gluon.jit
1662+
def kernel(ptr):
1663+
blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 2], [4, 1], [1, 0])
1664+
shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0])
1665+
1666+
smem = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [128, 16], shared)
1667+
offsets = ttgl.arange(0, 128, layout=ttgl.SliceLayout(1, blocked))[:, None] * 16 + \
1668+
ttgl.arange(0, 16, layout=ttgl.SliceLayout(0, blocked))[None, :]
1669+
1670+
cdna4_async_copy.global_load_to_shared(smem, ptr + offsets)
1671+
cdna4_async_copy.async_wait(0)
1672+
1673+
ptr = MockTensor(ttgl.float16)
1674+
mod = run_parser(kernel, *make_args(ptr), target=target)
1675+
expecttest.assert_expected_inline(
1676+
anonymize_ir(mod.str_nodebug()), """\
1677+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
1678+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
1679+
#smem = #ttg.shared_memory
1680+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1681+
tt.func public @kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1682+
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable>
1683+
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1684+
%2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
1685+
%c16_i32 = arith.constant 16 : i32
1686+
%c16_i32_0 = arith.constant 16 : i32
1687+
%cst = arith.constant dense<16> : tensor<128x1xi32, #blocked>
1688+
%3 = arith.muli %2, %cst : tensor<128x1xi32, #blocked>
1689+
%4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
1690+
%5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked>
1691+
%6 = tt.broadcast %3 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked>
1692+
%7 = tt.broadcast %5 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked>
1693+
%8 = arith.addi %6, %7 : tensor<128x16xi32, #blocked>
1694+
%9 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
1695+
%10 = tt.addptr %9, %8 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi32, #blocked>
1696+
%11 = ttg.async_copy_global_to_local %10, %0 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
1697+
%12 = ttg.async_wait {num = 0 : i32}
1698+
tt.return
1699+
}
1700+
}
1701+
""")
1702+
1703+
1704+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
1705+
def test_amd_global_load_to_shared_with_broadcast(target):
1706+
1707+
@gluon.jit
1708+
def kernel(ptr):
1709+
blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 2], [4, 1], [1, 0])
1710+
shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0])
1711+
1712+
smem = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [128, 16], shared)
1713+
y_offset = ttgl.arange(0, 128, layout=ttgl.SliceLayout(1, blocked))
1714+
x_offset = ttgl.arange(0, 16, layout=ttgl.SliceLayout(0, blocked))
1715+
offsets = y_offset[:, None] * 16 + x_offset[None, :]
1716+
1717+
mask = (y_offset < 64)[:, None]
1718+
other = tl.cast(0.0, ptr.dtype.element_ty)
1719+
1720+
cdna4_async_copy.global_load_to_shared(smem, ptr + offsets, mask, other)
1721+
cdna4_async_copy.async_wait(0)
1722+
1723+
ptr = MockTensor(ttgl.float16)
1724+
mod = run_parser(kernel, *make_args(ptr), target=target)
1725+
expecttest.assert_expected_inline(
1726+
anonymize_ir(mod.str_nodebug()), """\
1727+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
1728+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
1729+
#smem = #ttg.shared_memory
1730+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1731+
tt.func public @kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1732+
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable>
1733+
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1734+
%2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
1735+
%3 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
1736+
%c16_i32 = arith.constant 16 : i32
1737+
%c16_i32_0 = arith.constant 16 : i32
1738+
%cst = arith.constant dense<16> : tensor<128x1xi32, #blocked>
1739+
%4 = arith.muli %3, %cst : tensor<128x1xi32, #blocked>
1740+
%5 = tt.expand_dims %2 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked>
1741+
%6 = tt.broadcast %4 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked>
1742+
%7 = tt.broadcast %5 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked>
1743+
%8 = arith.addi %6, %7 : tensor<128x16xi32, #blocked>
1744+
%c64_i32 = arith.constant 64 : i32
1745+
%cst_1 = arith.constant dense<64> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1746+
%9 = arith.cmpi slt, %1, %cst_1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1747+
%10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi1, #blocked>
1748+
%cst_2 = arith.constant 0.000000e+00 : f32
1749+
%11 = arith.truncf %cst_2 : f32 to f16
1750+
%12 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
1751+
%13 = tt.addptr %12, %8 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi32, #blocked>
1752+
%14 = tt.broadcast %10 : tensor<128x1xi1, #blocked> -> tensor<128x16xi1, #blocked>
1753+
%15 = tt.splat %11 : f16 -> tensor<128x16xf16, #blocked>
1754+
%16 = ttg.async_copy_global_to_local %13, %0 mask %14 other %15 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
1755+
%17 = ttg.async_wait {num = 0 : i32}
1756+
tt.return
1757+
}
1758+
}
1759+
""")
1760+
1761+
1762+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
15941763
def test_buffer_load_to_shared(target):
15951764

15961765
@gluon.jit
@@ -1601,7 +1770,7 @@ def kernel(ptr):
16011770
dest = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [256], shared)
16021771
offsets = ttgl.arange(0, 256, layout=blocked)
16031772

1604-
ttgl.amd.cdna3.buffer_load_to_shared(dest, ptr, offsets)
1773+
cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets)
16051774

16061775
ptr = MockTensor(ttgl.float32)
16071776
mod = run_parser(kernel, *make_args(ptr), target=target)
@@ -1621,7 +1790,61 @@ def kernel(ptr):
16211790
""")
16221791

16231792

1624-
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
1793+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
1794+
def test_buffer_load_to_shared_with_broadcast(target):
1795+
1796+
@gluon.jit
1797+
def kernel(ptr):
1798+
blocked1: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 64], [4, 1], [1, 0])
1799+
shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0])
1800+
1801+
dest = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [4, 64], shared)
1802+
1803+
y_index = ttgl.arange(0, 4, layout=ttgl.SliceLayout(1, blocked1))
1804+
x_index = ttgl.arange(0, 64, layout=ttgl.SliceLayout(0, blocked1))
1805+
offsets = y_index[:, None] * 64 + x_index[None, :]
1806+
1807+
mask = (y_index < 2)[:, None]
1808+
other = 0.0
1809+
1810+
cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets, mask, other)
1811+
1812+
ptr = MockTensor(ttgl.float32)
1813+
mod = run_parser(kernel, *make_args(ptr), target=target)
1814+
expecttest.assert_expected_inline(
1815+
anonymize_ir(mod.str_nodebug()), """\
1816+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
1817+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
1818+
#smem = #ttg.shared_memory
1819+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1820+
tt.func public @kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1821+
%0 = ttg.local_alloc : () -> !ttg.memdesc<4x64xf32, #shared, #smem, mutable>
1822+
%1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1823+
%2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
1824+
%3 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<4x1xi32, #blocked>
1825+
%c64_i32 = arith.constant 64 : i32
1826+
%c64_i32_0 = arith.constant 64 : i32
1827+
%cst = arith.constant dense<64> : tensor<4x1xi32, #blocked>
1828+
%4 = arith.muli %3, %cst : tensor<4x1xi32, #blocked>
1829+
%5 = tt.expand_dims %2 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
1830+
%6 = tt.broadcast %4 : tensor<4x1xi32, #blocked> -> tensor<4x64xi32, #blocked>
1831+
%7 = tt.broadcast %5 : tensor<1x64xi32, #blocked> -> tensor<4x64xi32, #blocked>
1832+
%8 = arith.addi %6, %7 : tensor<4x64xi32, #blocked>
1833+
%c2_i32 = arith.constant 2 : i32
1834+
%cst_1 = arith.constant dense<2> : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1835+
%9 = arith.cmpi slt, %1, %cst_1 : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1836+
%10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<4xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<4x1xi1, #blocked>
1837+
%cst_2 = arith.constant 0.000000e+00 : f32
1838+
%11 = tt.broadcast %10 : tensor<4x1xi1, #blocked> -> tensor<4x64xi1, #blocked>
1839+
%cst_3 = arith.constant dense<0.000000e+00> : tensor<4x64xf32, #blocked>
1840+
%12 = amdgpu.buffer_load_to_local %arg0[%8] mask = %11 other = %cst_3 into %0 : <f32>[tensor<4x64xi32, #blocked>] tensor<4x64xf32, #blocked> -> <4x64xf32, #shared, #smem, mutable>
1841+
tt.return
1842+
}
1843+
}
1844+
""")
1845+
1846+
1847+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
16251848
def test_buffer_load_to_shared_mask_other(target):
16261849

16271850
@gluon.jit
@@ -1634,7 +1857,7 @@ def kernel(ptr):
16341857

16351858
mask = ttgl.full([256], 1, ttgl.int1, layout=blocked)
16361859
other = ttgl.full([256], 0, ptr.dtype.element_ty, layout=blocked)
1637-
ttgl.amd.cdna3.buffer_load_to_shared(dest, ptr, offsets, mask, other)
1860+
cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets, mask, other)
16381861

16391862
ptr = MockTensor(ttgl.float32)
16401863
mod = run_parser(kernel, *make_args(ptr), target=target)
@@ -1658,7 +1881,7 @@ def kernel(ptr):
16581881
""")
16591882

16601883

1661-
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
1884+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4])
16621885
def test_buffer_load_to_shared_cache_mods(target):
16631886

16641887
@gluon.jit
@@ -1669,9 +1892,9 @@ def kernel(ptr):
16691892
dest = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [256], shared)
16701893
offsets = ttgl.arange(0, 256, layout=blocked)
16711894

1672-
ttgl.amd.cdna3.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".ca")
1673-
ttgl.amd.cdna3.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".cg")
1674-
ttgl.amd.cdna3.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".cv")
1895+
cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".ca")
1896+
cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".cg")
1897+
cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".cv")
16751898

16761899
ptr = MockTensor(ttgl.float32)
16771900
mod = run_parser(kernel, *make_args(ptr), target=target)

0 commit comments

Comments
 (0)