|
5 | 5 | import triton
|
6 | 6 | from triton.backends.compiler import GPUTarget
|
7 | 7 |
|
| 8 | +from triton.experimental import gluon |
| 9 | +from triton.experimental.gluon import language as ttgl |
| 10 | +from triton.experimental.gluon.language.nvidia.blackwell import ( |
| 11 | + TensorMemoryLayout, |
| 12 | + allocate_tensor_memory, |
| 13 | + get_tmem_32x32b_reg_layout, |
| 14 | + mbarrier, |
| 15 | + tcgen05_mma, |
| 16 | + tcgen05_commit, |
| 17 | +) |
8 | 18 |
|
9 |
| -def test_tmem_copy_2d(): |
10 |
| - if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] == 10: |
11 |
| - pytest.skip("Test requires Blackwell target.") |
12 | 19 |
|
| 20 | +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10") |
| 21 | +def test_tmem_copy_2d(): |
13 | 22 | device = "cuda"
|
14 | 23 |
|
15 | 24 | smem_h = 256
|
@@ -89,3 +98,174 @@ def test_tmem_copy_2d():
|
89 | 98 | for i in range(4):
|
90 | 99 | # Copied values are duplicated across warps
|
91 | 100 | assert torch.equal(x[m * 32:(m + 1) * 32], z_tri[32 * i:32 * (i + 1), col_offset:(col_offset + 4)])
|
| 101 | + |
| 102 | + |
| 103 | +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10") |
| 104 | +def test_tmem_subslice_block_m_64(): |
| 105 | + |
| 106 | + @gluon.jit |
| 107 | + def kernel(s_ptr, out_ptr): |
| 108 | + BLOCK_M: ttgl.constexpr = 64 |
| 109 | + N: ttgl.constexpr = 128 |
| 110 | + BLOCK_N: ttgl.constexpr = 64 |
| 111 | + |
| 112 | + tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), unpacked=True) |
| 113 | + s_tmem = allocate_tensor_memory(ttgl.float32, (BLOCK_M, N), layout=tmem_layout) |
| 114 | + o_tmem = allocate_tensor_memory(ttgl.float32, (BLOCK_M, N), layout=tmem_layout) |
| 115 | + |
| 116 | + layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, (BLOCK_M, N), num_warps=4) |
| 117 | + |
| 118 | + offsets = ttgl.arange(0, BLOCK_M)[:, None] * N + ttgl.arange(0, N)[None, :] |
| 119 | + offsets = ttgl.convert_layout(offsets, layout) |
| 120 | + s = ttgl.load(s_ptr + offsets) |
| 121 | + |
| 122 | + s_tmem.store(s) |
| 123 | + o_tmem.store(s) |
| 124 | + |
| 125 | + p_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), unpacked=False) |
| 126 | + p_tmem = s_tmem.slice(0, N // 2)._reinterpret(ttgl.float16, [BLOCK_M, N], p_tmem_layout) |
| 127 | + p_tmem.store(ttgl.full((BLOCK_M, N), 0.0, dtype=ttgl.float16, layout=layout)) |
| 128 | + |
| 129 | + d1_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, 1), unpacked=True) |
| 130 | + d1_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, 1, (BLOCK_M, 1), num_warps=4) |
| 131 | + |
| 132 | + m_tmem = s_tmem.slice(N // 4, 1)._reinterpret(ttgl.float32, [BLOCK_M, 1], d1_tmem_layout) |
| 133 | + m_tmem.store(ttgl.full((BLOCK_M, 1), 2.0, dtype=ttgl.float32, layout=d1_layout)) |
| 134 | + l_tmem = s_tmem.slice(N // 4 + 1, 1)._reinterpret(ttgl.float32, [BLOCK_M, 1], d1_tmem_layout) |
| 135 | + l_tmem.store(ttgl.full((BLOCK_M, 1), 3.0, dtype=ttgl.float32, layout=d1_layout)) |
| 136 | + a_tmem = s_tmem.slice(N // 4 + 2, 1)._reinterpret(ttgl.float32, [BLOCK_M, 1], d1_tmem_layout) |
| 137 | + a_tmem.store(ttgl.full((BLOCK_M, 1), 4.0, dtype=ttgl.float32, layout=d1_layout)) |
| 138 | + |
| 139 | + s = s_tmem.load(layout) |
| 140 | + |
| 141 | + ttgl.store(out_ptr + offsets, s) |
| 142 | + |
| 143 | + torch.manual_seed(0) |
| 144 | + s = torch.randn((64, 128), dtype=torch.float32, device="cuda") |
| 145 | + |
| 146 | + out_tri = torch.empty_like(s) |
| 147 | + compiled = kernel[(1, )](s, out_tri) |
| 148 | + |
| 149 | + ttgir = compiled.asm["ttgir"] |
| 150 | + # Check that we have two 64x128xf32 allocations. |
| 151 | + assert ttgir.count("ttng.tmem_alloc") == 2 |
| 152 | + assert ttgir.count("ttng.tmem_alloc : () -> !ttg.memdesc<64x128xf32") == 2 |
| 153 | + |
| 154 | + # Check that we allocated only 128 columns of TMEM. |
| 155 | + llir = compiled.asm["llir"] |
| 156 | + assert llir.count("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$1], 128") |
| 157 | + |
| 158 | + # Given TMEM[0:32] is the slice of TMEM for warpgroup 0, the expected layout |
| 159 | + # of S is |
| 160 | + # |
| 161 | + # TMEM[0:16] = S[0:16, 0:64] |
| 162 | + # TMEM[16:32] = S[0:16, 64:128] |
| 163 | + # |
| 164 | + # When slicing S to obtain P, we expect it to overlap with the left half, |
| 165 | + # i.e. S[0:16, 0:32] and S[0:16, 64:96]. |
| 166 | + out_ref = s |
| 167 | + out_ref[:, 0:32] = 0.0 |
| 168 | + out_ref[:, 64:96] = 0.0 |
| 169 | + |
| 170 | + # Given S = [s0, s1, s2, s3], they are arranged like |
| 171 | + # |
| 172 | + # TMEM[0:16] = [s0, s1] |
| 173 | + # TMEM[16:32] = [s2, s3] |
| 174 | + # |
| 175 | + # Thus slicing S at N//4 will obtain an offset to the beginning of s1. |
| 176 | + out_ref[:, 32] = 2.0 |
| 177 | + out_ref[:, 33] = 3.0 |
| 178 | + out_ref[:, 34] = 4.0 |
| 179 | + |
| 180 | + torch.testing.assert_close(out_ref, out_tri, atol=0, rtol=0) |
| 181 | + |
| 182 | + |
| 183 | +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10") |
| 184 | +def test_block_m_64_mma(): |
| 185 | + |
| 186 | + @gluon.jit |
| 187 | + def kernel(a_ptr, b_ptr, c_ptr, d_ptr): |
| 188 | + BLOCK_M: ttgl.constexpr = 64 |
| 189 | + N: ttgl.constexpr = 128 |
| 190 | + BLOCK_N: ttgl.constexpr = 64 |
| 191 | + |
| 192 | + a_offsets = ttgl.arange(0, BLOCK_M)[:, None] * N + ttgl.arange(0, N)[None, :] |
| 193 | + b_offsets = ttgl.arange(0, N)[:, None] * N + ttgl.arange(0, N)[None, :] |
| 194 | + |
| 195 | + a_layout: ttgl.constexpr = get_tmem_32x32b_reg_layout(BLOCK_M, BLOCK_N, (BLOCK_M, N), num_warps=4) |
| 196 | + b_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0]) |
| 197 | + a_offsets = ttgl.convert_layout(a_offsets, a_layout) |
| 198 | + b_offsets = ttgl.convert_layout(b_offsets, b_layout) |
| 199 | + |
| 200 | + a = ttgl.load(a_ptr + a_offsets) |
| 201 | + b = ttgl.load(b_ptr + b_offsets) |
| 202 | + c = ttgl.load(c_ptr + a_offsets) |
| 203 | + |
| 204 | + a_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), unpacked=False) |
| 205 | + acc_tmem_layout: ttgl.constexpr = TensorMemoryLayout((BLOCK_M, BLOCK_N), unpacked=True) |
| 206 | + al_tmem = allocate_tensor_memory(ttgl.float16, (BLOCK_M, N), layout=a_tmem_layout) |
| 207 | + ar_tmem = allocate_tensor_memory(ttgl.float16, (BLOCK_M, N), layout=a_tmem_layout) |
| 208 | + acc_tmem = allocate_tensor_memory(ttgl.float32, (BLOCK_M, N), layout=acc_tmem_layout) |
| 209 | + |
| 210 | + a0, a1 = a.reshape((BLOCK_M, 2, N // 2)).permute(0, 2, 1).split() |
| 211 | + |
| 212 | + al = ttgl.join(a0, a1).permute(0, 2, 1).reshape((BLOCK_M, N)) |
| 213 | + ar = ttgl.join(a1, a0).permute(0, 2, 1).reshape((BLOCK_M, N)) |
| 214 | + |
| 215 | + al_tmem.store(ttgl.convert_layout(al, a_layout, assert_trivial=True)) |
| 216 | + ar_tmem.store(ttgl.convert_layout(ar, a_layout, assert_trivial=True)) |
| 217 | + |
| 218 | + b_shared_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=32, element_bitwidth=16, rank=2) |
| 219 | + b_shared = ttgl.allocate_shared_memory(ttgl.float16, [N, N], layout=b_shared_layout) |
| 220 | + b_shared.store(b) |
| 221 | + |
| 222 | + acc_tmem.store(c) |
| 223 | + |
| 224 | + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], ttgl.constexpr(mbarrier.MBarrierLayout())) |
| 225 | + mbarrier.init(bar, count=1) |
| 226 | + |
| 227 | + # This is a manually tiled MMA where LHS is in TMEM with blockM=64, |
| 228 | + # where we circumvent the limitation that LHS and accumulator need to |
| 229 | + # share the same TMEM rows by storing the LHS twice. |
| 230 | + # |
| 231 | + # TMEM al ar c |
| 232 | + # [0, 16) a0 a1 c0 |
| 233 | + # [16, 32) a1 a0 c1 |
| 234 | + # |
| 235 | + # d0 = a0 @ b00 + a1 @ b10 + c0 |
| 236 | + # d1 = a0 @ b10 + a1 @ b11 + c1 |
| 237 | + |
| 238 | + N2: ttgl.constexpr = N // 2 |
| 239 | + c0 = acc_tmem.slice(0, N2) |
| 240 | + c1 = acc_tmem.slice(N2, N2) |
| 241 | + |
| 242 | + tcgen05_mma(al_tmem.slice(0, N2), b_shared.slice(0, N2, dim=0).slice(0, N2, dim=1), c0) |
| 243 | + tcgen05_mma(ar_tmem.slice(0, N2), b_shared.slice(N2, N2, dim=0).slice(0, N2, dim=1), c0) |
| 244 | + tcgen05_mma(ar_tmem.slice(N2, N2), b_shared.slice(0, N2, dim=0).slice(N2, N2, dim=1), c1) |
| 245 | + tcgen05_mma(al_tmem.slice(N2, N2), b_shared.slice(N2, N2, dim=0).slice(N2, N2, dim=1), c1) |
| 246 | + |
| 247 | + tcgen05_commit(bar) |
| 248 | + mbarrier.wait(bar, 0) |
| 249 | + mbarrier.invalidate(bar) |
| 250 | + |
| 251 | + d = acc_tmem.load(a_layout) |
| 252 | + ttgl.store(d_ptr + a_offsets, d) |
| 253 | + |
| 254 | + torch.manual_seed(0) |
| 255 | + a = torch.randn((64, 128), dtype=torch.float16, device="cuda") |
| 256 | + b = torch.randn((128, 128), dtype=torch.float16, device="cuda") |
| 257 | + c = torch.randn((64, 128), dtype=torch.float32, device="cuda") |
| 258 | + |
| 259 | + d_tri = torch.empty_like(c) |
| 260 | + compiled = kernel[(1, )](a, b, c, d_tri) |
| 261 | + |
| 262 | + ttgir = compiled.asm["ttgir"] |
| 263 | + assert ttgir.count("ttng.tmem_alloc") == 3 |
| 264 | + assert ttgir.count("ttng.tmem_alloc : () -> !ttg.memdesc<64x128xf32") == 1 |
| 265 | + assert ttgir.count("ttng.tmem_alloc : () -> !ttg.memdesc<64x128xf16") == 2 |
| 266 | + |
| 267 | + llir = compiled.asm["llir"] |
| 268 | + assert llir.count("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [$1], 128") |
| 269 | + |
| 270 | + d_ref = a @ b + c |
| 271 | + torch.testing.assert_close(d_ref, d_tri, rtol=0.08, atol=0) |
0 commit comments