1111from triton .experimental .gluon .language .nvidia .blackwell import mbarrier , tma , TensorMemoryLayout , async_copy
1212from triton .experimental .gluon .nvidia .hopper import TensorDescriptor
1313from triton .experimental .gluon .language .amd import _layouts as amd_layouts
14+ from triton .experimental .gluon .language .amd .cdna4 import async_copy as cdna4_async_copy
1415from triton .experimental .gluon .language .extra import libdevice
1516
1617from triton ._filecheck import filecheck_test , run_parser
@@ -1590,7 +1591,175 @@ def test_infer_layout_for_amd_mfma(target):
15901591""" )
15911592
15921593
1593- @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA3 , HIP_TARGET_CDNA4 ])
1594+ @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA4 ])
1595+ def test_amd_load_shared_relaxed (target ):
1596+
1597+ @gluon .jit
1598+ def kernel ():
1599+ blocked : ttgl .constexpr = ttgl .BlockedLayout ([1 , 8 ], [32 , 2 ], [4 , 1 ], [1 , 0 ])
1600+ shared : ttgl .constexpr = ttgl .SwizzledSharedLayout (1 , 1 , 1 , order = [1 , 0 ])
1601+
1602+ smem = ttgl .allocate_shared_memory (ttgl .float16 , [128 , 16 ], shared )
1603+ cdna4_async_copy .load_shared_relaxed (smem , blocked )
1604+
1605+ mod = run_parser (kernel , target = target )
1606+ expecttest .assert_expected_inline (
1607+ anonymize_ir (mod .str_nodebug ()), """\
1608+ #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
1609+ #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
1610+ #smem = #ttg.shared_memory
1611+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1612+ tt.func public @kernel() attributes {noinline = false} {
1613+ %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable>
1614+ %1 = ttg.local_load %0 {ttg.amdgpu.syncedViaAsyncWait = true} : !ttg.memdesc<128x16xf16, #shared, #smem, mutable> -> tensor<128x16xf16, #blocked>
1615+ tt.return
1616+ }
1617+ }
1618+ """ )
1619+
1620+
1621+ @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA4 ])
1622+ def test_amd_load_shared_relaxed_in_loop (target ):
1623+
1624+ @gluon .jit
1625+ def kernel ():
1626+ blocked : ttgl .constexpr = ttgl .BlockedLayout ([1 , 8 ], [32 , 2 ], [4 , 1 ], [1 , 0 ])
1627+ shared : ttgl .constexpr = ttgl .SwizzledSharedLayout (1 , 1 , 1 , order = [1 , 0 ])
1628+
1629+ smem = ttgl .allocate_shared_memory (ttgl .float16 , [128 , 16 ], shared )
1630+ for i in range (10 ):
1631+ cdna4_async_copy .load_shared_relaxed (smem , blocked )
1632+
1633+ mod = run_parser (kernel , target = target )
1634+ expecttest .assert_expected_inline (
1635+ anonymize_ir (mod .str_nodebug ()), """\
1636+ #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
1637+ #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
1638+ #smem = #ttg.shared_memory
1639+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1640+ tt.func public @kernel() attributes {noinline = false} {
1641+ %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable>
1642+ %c0_i32 = arith.constant 0 : i32
1643+ %c10_i32 = arith.constant 10 : i32
1644+ %c1_i32 = arith.constant 1 : i32
1645+ %1 = arith.bitcast %c0_i32 : i32 to i32
1646+ %2 = arith.bitcast %c10_i32 : i32 to i32
1647+ %3 = arith.bitcast %c1_i32 : i32 to i32
1648+ %4 = ub.poison : i32
1649+ scf.for %arg0 = %1 to %2 step %3 : i32 {
1650+ %5 = ttg.local_load %0 {ttg.amdgpu.syncedViaAsyncWait = true} : !ttg.memdesc<128x16xf16, #shared, #smem, mutable> -> tensor<128x16xf16, #blocked>
1651+ }
1652+ tt.return
1653+ }
1654+ }
1655+ """ )
1656+
1657+
1658+ @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA4 ])
1659+ def test_amd_global_load_to_shared (target ):
1660+
1661+ @gluon .jit
1662+ def kernel (ptr ):
1663+ blocked : ttgl .constexpr = ttgl .BlockedLayout ([1 , 8 ], [32 , 2 ], [4 , 1 ], [1 , 0 ])
1664+ shared : ttgl .constexpr = ttgl .SwizzledSharedLayout (1 , 1 , 1 , order = [1 , 0 ])
1665+
1666+ smem = ttgl .allocate_shared_memory (ptr .dtype .element_ty , [128 , 16 ], shared )
1667+ offsets = ttgl .arange (0 , 128 , layout = ttgl .SliceLayout (1 , blocked ))[:, None ] * 16 + \
1668+ ttgl .arange (0 , 16 , layout = ttgl .SliceLayout (0 , blocked ))[None , :]
1669+
1670+ cdna4_async_copy .global_load_to_shared (smem , ptr + offsets )
1671+ cdna4_async_copy .async_wait (0 )
1672+
1673+ ptr = MockTensor (ttgl .float16 )
1674+ mod = run_parser (kernel , * make_args (ptr ), target = target )
1675+ expecttest .assert_expected_inline (
1676+ anonymize_ir (mod .str_nodebug ()), """\
1677+ #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
1678+ #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
1679+ #smem = #ttg.shared_memory
1680+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1681+ tt.func public @kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1682+ %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable>
1683+ %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1684+ %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
1685+ %c16_i32 = arith.constant 16 : i32
1686+ %c16_i32_0 = arith.constant 16 : i32
1687+ %cst = arith.constant dense<16> : tensor<128x1xi32, #blocked>
1688+ %3 = arith.muli %2, %cst : tensor<128x1xi32, #blocked>
1689+ %4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
1690+ %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked>
1691+ %6 = tt.broadcast %3 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked>
1692+ %7 = tt.broadcast %5 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked>
1693+ %8 = arith.addi %6, %7 : tensor<128x16xi32, #blocked>
1694+ %9 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
1695+ %10 = tt.addptr %9, %8 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi32, #blocked>
1696+ %11 = ttg.async_copy_global_to_local %10, %0 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
1697+ %12 = ttg.async_wait {num = 0 : i32}
1698+ tt.return
1699+ }
1700+ }
1701+ """ )
1702+
1703+
1704+ @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA4 ])
1705+ def test_amd_global_load_to_shared_with_broadcast (target ):
1706+
1707+ @gluon .jit
1708+ def kernel (ptr ):
1709+ blocked : ttgl .constexpr = ttgl .BlockedLayout ([1 , 8 ], [32 , 2 ], [4 , 1 ], [1 , 0 ])
1710+ shared : ttgl .constexpr = ttgl .SwizzledSharedLayout (1 , 1 , 1 , order = [1 , 0 ])
1711+
1712+ smem = ttgl .allocate_shared_memory (ptr .dtype .element_ty , [128 , 16 ], shared )
1713+ y_offset = ttgl .arange (0 , 128 , layout = ttgl .SliceLayout (1 , blocked ))
1714+ x_offset = ttgl .arange (0 , 16 , layout = ttgl .SliceLayout (0 , blocked ))
1715+ offsets = y_offset [:, None ] * 16 + x_offset [None , :]
1716+
1717+ mask = (y_offset < 64 )[:, None ]
1718+ other = tl .cast (0.0 , ptr .dtype .element_ty )
1719+
1720+ cdna4_async_copy .global_load_to_shared (smem , ptr + offsets , mask , other )
1721+ cdna4_async_copy .async_wait (0 )
1722+
1723+ ptr = MockTensor (ttgl .float16 )
1724+ mod = run_parser (kernel , * make_args (ptr ), target = target )
1725+ expecttest .assert_expected_inline (
1726+ anonymize_ir (mod .str_nodebug ()), """\
1727+ #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
1728+ #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
1729+ #smem = #ttg.shared_memory
1730+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1731+ tt.func public @kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1732+ %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable>
1733+ %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1734+ %2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
1735+ %3 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
1736+ %c16_i32 = arith.constant 16 : i32
1737+ %c16_i32_0 = arith.constant 16 : i32
1738+ %cst = arith.constant dense<16> : tensor<128x1xi32, #blocked>
1739+ %4 = arith.muli %3, %cst : tensor<128x1xi32, #blocked>
1740+ %5 = tt.expand_dims %2 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked>
1741+ %6 = tt.broadcast %4 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked>
1742+ %7 = tt.broadcast %5 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked>
1743+ %8 = arith.addi %6, %7 : tensor<128x16xi32, #blocked>
1744+ %c64_i32 = arith.constant 64 : i32
1745+ %cst_1 = arith.constant dense<64> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1746+ %9 = arith.cmpi slt, %1, %cst_1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1747+ %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi1, #blocked>
1748+ %cst_2 = arith.constant 0.000000e+00 : f32
1749+ %11 = arith.truncf %cst_2 : f32 to f16
1750+ %12 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
1751+ %13 = tt.addptr %12, %8 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi32, #blocked>
1752+ %14 = tt.broadcast %10 : tensor<128x1xi1, #blocked> -> tensor<128x16xi1, #blocked>
1753+ %15 = tt.splat %11 : f16 -> tensor<128x16xf16, #blocked>
1754+ %16 = ttg.async_copy_global_to_local %13, %0 mask %14 other %15 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
1755+ %17 = ttg.async_wait {num = 0 : i32}
1756+ tt.return
1757+ }
1758+ }
1759+ """ )
1760+
1761+
1762+ @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA4 ])
15941763def test_buffer_load_to_shared (target ):
15951764
15961765 @gluon .jit
@@ -1601,7 +1770,7 @@ def kernel(ptr):
16011770 dest = ttgl .allocate_shared_memory (ptr .dtype .element_ty , [256 ], shared )
16021771 offsets = ttgl .arange (0 , 256 , layout = blocked )
16031772
1604- ttgl . amd . cdna3 .buffer_load_to_shared (dest , ptr , offsets )
1773+ cdna4_async_copy .buffer_load_to_shared (dest , ptr , offsets )
16051774
16061775 ptr = MockTensor (ttgl .float32 )
16071776 mod = run_parser (kernel , * make_args (ptr ), target = target )
@@ -1621,7 +1790,61 @@ def kernel(ptr):
16211790""" )
16221791
16231792
1624- @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA3 , HIP_TARGET_CDNA4 ])
1793+ @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA4 ])
1794+ def test_buffer_load_to_shared_with_broadcast (target ):
1795+
1796+ @gluon .jit
1797+ def kernel (ptr ):
1798+ blocked1 : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 ], [1 , 64 ], [4 , 1 ], [1 , 0 ])
1799+ shared : ttgl .constexpr = ttgl .SwizzledSharedLayout (1 , 1 , 1 , order = [1 , 0 ])
1800+
1801+ dest = ttgl .allocate_shared_memory (ptr .dtype .element_ty , [4 , 64 ], shared )
1802+
1803+ y_index = ttgl .arange (0 , 4 , layout = ttgl .SliceLayout (1 , blocked1 ))
1804+ x_index = ttgl .arange (0 , 64 , layout = ttgl .SliceLayout (0 , blocked1 ))
1805+ offsets = y_index [:, None ] * 64 + x_index [None , :]
1806+
1807+ mask = (y_index < 2 )[:, None ]
1808+ other = 0.0
1809+
1810+ cdna4_async_copy .buffer_load_to_shared (dest , ptr , offsets , mask , other )
1811+
1812+ ptr = MockTensor (ttgl .float32 )
1813+ mod = run_parser (kernel , * make_args (ptr ), target = target )
1814+ expecttest .assert_expected_inline (
1815+ anonymize_ir (mod .str_nodebug ()), """\
1816+ #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
1817+ #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
1818+ #smem = #ttg.shared_memory
1819+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1820+ tt.func public @kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1821+ %0 = ttg.local_alloc : () -> !ttg.memdesc<4x64xf32, #shared, #smem, mutable>
1822+ %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1823+ %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
1824+ %3 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<4x1xi32, #blocked>
1825+ %c64_i32 = arith.constant 64 : i32
1826+ %c64_i32_0 = arith.constant 64 : i32
1827+ %cst = arith.constant dense<64> : tensor<4x1xi32, #blocked>
1828+ %4 = arith.muli %3, %cst : tensor<4x1xi32, #blocked>
1829+ %5 = tt.expand_dims %2 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
1830+ %6 = tt.broadcast %4 : tensor<4x1xi32, #blocked> -> tensor<4x64xi32, #blocked>
1831+ %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked> -> tensor<4x64xi32, #blocked>
1832+ %8 = arith.addi %6, %7 : tensor<4x64xi32, #blocked>
1833+ %c2_i32 = arith.constant 2 : i32
1834+ %cst_1 = arith.constant dense<2> : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1835+ %9 = arith.cmpi slt, %1, %cst_1 : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
1836+ %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<4xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<4x1xi1, #blocked>
1837+ %cst_2 = arith.constant 0.000000e+00 : f32
1838+ %11 = tt.broadcast %10 : tensor<4x1xi1, #blocked> -> tensor<4x64xi1, #blocked>
1839+ %cst_3 = arith.constant dense<0.000000e+00> : tensor<4x64xf32, #blocked>
1840+ %12 = amdgpu.buffer_load_to_local %arg0[%8] mask = %11 other = %cst_3 into %0 : <f32>[tensor<4x64xi32, #blocked>] tensor<4x64xf32, #blocked> -> <4x64xf32, #shared, #smem, mutable>
1841+ tt.return
1842+ }
1843+ }
1844+ """ )
1845+
1846+
1847+ @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA4 ])
16251848def test_buffer_load_to_shared_mask_other (target ):
16261849
16271850 @gluon .jit
@@ -1634,7 +1857,7 @@ def kernel(ptr):
16341857
16351858 mask = ttgl .full ([256 ], 1 , ttgl .int1 , layout = blocked )
16361859 other = ttgl .full ([256 ], 0 , ptr .dtype .element_ty , layout = blocked )
1637- ttgl . amd . cdna3 .buffer_load_to_shared (dest , ptr , offsets , mask , other )
1860+ cdna4_async_copy .buffer_load_to_shared (dest , ptr , offsets , mask , other )
16381861
16391862 ptr = MockTensor (ttgl .float32 )
16401863 mod = run_parser (kernel , * make_args (ptr ), target = target )
@@ -1658,7 +1881,7 @@ def kernel(ptr):
16581881""" )
16591882
16601883
1661- @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA3 , HIP_TARGET_CDNA4 ])
1884+ @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA4 ])
16621885def test_buffer_load_to_shared_cache_mods (target ):
16631886
16641887 @gluon .jit
@@ -1669,9 +1892,9 @@ def kernel(ptr):
16691892 dest = ttgl .allocate_shared_memory (ptr .dtype .element_ty , [256 ], shared )
16701893 offsets = ttgl .arange (0 , 256 , layout = blocked )
16711894
1672- ttgl . amd . cdna3 .buffer_load_to_shared (dest , ptr , offsets , cache_modifier = ".ca" )
1673- ttgl . amd . cdna3 .buffer_load_to_shared (dest , ptr , offsets , cache_modifier = ".cg" )
1674- ttgl . amd . cdna3 .buffer_load_to_shared (dest , ptr , offsets , cache_modifier = ".cv" )
1895+ cdna4_async_copy .buffer_load_to_shared (dest , ptr , offsets , cache_modifier = ".ca" )
1896+ cdna4_async_copy .buffer_load_to_shared (dest , ptr , offsets , cache_modifier = ".cg" )
1897+ cdna4_async_copy .buffer_load_to_shared (dest , ptr , offsets , cache_modifier = ".cv" )
16751898
16761899 ptr = MockTensor (ttgl .float32 )
16771900 mod = run_parser (kernel , * make_args (ptr ), target = target )
0 commit comments