Skip to content

Commit ce47711

Browse files
authored
[GLUON][TEST] Generate correct linear layouts for testing (#8033)
Previously passing the "linear_layout" string in the reduce test is wrong because _filter_layouts will skip the string and yield no test. This PR should also cover problems we found in triton-lang/triton#8016
1 parent 084d620 commit ce47711

File tree

1 file changed

+86
-47
lines changed

1 file changed

+86
-47
lines changed

python/test/gluon/test_lowerings.py

Lines changed: 86 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,34 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.cons
8282
torch.testing.assert_close(z_tri, z_ref)
8383

8484

85-
@pytest.mark.parametrize("M, N", [[128, 16], [32, 128], [32, 32], [16, 16]])
86-
@pytest.mark.parametrize(
87-
"src_layout",
88-
_filter_layouts([
85+
def _reduce_linear_layouts():
86+
if THREADS_PER_WARP == 32:
87+
return [
88+
ttgl.DistributedLinearLayout(
89+
reg_bases=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]],
90+
lane_bases=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]],
91+
warp_bases=[[32, 0], [0, 32]],
92+
block_bases=[],
93+
shape=[64, 64],
94+
)
95+
]
96+
elif THREADS_PER_WARP == 64:
97+
return [
98+
ttgl.DistributedLinearLayout(
99+
reg_bases=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]],
100+
lane_bases=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 64]],
101+
warp_bases=[[32, 0], [0, 32]],
102+
block_bases=[],
103+
shape=[64, 128],
104+
)
105+
]
106+
else:
107+
raise RuntimeError(f"Unsupported THREADS_PER_WARP: {THREADS_PER_WARP}")
108+
109+
110+
def _reduce_layouts():
111+
shapes = [(128, 16), (32, 128), (32, 32), (16, 16)]
112+
layouts = _filter_layouts([
89113
# FIXME: Do not enable these tests until the SLPVectorizor problem with nvptx target has been resolved
90114
# SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 1, 4], [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2])),
91115
# SliceLayout(dim=0, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 4, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2])),
@@ -104,47 +128,50 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.cons
104128
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[1, 4], tiles_per_warp=[1, 1], instr_shape=[32, 32],
105129
transposed=True),
106130
# TODO: AMDWMMA layouts
107-
# WmmaLayout(version=1, warps_per_cta=[4, 1]),
108-
# WmmaLayout(version=1, warps_per_cta=[1, 4]),
109131
ttgl.DotOperandLayout(
110-
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], ctas_per_cga=[1, 1], #
111-
cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), #
132+
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], ctas_per_cga=[1, 1],
133+
cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]),
112134
operand_index=1, k_width=8),
113135
ttgl.DotOperandLayout(
114-
parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 1], ctas_per_cga=[1, 1], #
115-
cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 32, 16]), #
136+
parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 1], ctas_per_cga=[1, 1],
137+
cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 32, 16]),
116138
operand_index=0, k_width=2),
117139
ttgl.SliceLayout(
118-
dim=0,
119-
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1], #
120-
cta_split_num=[1, 1, 1], cta_order=[2, 1, 0], instr_shape=[1, 16,
121-
8])), #
140+
dim=0, parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1],
141+
cta_split_num=[1, 1, 1], cta_order=[2, 1,
142+
0], instr_shape=[1, 16, 8])),
122143
ttgl.SliceLayout(
123144
dim=1, parent=ttgl.DotOperandLayout(
124-
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1], #
125-
cta_split_num=[1, 1, 1], cta_order=[2, 1, 0], instr_shape=[1, 16,
126-
8]), #
127-
operand_index=1, k_width=2)),
128-
"linear_layout",
129-
]))
145+
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1],
146+
cta_split_num=[1, 1, 1], cta_order=[2, 1, 0],
147+
instr_shape=[1, 16, 8]), operand_index=1, k_width=2)),
148+
])
149+
150+
rets = []
151+
for (M, N) in shapes:
152+
for layout in layouts:
153+
if isinstance(layout, (ttgl.amd.AMDMFMALayout, ttgl.NVMMADistributedLayout)):
154+
instr_shape = layout.instr_shape
155+
if M < instr_shape[0] or N < instr_shape[1]:
156+
continue
157+
rets.append((M, N, layout))
158+
return rets
159+
160+
161+
def _reduce_cases():
162+
for layout in _reduce_linear_layouts():
163+
yield (layout.shape[0], layout.shape[1], layout)
164+
for M, N, layout in _reduce_layouts():
165+
yield (M, N, layout)
166+
167+
168+
@pytest.mark.parametrize("M, N, src_layout", _reduce_cases())
130169
@pytest.mark.parametrize("axis", [0, 1])
131170
@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d'])
132171
@pytest.mark.parametrize("dtype_str, sanitize_overflow", [("int32", False), ("int32", True), ("float32", False),
133172
("float16", False)])
134173
@pytest.mark.parametrize("reduce_op", ["sum", "max"])
135174
def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, sanitize_overflow, reduce_op, device):
136-
if src_layout == "linear_layout":
137-
src_layout = ttgl.DistributedLinearLayout(reg_bases=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], #
138-
lane_bases=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], #
139-
warp_bases=[[32, 0], [0, 32]], block_bases=[], shape=[M, N])
140-
if THREADS_PER_WARP != (1 << len(src_layout.lane_bases)):
141-
pytest.skip(f"Skipping. This LinearLayout assumes {1 << len(src_layout.lane_bases)} threads per warp")
142-
elif M < 64 or N < 64:
143-
pytest.skip(f"Skipping. This LinearLayout assumes M >= 64 and N >= 64, got M={M}, N={N}")
144-
if isinstance(src_layout,
145-
(ttgl.amd.AMDMFMALayout, ttgl.NVMMADistributedLayout)) and (M < src_layout.instr_shape[0]
146-
or N < src_layout.instr_shape[1]):
147-
pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape")
148175

149176
@gluon.jit
150177
def _add(a, b):
@@ -240,9 +267,33 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, layout: ttgl.constexpr):
240267
])
241268

242269

243-
@pytest.mark.parametrize("M, bins", [[2048, 2], [8, 512], [32, 32]])
244-
@pytest.mark.parametrize("src_layout", [ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4], [0]), "linear_layout"])
245-
@pytest.mark.parametrize("dst_layout", [ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4], [0])])
270+
def _histogram_cases():
271+
if THREADS_PER_WARP not in (32, 64):
272+
raise RuntimeError(f"Unsupported THREADS_PER_WARP: {THREADS_PER_WARP}")
273+
274+
m_bins = [(2048, 2), (8, 512), (32, 32)]
275+
layouts = [(ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4],
276+
[0]), ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4], [0]))]
277+
for m, bins in m_bins:
278+
for src_layout, dst_layout in layouts:
279+
yield (m, bins, src_layout, dst_layout)
280+
import math
281+
282+
linear_layouts = [(
283+
ttgl.DistributedLinearLayout(
284+
reg_bases=[[1 << (5 + i)] for i in range(int(math.log2(m)) - 5)],
285+
lane_bases=[[0], [16], [4], [2], [1]] + ([[0]] if THREADS_PER_WARP == 64 else []),
286+
warp_bases=[[0], [8]],
287+
block_bases=[],
288+
shape=(m, ),
289+
),
290+
bins,
291+
) for (m, bins) in m_bins if m >= 32]
292+
for linear_layout, bins in linear_layouts:
293+
yield (linear_layout.shape[0], bins, linear_layout, ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4], [0]))
294+
295+
296+
@pytest.mark.parametrize("M, bins, src_layout, dst_layout", _histogram_cases())
246297
def test_histogram(M, bins, src_layout, dst_layout, device):
247298

248299
@gluon.jit
@@ -254,18 +305,6 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, B: ttgl.constexpr, src_layout: ttgl.
254305
z_offs = ttgl.arange(0, B, layout=dst_layout)
255306
ttgl.store(z_ptr + z_offs, h)
256307

257-
if src_layout == "linear_layout":
258-
if M == 32:
259-
src_layout = ttgl.DistributedLinearLayout(
260-
reg_bases=[],
261-
lane_bases=[[0], [16], [4], [2], [1]] + [[0]] * (THREADS_PER_WARP >> 6),
262-
warp_bases=[[0], [8]],
263-
block_bases=[],
264-
shape=(M, ),
265-
)
266-
else:
267-
pytest.skip("Linear layout is specialized for 32 elements")
268-
269308
torch.manual_seed(0)
270309
x = torch.randint(0, bins, (M, ), dtype=torch.int32, device=device)
271310
z = torch.zeros((bins, ), dtype=torch.int32, device=device)

0 commit comments

Comments
 (0)