|
5 | 5 | from typing import Optional |
6 | 6 | import math |
7 | 7 | import textwrap |
8 | | -import pathlib |
9 | 8 |
|
10 | 9 | import numpy as np |
11 | 10 | import pytest |
|
29 | 28 | is_cuda, |
30 | 29 | is_interpreter, |
31 | 30 | is_hopper, |
32 | | - is_hopper_or_newer, |
33 | 31 | is_hip, |
34 | 32 | is_hip_cdna, |
35 | 33 | is_hip_cdna2, |
@@ -144,199 +142,6 @@ def get_src_element_ty_size(dtype_str): |
144 | 142 | raise ValueError(f"Unknown dtype {dtype_str}") |
145 | 143 |
|
146 | 144 |
|
147 | | -class MfmaLayout: |
148 | | - |
149 | | - def __init__(self, version, warps_per_cta, tiles_per_warp, instr_shape, is_transposed): |
150 | | - self.version = version |
151 | | - self.warps_per_cta = warps_per_cta |
152 | | - self.tiles_per_warp = tiles_per_warp |
153 | | - self.instr_shape = instr_shape |
154 | | - self.is_transposed = is_transposed |
155 | | - |
156 | | - def __str__(self): |
157 | | - return f"#{GPU_DIALECT}.amd_mfma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA = {self.warps_per_cta}, tilesPerWarp = {self.tiles_per_warp}, instrShape={self.instr_shape}, isTransposed = {str(self.is_transposed).lower()}}}>" |
158 | | - |
159 | | - |
160 | | -class WmmaLayout: |
161 | | - |
162 | | - def __init__(self, version, warps_per_cta): |
163 | | - self.version = version |
164 | | - self.warps_per_cta = warps_per_cta |
165 | | - |
166 | | - def __str__(self): |
167 | | - return f"#{GPU_DIALECT}.amd_wmma<{{version = {self.version}, warpsPerCTA = {self.warps_per_cta}}}>" |
168 | | - |
169 | | - |
170 | | -class MmaLayout: |
171 | | - |
172 | | - def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape): |
173 | | - self.version = version |
174 | | - self.warps_per_cta = warps_per_cta |
175 | | - self.ctas_per_cga = ctas_per_cga |
176 | | - self.cta_split_num = cta_split_num |
177 | | - self.cta_order = cta_order |
178 | | - self.instr_shape = instr_shape |
179 | | - |
180 | | - def __str__(self): |
181 | | - return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" |
182 | | - |
183 | | - |
184 | | -class DotOperandLayout: |
185 | | - |
186 | | - def __init__(self, parent, op_idx, k_width): |
187 | | - self.parent = parent |
188 | | - self.op_idx = op_idx |
189 | | - self.k_width = k_width |
190 | | - |
191 | | - def __str__(self): |
192 | | - return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>" |
193 | | - |
194 | | - |
195 | | -class SliceLayout: |
196 | | - |
197 | | - def __init__(self, dim, parent): |
198 | | - self.dim = dim |
199 | | - self.parent = parent |
200 | | - |
201 | | - def __str__(self): |
202 | | - return f"#{GPU_DIALECT}.slice<{{dim = {self.dim}, parent = {self.parent}}}>" |
203 | | - |
204 | | - |
205 | | -class BlockedLayout: |
206 | | - |
207 | | - def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1], |
208 | | - cta_split_num=[1, 1], cta_order=[0, 1]): |
209 | | - self.sz_per_thread = size_per_thread |
210 | | - self.threads_per_warp = threads_per_warp |
211 | | - self.warps_per_cta = warps_per_cta |
212 | | - self.order = order |
213 | | - self.ctas_per_cga = ctas_per_cga |
214 | | - self.cta_split_num = cta_split_num |
215 | | - self.cta_order = cta_order |
216 | | - |
217 | | - def __str__(self): |
218 | | - return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" |
219 | | - |
220 | | - |
221 | | -class SwizzledSharedLayout: |
222 | | - |
223 | | - def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): |
224 | | - self.vec = vec |
225 | | - self.per_phase = per_phase |
226 | | - self.max_phase = max_phase |
227 | | - self.order = order |
228 | | - self.ctas_per_cga = ctas_per_cga |
229 | | - self.cta_split_num = cta_split_num |
230 | | - self.cta_order = cta_order |
231 | | - |
232 | | - def __str__(self): |
233 | | - return f"#{GPU_DIALECT}.swizzled_shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" |
234 | | - |
235 | | - |
236 | | -class PaddedSharedLayout: |
237 | | - |
238 | | - def __init__(self, interval_padding_pairs, linear_layout_offset_bases, linear_layout_block_bases): |
239 | | - self.interval_padding_pairs = "[" + ", ".join(f"{v[0]}:{v[1]:+d}" for v in interval_padding_pairs) + "]" |
240 | | - self.offset_bases = linear_layout_offset_bases |
241 | | - self.block_bases = linear_layout_block_bases |
242 | | - |
243 | | - def __str__(self): |
244 | | - return f"#{GPU_DIALECT}.padded_shared<{self.interval_padding_pairs} {{offset={self.offset_bases}, block={self.block_bases}}}>" |
245 | | - |
246 | | - |
247 | | -class NVMMASharedLayout: |
248 | | - |
249 | | - def __init__(self, swizzle, transpose, element_bit_width, ctas_per_cga, cta_split_num, cta_order): |
250 | | - self.swizzle = swizzle |
251 | | - self.transpose = transpose |
252 | | - self.element_bit_width = element_bit_width |
253 | | - self.ctas_per_cga = ctas_per_cga |
254 | | - self.cta_split_num = cta_split_num |
255 | | - self.cta_order = cta_order |
256 | | - |
257 | | - def __str__(self): |
258 | | - transpose_str = "true" if self.transpose else "false" |
259 | | - return f"#{GPU_DIALECT}.nvmma_shared<{{swizzlingByteWidth={self.swizzle}, transposed={transpose_str}, elementBitWidth={self.element_bit_width}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" |
260 | | - |
261 | | - |
262 | | -class LinearLayout: |
263 | | - |
264 | | - def __init__(self, register, lane, warp, block): |
265 | | - self.register = register |
266 | | - self.lane = lane |
267 | | - self.warp = warp |
268 | | - self.block = block |
269 | | - |
270 | | - def __str__(self): |
271 | | - return f"#{GPU_DIALECT}.linear<{{register={self.register}, lane={self.lane}, warp={self.warp}, block={self.block}}}>" |
272 | | - |
273 | | - |
274 | | -# Python impl of LinearEncodingAttr::basesPerDim |
275 | | -def bases_per_dim(layout, dim, rank, skip_broadcast=True): |
276 | | - assert isinstance(layout, LinearLayout) |
277 | | - bases = getattr(layout, dim) |
278 | | - result = [1] * rank |
279 | | - |
280 | | - if not bases: |
281 | | - return result |
282 | | - |
283 | | - non_zero_idx = None |
284 | | - |
285 | | - for basis in bases: |
286 | | - # Find the first non-zero index in the current basis |
287 | | - idx = next((i for i, v in enumerate(basis) if v != 0), None) |
288 | | - if idx is not None: |
289 | | - non_zero_idx = idx |
290 | | - result[idx] *= 2 |
291 | | - elif not skip_broadcast: |
292 | | - # If no non-zero found and we're not skipping broadcasts, use the last found non-zero index |
293 | | - assert non_zero_idx is not None |
294 | | - result[non_zero_idx] *= 2 |
295 | | - |
296 | | - return result |
297 | | - |
298 | | - |
299 | | -def warps_per_cta(layout, shape): |
300 | | - if isinstance(layout, LinearLayout): |
301 | | - return bases_per_dim(layout, 'warp', len(shape)) |
302 | | - elif isinstance(layout, (SliceLayout, DotOperandLayout)): |
303 | | - return warps_per_cta(layout.parent, shape) |
304 | | - else: |
305 | | - return layout.warps_per_cta |
306 | | - |
307 | | - |
308 | | -def is_layout_applicable(layout) -> bool: |
309 | | - if isinstance(layout, (BlockedLayout, SwizzledSharedLayout, LinearLayout)): |
310 | | - return True |
311 | | - elif isinstance(layout, SliceLayout): |
312 | | - return is_layout_applicable(layout.parent) |
313 | | - elif is_cuda(): |
314 | | - mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout |
315 | | - if not isinstance(mma_layout, MmaLayout): |
316 | | - return False |
317 | | - if mma_layout.version[0] >= 3 and not is_hopper_or_newer(): |
318 | | - return False |
319 | | - return True |
320 | | - elif is_hip(): |
321 | | - target_arch = triton.runtime.driver.active.get_current_target().arch |
322 | | - if isinstance(layout, PaddedSharedLayout): |
323 | | - return True |
324 | | - elif any(arch for arch in ["gfx11", "gfx12"] if arch in target_arch): |
325 | | - # RDNA 3, 4 |
326 | | - return isinstance(layout, WmmaLayout) |
327 | | - elif any(arch for arch in ["gfx8", "gfx9"] if arch in target_arch): |
328 | | - # CDNA 1, 2, 3, 4 |
329 | | - return isinstance(layout, MfmaLayout) |
330 | | - else: |
331 | | - return False |
332 | | - else: |
333 | | - return True |
334 | | - |
335 | | - |
336 | | -def filter_layouts(layouts): |
337 | | - return [l for l in layouts if is_layout_applicable(l)] |
338 | | - |
339 | | - |
340 | 145 | @pytest.mark.interpreter |
341 | 146 | def test_scalar_overflow(device): |
342 | 147 |
|
@@ -5722,91 +5527,6 @@ def kernel(Out): |
5722 | 5527 | assert h.asm["ptx"].count("%smid") == 1 |
5723 | 5528 |
|
5724 | 5529 |
|
5725 | | -# ----------------------- |
5726 | | -# test layout conversions |
5727 | | -# ----------------------- |
5728 | | -# TODO: backend should be tested separately |
5729 | | - |
5730 | | - |
5731 | | -@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size", |
5732 | | - [[128, 128, 64, 64], [128, 128, 64, 32], [128, 64, 64, 32], [256, 128, 64, 64]]) |
5733 | | -def test_split_subview(M, N, M_tile_size, N_tile_size, device, tmp_path: pathlib.Path): |
5734 | | - num_rows_per_warp = THREADS_PER_WARP // 4 |
5735 | | - num_repeats_M = triton.cdiv(M, M_tile_size) |
5736 | | - num_repeats_N = triton.cdiv(N, N_tile_size) |
5737 | | - |
5738 | | - ir = f""" |
5739 | | - #blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[{num_rows_per_warp}, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}}> |
5740 | | - #shared = #ttg.swizzled_shared<{{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}}> |
5741 | | - #smem = #ttg.shared_memory |
5742 | | -
|
5743 | | - module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ |
5744 | | - tt.func public @kernel(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{ |
5745 | | - %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> |
5746 | | - %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> |
5747 | | - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> |
5748 | | - %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> |
5749 | | - %2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #blocked> |
5750 | | - %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> |
5751 | | - %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> |
5752 | | - %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> |
5753 | | - %7 = tt.broadcast %6 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> |
5754 | | - %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> |
5755 | | - %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> |
5756 | | - %ptrs = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #blocked>, tensor<{M}x{N}xi32, #blocked> |
5757 | | - %11 = tt.load %ptrs {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr<f16>, #blocked> |
5758 | | -
|
5759 | | - %c0_i32 = arith.constant 0 : i32 |
5760 | | -
|
5761 | | - %12 = ttg.local_alloc : () -> !ttg.memdesc<1x{M}x{N}xf16, #shared, #smem, mutable> |
5762 | | - %13 = ttg.memdesc_index %12[%c0_i32] : !ttg.memdesc<1x{M}x{N}xf16, #shared, #smem, mutable> -> !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable> |
5763 | | - ttg.local_store %11, %13 : tensor<{M}x{N}xf16, #blocked> -> !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable> |
5764 | | -
|
5765 | | - """ |
5766 | | - |
5767 | | - for m in range(num_repeats_M): |
5768 | | - for n in range(num_repeats_N): |
5769 | | - linear_idx = n + m * num_repeats_N |
5770 | | - m_offset = m * M_tile_size |
5771 | | - n_offset = n * N_tile_size |
5772 | | - ir += f""" |
5773 | | - %view{linear_idx} = ttg.memdesc_subslice %13[{m_offset}, {n_offset}] : !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable> -> !ttg.memdesc<{M_tile_size}x{N_tile_size}xf16, #shared, #smem, mutable, {M}x{N}> |
5774 | | - %data{linear_idx} = ttg.local_load %view{linear_idx} : !ttg.memdesc<{M_tile_size}x{N_tile_size}xf16, #shared, #smem, mutable, {M}x{N}> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> |
5775 | | - %inc{linear_idx} = arith.constant dense<{linear_idx}.0> : tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> |
5776 | | -
|
5777 | | - %res{linear_idx} = arith.addf %data{linear_idx}, %inc{linear_idx} : tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> |
5778 | | - ttg.local_store %res{linear_idx}, %view{linear_idx} : tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> -> !ttg.memdesc<{M_tile_size}x{N_tile_size}xf16, #shared, #smem, mutable, {M}x{N}> |
5779 | | - """ |
5780 | | - |
5781 | | - ir += f""" |
5782 | | - %res = ttg.local_load %13 : !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable> -> tensor<{M}x{N}xf16, #blocked> |
5783 | | - tt.store %ptrs, %res : tensor<{M}x{N}x!tt.ptr<f16>, #blocked> |
5784 | | - tt.return |
5785 | | - }} |
5786 | | - }} |
5787 | | - """ |
5788 | | - |
5789 | | - temp_file = tmp_path / "test_split_subview.ttgir" |
5790 | | - temp_file.write_text(ir) |
5791 | | - kernel = triton.compile(str(temp_file)) |
5792 | | - |
5793 | | - triton_result = torch.zeros((M, N), device=device, dtype=torch.float16) |
5794 | | - kernel[(1, 1, 1)](triton_result.data_ptr()) |
5795 | | - |
5796 | | - rows = [] |
5797 | | - for m in range(num_repeats_M): |
5798 | | - columns = [] |
5799 | | - for n in range(num_repeats_N): |
5800 | | - linear_idx = n + m * num_repeats_N |
5801 | | - tile = float(linear_idx) * torch.ones((M_tile_size, N_tile_size), device=device, dtype=torch.float16) |
5802 | | - columns.append(tile) |
5803 | | - rows.append(torch.cat(columns, dim=1)) |
5804 | | - expected_result = torch.cat(rows, dim=0) |
5805 | | - |
5806 | | - test_result = torch.equal(triton_result, expected_result) |
5807 | | - assert test_result |
5808 | | - |
5809 | | - |
5810 | 5530 | @pytest.mark.interpreter |
5811 | 5531 | def test_load_scalar_with_mask(device): |
5812 | 5532 |
|
|
0 commit comments