|
1 | 1 | import torch |
2 | 2 | import pytest |
3 | 3 |
|
4 | | -from triton._internal_testing import is_cuda, is_ampere_or_newer, is_hopper_or_newer, is_hopper |
| 4 | +from triton._internal_testing import is_cuda, is_ampere_or_newer, is_hip_cdna3, is_hip_cdna4, is_hopper_or_newer, is_hopper |
5 | 5 | from triton.experimental import gluon |
6 | 6 | from triton.experimental.gluon import language as ttgl |
7 | 7 | from triton.experimental.gluon.language.nvidia.ampere import async_copy, mbarrier |
@@ -143,3 +143,66 @@ def test_warpgroup_mma(ASYNC): |
143 | 143 | ref = torch.matmul(a, b) |
144 | 144 |
|
145 | 145 | torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-1) |
| 146 | + |
| 147 | + |
| 148 | +@pytest.mark.parametrize("M, N, K", [(32, 32, 16), (16, 16, 32)]) |
| 149 | +@pytest.mark.parametrize("in_dtype", ['float16', 'bfloat16']) |
| 150 | +@pytest.mark.parametrize("num_warps", [4, 8]) |
| 151 | +@pytest.mark.parametrize("cdna_version", [3, 4]) |
| 152 | +def test_amd_mfma(M, N, K, in_dtype, num_warps, cdna_version): |
| 153 | + |
| 154 | + @gluon.jit |
| 155 | + def kernel(a_ptr, b_ptr, c_ptr, stride_am, stride_ak, # |
| 156 | + stride_bk, stride_bn, # |
| 157 | + stride_cm, stride_cn, BLOCK_SIZE_M: ttgl.constexpr, BLOCK_SIZE_N: ttgl.constexpr, |
| 158 | + BLOCK_SIZE_K: ttgl.constexpr, blocked: ttgl.constexpr, mfma_layout: ttgl.constexpr): |
| 159 | + dot_a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=8) |
| 160 | + dot_b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=8) |
| 161 | + |
| 162 | + offs_am = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, blocked)) |
| 163 | + offs_bn = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, blocked)) |
| 164 | + |
| 165 | + offs_ak = ttgl.arange(0, BLOCK_SIZE_K, layout=ttgl.SliceLayout(0, blocked)) |
| 166 | + offs_bk = ttgl.arange(0, BLOCK_SIZE_K, layout=ttgl.SliceLayout(1, blocked)) |
| 167 | + offs_a = offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak |
| 168 | + offs_b = offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn |
| 169 | + |
| 170 | + a = ttgl.amd.cdna3.buffer_load(ptr=a_ptr, offsets=offs_a) |
| 171 | + b = ttgl.amd.cdna3.buffer_load(ptr=b_ptr, offsets=offs_b) |
| 172 | + a1 = ttgl.convert_layout(a, layout=dot_a_layout) |
| 173 | + b1 = ttgl.convert_layout(b, layout=dot_b_layout) |
| 174 | + acc = ttgl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], ttgl.float32, mfma_layout) |
| 175 | + c = ttgl.amd.cdna3.mfma(a1, b1, acc) |
| 176 | + c = ttgl.convert_layout(c, layout=blocked) |
| 177 | + c = c.to(a_ptr.dtype.element_ty) |
| 178 | + |
| 179 | + offs_cm = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, blocked)) |
| 180 | + offs_cn = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, blocked)) |
| 181 | + offs_c = offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn |
| 182 | + ttgl.amd.cdna3.buffer_store(stored_value=c, ptr=c_ptr, offsets=offs_c) |
| 183 | + |
| 184 | + if not is_hip_cdna4() and not is_hip_cdna3(): |
| 185 | + pytest.skip("mfma quires target to be CDNA3 or CDNA4") |
| 186 | + |
| 187 | + if is_hip_cdna3() and cdna_version != 3: |
| 188 | + pytest.skip("On CDNA3 target, skip if mfma version is not 3") |
| 189 | + |
| 190 | + if is_hip_cdna4() and cdna_version != 4: |
| 191 | + pytest.skip("On CDNA4 target, skip if mfma version is not 4") |
| 192 | + |
| 193 | + elem_type = torch.float16 if in_dtype == 'float16' else torch.bfloat16 |
| 194 | + a = torch.randn((M, K), device='cuda', dtype=elem_type) - 0.5 |
| 195 | + b = torch.randn((K, N), device='cuda', dtype=elem_type) - 0.5 |
| 196 | + c = torch.empty((M, N), device=a.device, dtype=elem_type) |
| 197 | + nonkdim: ttgl.constexpr = 32 |
| 198 | + blocked: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[4, 4], threads_per_warp=[4, 16], |
| 199 | + warps_per_cta=[num_warps, 1], order=[1, 0]) |
| 200 | + mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=cdna_version, instr_shape=[nonkdim, nonkdim], |
| 201 | + transposed=True, warps_per_cta=[num_warps, 1]) |
| 202 | + |
| 203 | + kernel[1, 1](a, b, c, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=M, |
| 204 | + BLOCK_SIZE_N=N, BLOCK_SIZE_K=K, blocked=blocked, mfma_layout=mfma_layout, num_warps=num_warps) |
| 205 | + |
| 206 | + ref = torch.matmul(a, b) |
| 207 | + triton_output = c |
| 208 | + torch.testing.assert_close(ref, triton_output) |
0 commit comments