|
4 | 4 | import triton |
5 | 5 | from triton.experimental import gluon |
6 | 6 | from triton.experimental.gluon import language as ttgl |
| 7 | +from triton._internal_testing import is_cuda, is_hip, is_hopper_or_newer |
| 8 | + |
| 9 | + |
| 10 | +def _is_layout_applicable(layout) -> bool: |
| 11 | + if isinstance(layout, ttgl.SliceLayout): |
| 12 | + return _is_layout_applicable(layout.parent) |
| 13 | + elif is_cuda(): |
| 14 | + mma_layout = layout.parent if isinstance(layout, ttgl.DotOperandLayout) else layout |
| 15 | + if not isinstance(mma_layout, ttgl.NVMMADistributedLayout): |
| 16 | + return False |
| 17 | + if mma_layout.version[0] >= 3 and not is_hopper_or_newer(): |
| 18 | + return False |
| 19 | + return True |
| 20 | + elif is_hip(): |
| 21 | + # TODO: Add other amd layouts |
| 22 | + return isinstance(layout, ttgl.amd.AMDMFMALayout) |
| 23 | + else: |
| 24 | + return True |
| 25 | + |
| 26 | + |
| 27 | +def _filter_layouts(layouts): |
| 28 | + return [l for l in layouts if _is_layout_applicable(l)] |
| 29 | + |
7 | 30 |
|
8 | 31 | THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size |
9 | 32 |
|
10 | 33 |
|
11 | 34 | @pytest.mark.parametrize("M, N", [(32, 16), (32, 32), (32, 64), (64, 32)]) |
12 | | -@pytest.mark.parametrize("src_layout", [ |
13 | | - ttgl.BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1]), |
14 | | - ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]), |
15 | | - ttgl.BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1]), |
16 | | - ttgl.BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1]), |
17 | | - ttgl.BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1]), |
18 | | - ttgl.BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0]), |
19 | | - ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]), |
20 | | - ttgl.BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0]), |
21 | | - ttgl.BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0]), |
22 | | - ttgl.BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0]), |
23 | | - ttgl.BlockedLayout([1, 2], [1, THREADS_PER_WARP], [1, 4], [1, 0]), |
24 | | -]) |
| 35 | +@pytest.mark.parametrize( |
| 36 | + "src_layout", |
| 37 | + _filter_layouts([ |
| 38 | + ttgl.BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1]), |
| 39 | + ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]), |
| 40 | + ttgl.BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1]), |
| 41 | + ttgl.BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1]), |
| 42 | + ttgl.BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1]), |
| 43 | + ttgl.BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0]), |
| 44 | + ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]), |
| 45 | + ttgl.BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0]), |
| 46 | + ttgl.BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0]), |
| 47 | + ttgl.BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0]), |
| 48 | + ttgl.BlockedLayout([1, 2], [1, THREADS_PER_WARP], [1, 4], [1, 0]), |
| 49 | + ])) |
25 | 50 | @pytest.mark.parametrize("axis", [0, 1]) |
26 | 51 | @pytest.mark.parametrize("sanitize_overflow", [False, True]) |
27 | 52 | def test_scan_layouts(M, N, src_layout, axis, sanitize_overflow, device): |
@@ -49,3 +74,117 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.cons |
49 | 74 |
|
50 | 75 | z_ref = torch.cumsum(x, dim=axis, dtype=torch.int32) |
51 | 76 | torch.testing.assert_close(z_tri, z_ref) |
| 77 | + |
| 78 | + |
| 79 | +@pytest.mark.parametrize("M, N", [[128, 16], [32, 128], [32, 32], [16, 16]]) |
| 80 | +@pytest.mark.parametrize( |
| 81 | + "src_layout", |
| 82 | + _filter_layouts([ |
| 83 | + # FIXME: Do not enable these tests until the SLPVectorizor problem with nvptx target has been resolved |
| 84 | + # SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 1, 4], [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2])), |
| 85 | + # SliceLayout(dim=0, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 4, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2])), |
| 86 | + ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), |
| 87 | + ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), |
| 88 | + ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1], |
| 89 | + cta_order=[0, 1], instr_shape=[16, 8]), |
| 90 | + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], |
| 91 | + cta_order=[1, 0], instr_shape=[16, 16, 16]), |
| 92 | + ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32], |
| 93 | + transposed=False), |
| 94 | + ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[1, 4], tiles_per_warp=[1, 1], instr_shape=[32, 32], |
| 95 | + transposed=False), |
| 96 | + ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32], |
| 97 | + transposed=True), |
| 98 | + ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[1, 4], tiles_per_warp=[1, 1], instr_shape=[32, 32], |
| 99 | + transposed=True), |
| 100 | + # TODO: AMDWMMA layouts |
| 101 | + # WmmaLayout(version=1, warps_per_cta=[4, 1]), |
| 102 | + # WmmaLayout(version=1, warps_per_cta=[1, 4]), |
| 103 | + ttgl.DotOperandLayout( |
| 104 | + parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], ctas_per_cga=[1, 1], # |
| 105 | + cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), # |
| 106 | + operand_index=1, k_width=8), |
| 107 | + ttgl.DotOperandLayout( |
| 108 | + parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 1], ctas_per_cga=[1, 1], # |
| 109 | + cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 32, 16]), # |
| 110 | + operand_index=0, k_width=2), |
| 111 | + ttgl.SliceLayout( |
| 112 | + dim=0, |
| 113 | + parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1], # |
| 114 | + cta_split_num=[1, 1, 1], cta_order=[2, 1, 0], instr_shape=[1, 16, |
| 115 | + 8])), # |
| 116 | + ttgl.SliceLayout( |
| 117 | + dim=1, parent=ttgl.DotOperandLayout( |
| 118 | + parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1], # |
| 119 | + cta_split_num=[1, 1, 1], cta_order=[2, 1, 0], instr_shape=[1, 16, |
| 120 | + 8]), # |
| 121 | + operand_index=1, k_width=2)), |
| 122 | + "linear_layout", |
| 123 | + ])) |
| 124 | +@pytest.mark.parametrize("axis", [0, 1]) |
| 125 | +@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d']) |
| 126 | +@pytest.mark.parametrize("dtype_str, sanitize_overflow", [("int32", False), ("int32", True), ("float32", False), |
| 127 | + ("float16", False)]) |
| 128 | +@pytest.mark.parametrize("reduce_op", ["sum", "max"]) |
| 129 | +def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, sanitize_overflow, reduce_op, device): |
| 130 | + if src_layout == "linear_layout": |
| 131 | + src_layout = ttgl.DistributedLinearLayout(reg_bases=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], # |
| 132 | + lane_bases=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], # |
| 133 | + warp_bases=[[32, 0], [0, 32]], block_bases=[], shape=[M, N]) |
| 134 | + if THREADS_PER_WARP != (1 << len(src_layout.lane_bases)): |
| 135 | + pytest.skip(f"Skipping. This LinearLayout assumes {1 << len(src_layout.lane_bases)} threads per warp") |
| 136 | + elif M < 64 or N < 64: |
| 137 | + pytest.skip(f"Skipping. This LinearLayout assumes M >= 64 and N >= 64, got M={M}, N={N}") |
| 138 | + if isinstance(src_layout, |
| 139 | + (ttgl.amd.AMDMFMALayout, ttgl.NVMMADistributedLayout)) and (M < src_layout.instr_shape[0] |
| 140 | + or N < src_layout.instr_shape[1]): |
| 141 | + pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape") |
| 142 | + |
| 143 | + @gluon.jit |
| 144 | + def _add(a, b): |
| 145 | + return a + b |
| 146 | + |
| 147 | + @gluon.jit |
| 148 | + def _max(a, b): |
| 149 | + return ttgl.maximum(a, b) |
| 150 | + |
| 151 | + combine_fn = _add if reduce_op == "sum" else _max |
| 152 | + |
| 153 | + @gluon.jit |
| 154 | + def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.constexpr, axis: ttgl.constexpr, |
| 155 | + epilogue_kind: ttgl.constexpr): |
| 156 | + x_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, layout))[:, None] |
| 157 | + x_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, layout))[None, :] |
| 158 | + x = ttgl.load(x_ptr + x_offs_m * N + x_offs_n) |
| 159 | + y = ttgl.reduce(x, axis=axis, combine_fn=combine_fn) |
| 160 | + if epilogue_kind == "reduce1d": |
| 161 | + if axis == 0: |
| 162 | + z_offs = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, layout)) |
| 163 | + else: |
| 164 | + z_offs = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, layout)) |
| 165 | + ttgl.store(z_ptr + z_offs, y) |
| 166 | + elif epilogue_kind == "reduce2d": |
| 167 | + y = ttgl.reduce(y, axis=0, combine_fn=combine_fn) |
| 168 | + ttgl.store(z_ptr, y) |
| 169 | + elif epilogue_kind == "expand_reduce2d": |
| 170 | + y = ttgl.expand_dims(y, axis=axis) |
| 171 | + y = ttgl.reduce(y, axis=1 - axis, combine_fn=combine_fn) |
| 172 | + z_offs = ttgl.arange(0, 1, layout=ttgl.SliceLayout(1 - axis, layout)) |
| 173 | + ttgl.store(z_ptr + z_offs, y) |
| 174 | + |
| 175 | + torch.manual_seed(0) |
| 176 | + |
| 177 | + torch_dtype = getattr(torch, dtype_str) |
| 178 | + x = torch.randint(-10, 10, (M, N), dtype=torch.int32, device=device).to(torch_dtype) |
| 179 | + out_shape = (1, 1) if "reduce2d" in epilogue_kind else (1, N) if axis == 0 else (M, 1) |
| 180 | + z = torch.empty(out_shape, dtype=torch_dtype, device=device) |
| 181 | + |
| 182 | + num_warps = int(torch.prod(torch.tensor(ttgl._layouts.warps_per_cta(src_layout, (M, N))))) |
| 183 | + kernel[(1, 1, 1)](x, z, M, N, src_layout, axis, num_warps=num_warps, epilogue_kind=epilogue_kind, |
| 184 | + sanitize_overflow=sanitize_overflow, debug=sanitize_overflow) |
| 185 | + |
| 186 | + reduce_fn = torch.sum if reduce_op == "sum" else torch.amax |
| 187 | + z_ref = reduce_fn(x, dim=axis, keepdim=True) |
| 188 | + if epilogue_kind in ("expand_reduce2d", "reduce2d"): |
| 189 | + z_ref = reduce_fn(z_ref, dim=1 - axis, keepdim=True) |
| 190 | + torch.testing.assert_close(z, z_ref.to(torch_dtype)) |
0 commit comments