Skip to content

Commit 0ef5eae

Browse files
authored
[GLUON][TEST] Finish subslice test and remove all layout helpers from Triton's test_core.py (triton-lang#8049)
1 parent 0ce5d77 commit 0ef5eae

File tree

2 files changed

+45
-280
lines changed

2 files changed

+45
-280
lines changed

python/test/gluon/test_lowerings.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,3 +1209,48 @@ def test_gather_layouts(axis, src_layout, index_layout, src_shape, idx_shape, de
12091209

12101210
torch.testing.assert_close(out, ref, rtol=0, atol=0)
12111211
assert ("nvvm.shfl.sync.idx" in obj.asm["llir"]) or ("llvm.amdgcn.ds.bpermute" in obj.asm["llir"])
1212+
1213+
1214+
@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size",
1215+
[[128, 128, 64, 64], [128, 128, 64, 32], [128, 64, 64, 32], [256, 128, 64, 64]])
1216+
def test_memdesc_subslice(M, N, M_tile_size, N_tile_size, device):
1217+
if M % M_tile_size != 0 or N % N_tile_size != 0:
1218+
pytest.skip(f"Shape size ({M}, {N}) must be divisible by tile size ({M_tile_size}, {N_tile_size})")
1219+
1220+
num_rows_per_warp = THREADS_PER_WARP // 4
1221+
blocked_layout = ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[num_rows_per_warp, 4],
1222+
warps_per_cta=[4, 1], order=[1, 0])
1223+
shared_layout = ttgl.SwizzledSharedLayout(vec=8, per_phase=1, max_phase=8, order=[1, 0])
1224+
1225+
@gluon.jit
1226+
def kernel(
1227+
out,
1228+
M: ttgl.constexpr,
1229+
N: ttgl.constexpr,
1230+
BLOCK_SIZE_M: ttgl.constexpr,
1231+
BLOCK_SIZE_N: ttgl.constexpr,
1232+
blocked_layout: ttgl.constexpr,
1233+
shared_layout: ttgl.constexpr,
1234+
):
1235+
offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, blocked_layout))[:, None]
1236+
offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, blocked_layout))[None, :]
1237+
vals = ttgl.load(out + offs_m * N + offs_n)
1238+
1239+
smem: ttgl.shared_memory_descriptor = ttgl.allocate_shared_memory(vals.dtype, (M, N), shared_layout, value=vals)
1240+
for i in ttgl.static_range(M // BLOCK_SIZE_M):
1241+
for j in ttgl.static_range(N // BLOCK_SIZE_N):
1242+
tile = smem.slice(i * BLOCK_SIZE_M, BLOCK_SIZE_M, dim=0).slice(j * BLOCK_SIZE_N, BLOCK_SIZE_N, dim=1)
1243+
tile_vals = tile.load(blocked_layout)
1244+
tile_offs_m = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, blocked_layout))[:, None]
1245+
tile_offs_n = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, blocked_layout))[None, :]
1246+
linear_idx = tile_offs_m * N + tile_offs_n + i * BLOCK_SIZE_M * N + j * BLOCK_SIZE_N
1247+
tile.store(linear_idx + tile_vals)
1248+
1249+
vals = smem.load(blocked_layout)
1250+
ttgl.store(out + offs_m * N + offs_n, vals)
1251+
1252+
out = torch.zeros((M, N), device=device, dtype=torch.float16)
1253+
kernel[(1, )](out, M, N, M_tile_size, N_tile_size, blocked_layout, shared_layout)
1254+
1255+
out_ref = torch.arange(0, M * N, device=device).reshape((M, N)).to(torch.float16)
1256+
torch.testing.assert_close(out, out_ref, rtol=0, atol=0)

python/test/unit/language/test_core.py

Lines changed: 0 additions & 280 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Optional
66
import math
77
import textwrap
8-
import pathlib
98

109
import numpy as np
1110
import pytest
@@ -29,7 +28,6 @@
2928
is_cuda,
3029
is_interpreter,
3130
is_hopper,
32-
is_hopper_or_newer,
3331
is_hip,
3432
is_hip_cdna,
3533
is_hip_cdna2,
@@ -144,199 +142,6 @@ def get_src_element_ty_size(dtype_str):
144142
raise ValueError(f"Unknown dtype {dtype_str}")
145143

146144

147-
class MfmaLayout:
148-
149-
def __init__(self, version, warps_per_cta, tiles_per_warp, instr_shape, is_transposed):
150-
self.version = version
151-
self.warps_per_cta = warps_per_cta
152-
self.tiles_per_warp = tiles_per_warp
153-
self.instr_shape = instr_shape
154-
self.is_transposed = is_transposed
155-
156-
def __str__(self):
157-
return f"#{GPU_DIALECT}.amd_mfma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA = {self.warps_per_cta}, tilesPerWarp = {self.tiles_per_warp}, instrShape={self.instr_shape}, isTransposed = {str(self.is_transposed).lower()}}}>"
158-
159-
160-
class WmmaLayout:
161-
162-
def __init__(self, version, warps_per_cta):
163-
self.version = version
164-
self.warps_per_cta = warps_per_cta
165-
166-
def __str__(self):
167-
return f"#{GPU_DIALECT}.amd_wmma<{{version = {self.version}, warpsPerCTA = {self.warps_per_cta}}}>"
168-
169-
170-
class MmaLayout:
171-
172-
def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape):
173-
self.version = version
174-
self.warps_per_cta = warps_per_cta
175-
self.ctas_per_cga = ctas_per_cga
176-
self.cta_split_num = cta_split_num
177-
self.cta_order = cta_order
178-
self.instr_shape = instr_shape
179-
180-
def __str__(self):
181-
return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>"
182-
183-
184-
class DotOperandLayout:
185-
186-
def __init__(self, parent, op_idx, k_width):
187-
self.parent = parent
188-
self.op_idx = op_idx
189-
self.k_width = k_width
190-
191-
def __str__(self):
192-
return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>"
193-
194-
195-
class SliceLayout:
196-
197-
def __init__(self, dim, parent):
198-
self.dim = dim
199-
self.parent = parent
200-
201-
def __str__(self):
202-
return f"#{GPU_DIALECT}.slice<{{dim = {self.dim}, parent = {self.parent}}}>"
203-
204-
205-
class BlockedLayout:
206-
207-
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1],
208-
cta_split_num=[1, 1], cta_order=[0, 1]):
209-
self.sz_per_thread = size_per_thread
210-
self.threads_per_warp = threads_per_warp
211-
self.warps_per_cta = warps_per_cta
212-
self.order = order
213-
self.ctas_per_cga = ctas_per_cga
214-
self.cta_split_num = cta_split_num
215-
self.cta_order = cta_order
216-
217-
def __str__(self):
218-
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
219-
220-
221-
class SwizzledSharedLayout:
222-
223-
def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order):
224-
self.vec = vec
225-
self.per_phase = per_phase
226-
self.max_phase = max_phase
227-
self.order = order
228-
self.ctas_per_cga = ctas_per_cga
229-
self.cta_split_num = cta_split_num
230-
self.cta_order = cta_order
231-
232-
def __str__(self):
233-
return f"#{GPU_DIALECT}.swizzled_shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
234-
235-
236-
class PaddedSharedLayout:
237-
238-
def __init__(self, interval_padding_pairs, linear_layout_offset_bases, linear_layout_block_bases):
239-
self.interval_padding_pairs = "[" + ", ".join(f"{v[0]}:{v[1]:+d}" for v in interval_padding_pairs) + "]"
240-
self.offset_bases = linear_layout_offset_bases
241-
self.block_bases = linear_layout_block_bases
242-
243-
def __str__(self):
244-
return f"#{GPU_DIALECT}.padded_shared<{self.interval_padding_pairs} {{offset={self.offset_bases}, block={self.block_bases}}}>"
245-
246-
247-
class NVMMASharedLayout:
248-
249-
def __init__(self, swizzle, transpose, element_bit_width, ctas_per_cga, cta_split_num, cta_order):
250-
self.swizzle = swizzle
251-
self.transpose = transpose
252-
self.element_bit_width = element_bit_width
253-
self.ctas_per_cga = ctas_per_cga
254-
self.cta_split_num = cta_split_num
255-
self.cta_order = cta_order
256-
257-
def __str__(self):
258-
transpose_str = "true" if self.transpose else "false"
259-
return f"#{GPU_DIALECT}.nvmma_shared<{{swizzlingByteWidth={self.swizzle}, transposed={transpose_str}, elementBitWidth={self.element_bit_width}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
260-
261-
262-
class LinearLayout:
263-
264-
def __init__(self, register, lane, warp, block):
265-
self.register = register
266-
self.lane = lane
267-
self.warp = warp
268-
self.block = block
269-
270-
def __str__(self):
271-
return f"#{GPU_DIALECT}.linear<{{register={self.register}, lane={self.lane}, warp={self.warp}, block={self.block}}}>"
272-
273-
274-
# Python impl of LinearEncodingAttr::basesPerDim
275-
def bases_per_dim(layout, dim, rank, skip_broadcast=True):
276-
assert isinstance(layout, LinearLayout)
277-
bases = getattr(layout, dim)
278-
result = [1] * rank
279-
280-
if not bases:
281-
return result
282-
283-
non_zero_idx = None
284-
285-
for basis in bases:
286-
# Find the first non-zero index in the current basis
287-
idx = next((i for i, v in enumerate(basis) if v != 0), None)
288-
if idx is not None:
289-
non_zero_idx = idx
290-
result[idx] *= 2
291-
elif not skip_broadcast:
292-
# If no non-zero found and we're not skipping broadcasts, use the last found non-zero index
293-
assert non_zero_idx is not None
294-
result[non_zero_idx] *= 2
295-
296-
return result
297-
298-
299-
def warps_per_cta(layout, shape):
300-
if isinstance(layout, LinearLayout):
301-
return bases_per_dim(layout, 'warp', len(shape))
302-
elif isinstance(layout, (SliceLayout, DotOperandLayout)):
303-
return warps_per_cta(layout.parent, shape)
304-
else:
305-
return layout.warps_per_cta
306-
307-
308-
def is_layout_applicable(layout) -> bool:
309-
if isinstance(layout, (BlockedLayout, SwizzledSharedLayout, LinearLayout)):
310-
return True
311-
elif isinstance(layout, SliceLayout):
312-
return is_layout_applicable(layout.parent)
313-
elif is_cuda():
314-
mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout
315-
if not isinstance(mma_layout, MmaLayout):
316-
return False
317-
if mma_layout.version[0] >= 3 and not is_hopper_or_newer():
318-
return False
319-
return True
320-
elif is_hip():
321-
target_arch = triton.runtime.driver.active.get_current_target().arch
322-
if isinstance(layout, PaddedSharedLayout):
323-
return True
324-
elif any(arch for arch in ["gfx11", "gfx12"] if arch in target_arch):
325-
# RDNA 3, 4
326-
return isinstance(layout, WmmaLayout)
327-
elif any(arch for arch in ["gfx8", "gfx9"] if arch in target_arch):
328-
# CDNA 1, 2, 3, 4
329-
return isinstance(layout, MfmaLayout)
330-
else:
331-
return False
332-
else:
333-
return True
334-
335-
336-
def filter_layouts(layouts):
337-
return [l for l in layouts if is_layout_applicable(l)]
338-
339-
340145
@pytest.mark.interpreter
341146
def test_scalar_overflow(device):
342147

@@ -5722,91 +5527,6 @@ def kernel(Out):
57225527
assert h.asm["ptx"].count("%smid") == 1
57235528

57245529

5725-
# -----------------------
5726-
# test layout conversions
5727-
# -----------------------
5728-
# TODO: backend should be tested separately
5729-
5730-
5731-
@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size",
5732-
[[128, 128, 64, 64], [128, 128, 64, 32], [128, 64, 64, 32], [256, 128, 64, 64]])
5733-
def test_split_subview(M, N, M_tile_size, N_tile_size, device, tmp_path: pathlib.Path):
5734-
num_rows_per_warp = THREADS_PER_WARP // 4
5735-
num_repeats_M = triton.cdiv(M, M_tile_size)
5736-
num_repeats_N = triton.cdiv(N, N_tile_size)
5737-
5738-
ir = f"""
5739-
#blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[{num_rows_per_warp}, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}}>
5740-
#shared = #ttg.swizzled_shared<{{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}}>
5741-
#smem = #ttg.shared_memory
5742-
5743-
module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
5744-
tt.func public @kernel(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
5745-
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
5746-
%cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked>
5747-
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>>
5748-
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>>
5749-
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #blocked>
5750-
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked>
5751-
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked>
5752-
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked>
5753-
%7 = tt.broadcast %6 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked>
5754-
%8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked>
5755-
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked>
5756-
%ptrs = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #blocked>, tensor<{M}x{N}xi32, #blocked>
5757-
%11 = tt.load %ptrs {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr<f16>, #blocked>
5758-
5759-
%c0_i32 = arith.constant 0 : i32
5760-
5761-
%12 = ttg.local_alloc : () -> !ttg.memdesc<1x{M}x{N}xf16, #shared, #smem, mutable>
5762-
%13 = ttg.memdesc_index %12[%c0_i32] : !ttg.memdesc<1x{M}x{N}xf16, #shared, #smem, mutable> -> !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable>
5763-
ttg.local_store %11, %13 : tensor<{M}x{N}xf16, #blocked> -> !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable>
5764-
5765-
"""
5766-
5767-
for m in range(num_repeats_M):
5768-
for n in range(num_repeats_N):
5769-
linear_idx = n + m * num_repeats_N
5770-
m_offset = m * M_tile_size
5771-
n_offset = n * N_tile_size
5772-
ir += f"""
5773-
%view{linear_idx} = ttg.memdesc_subslice %13[{m_offset}, {n_offset}] : !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable> -> !ttg.memdesc<{M_tile_size}x{N_tile_size}xf16, #shared, #smem, mutable, {M}x{N}>
5774-
%data{linear_idx} = ttg.local_load %view{linear_idx} : !ttg.memdesc<{M_tile_size}x{N_tile_size}xf16, #shared, #smem, mutable, {M}x{N}> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked>
5775-
%inc{linear_idx} = arith.constant dense<{linear_idx}.0> : tensor<{M_tile_size}x{N_tile_size}xf16, #blocked>
5776-
5777-
%res{linear_idx} = arith.addf %data{linear_idx}, %inc{linear_idx} : tensor<{M_tile_size}x{N_tile_size}xf16, #blocked>
5778-
ttg.local_store %res{linear_idx}, %view{linear_idx} : tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> -> !ttg.memdesc<{M_tile_size}x{N_tile_size}xf16, #shared, #smem, mutable, {M}x{N}>
5779-
"""
5780-
5781-
ir += f"""
5782-
%res = ttg.local_load %13 : !ttg.memdesc<{M}x{N}xf16, #shared, #smem, mutable> -> tensor<{M}x{N}xf16, #blocked>
5783-
tt.store %ptrs, %res : tensor<{M}x{N}x!tt.ptr<f16>, #blocked>
5784-
tt.return
5785-
}}
5786-
}}
5787-
"""
5788-
5789-
temp_file = tmp_path / "test_split_subview.ttgir"
5790-
temp_file.write_text(ir)
5791-
kernel = triton.compile(str(temp_file))
5792-
5793-
triton_result = torch.zeros((M, N), device=device, dtype=torch.float16)
5794-
kernel[(1, 1, 1)](triton_result.data_ptr())
5795-
5796-
rows = []
5797-
for m in range(num_repeats_M):
5798-
columns = []
5799-
for n in range(num_repeats_N):
5800-
linear_idx = n + m * num_repeats_N
5801-
tile = float(linear_idx) * torch.ones((M_tile_size, N_tile_size), device=device, dtype=torch.float16)
5802-
columns.append(tile)
5803-
rows.append(torch.cat(columns, dim=1))
5804-
expected_result = torch.cat(rows, dim=0)
5805-
5806-
test_result = torch.equal(triton_result, expected_result)
5807-
assert test_result
5808-
5809-
58105530
@pytest.mark.interpreter
58115531
def test_load_scalar_with_mask(device):
58125532

0 commit comments

Comments
 (0)