Skip to content

Commit ca34989

Browse files
agron911meta-codesync[bot]
authored andcommitted
[Cherry-pick][RESOLVED] [GLUON][TEST] Generate correct linear layouts for testing (#8033) (#633)
Summary: ⚠️ **MERGE CONFLICTS DETECTED** ⚠️ This cherry-pick contains merge conflicts that require manual resolution. Original Commit: ce47711 Original Author: Keren Zhou Original Date: 2025-09-02 15:22:35 -0400 **Action Required:** 1. Check out this branch locally 2. Resolve the merge conflicts in the affected files 3. Commit the resolved changes 4. Update this PR Original commit message: ``` [GLUON][TEST] Generate correct linear layouts for testing (#8033) Previously passing the "linear_layout" string in the reduce test is wrong because _filter_layouts will skip the string and yield no test. This PR should also cover problems we found in triton-lang/triton#8016 ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. The conflicts have been committed with conflict markers for easier resolution. Pull Request resolved: #633 Reviewed By: dshi7 Differential Revision: D86218450 Pulled By: agron911 fbshipit-source-id: 291933dfbbb63791a8746ac8b738ce51706c402d
1 parent 2af19bc commit ca34989

File tree

1 file changed

+89
-86
lines changed

1 file changed

+89
-86
lines changed

python/test/gluon/test_lowerings.py

Lines changed: 89 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,34 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.cons
8383
torch.testing.assert_close(z_tri, z_ref)
8484

8585

86-
@pytest.mark.parametrize("M, N", [[128, 16], [32, 128], [32, 32], [16, 16]])
87-
@pytest.mark.parametrize(
88-
"src_layout",
89-
_filter_layouts([
86+
def _reduce_linear_layouts():
87+
if THREADS_PER_WARP == 32:
88+
return [
89+
ttgl.DistributedLinearLayout(
90+
reg_bases=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]],
91+
lane_bases=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]],
92+
warp_bases=[[32, 0], [0, 32]],
93+
block_bases=[],
94+
shape=[64, 64],
95+
)
96+
]
97+
elif THREADS_PER_WARP == 64:
98+
return [
99+
ttgl.DistributedLinearLayout(
100+
reg_bases=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]],
101+
lane_bases=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 64]],
102+
warp_bases=[[32, 0], [0, 32]],
103+
block_bases=[],
104+
shape=[64, 128],
105+
)
106+
]
107+
else:
108+
raise RuntimeError(f"Unsupported THREADS_PER_WARP: {THREADS_PER_WARP}")
109+
110+
111+
def _reduce_layouts():
112+
shapes = [(128, 16), (32, 128), (32, 32), (16, 16)]
113+
layouts = _filter_layouts([
90114
# FIXME: Do not enable these tests until the SLPVectorizor problem with nvptx target has been resolved
91115
# SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 1, 4], [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2])),
92116
# SliceLayout(dim=0, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 4, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2])),
@@ -117,83 +141,50 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.cons
117141
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[1, 4], tiles_per_warp=[1, 1], instr_shape=[32, 32],
118142
transposed=True),
119143
# TODO: AMDWMMA layouts
120-
# WmmaLayout(version=1, warps_per_cta=[4, 1]),
121-
# WmmaLayout(version=1, warps_per_cta=[1, 4]),
122144
ttgl.DotOperandLayout(
123-
parent=ttgl.NVMMADistributedLayout(
124-
version=[2, 0],
125-
warps_per_cta=[2, 4],
126-
ctas_per_cga=[1, 1], #
127-
cta_split_num=[1, 1],
128-
cta_order=[0, 1],
129-
instr_shape=[16, 8],
130-
), #
131-
operand_index=1,
132-
k_width=8,
133-
),
145+
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], ctas_per_cga=[1, 1],
146+
cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]),
147+
operand_index=1, k_width=8),
134148
ttgl.DotOperandLayout(
135-
parent=ttgl.NVMMADistributedLayout(
136-
version=[3, 0],
137-
warps_per_cta=[8, 1],
138-
ctas_per_cga=[1, 1], #
139-
cta_split_num=[1, 1],
140-
cta_order=[1, 0],
141-
instr_shape=[16, 32, 16],
142-
), #
143-
operand_index=0,
144-
k_width=2,
145-
),
149+
parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 1], ctas_per_cga=[1, 1],
150+
cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 32, 16]),
151+
operand_index=0, k_width=2),
146152
ttgl.SliceLayout(
147-
dim=0,
148-
parent=ttgl.NVMMADistributedLayout(
149-
version=[2, 0],
150-
warps_per_cta=[4, 1, 1],
151-
ctas_per_cga=[1, 1, 1], #
152-
cta_split_num=[1, 1, 1],
153-
cta_order=[2, 1, 0],
154-
instr_shape=[1, 16, 8],
155-
),
156-
), #
153+
dim=0, parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1],
154+
cta_split_num=[1, 1, 1], cta_order=[2, 1,
155+
0], instr_shape=[1, 16, 8])),
157156
ttgl.SliceLayout(
158-
dim=1,
159-
parent=ttgl.DotOperandLayout(
160-
parent=ttgl.NVMMADistributedLayout(
161-
version=[2, 0],
162-
warps_per_cta=[4, 1, 1],
163-
ctas_per_cga=[1, 1, 1], #
164-
cta_split_num=[1, 1, 1],
165-
cta_order=[2, 1, 0],
166-
instr_shape=[1, 16, 8],
167-
), #
168-
operand_index=1,
169-
k_width=2,
170-
),
171-
),
172-
"linear_layout",
173-
]),
174-
)
157+
dim=1, parent=ttgl.DotOperandLayout(
158+
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1],
159+
cta_split_num=[1, 1, 1], cta_order=[2, 1, 0],
160+
instr_shape=[1, 16, 8]), operand_index=1, k_width=2)),
161+
])
162+
163+
rets = []
164+
for (M, N) in shapes:
165+
for layout in layouts:
166+
if isinstance(layout, (ttgl.amd.AMDMFMALayout, ttgl.NVMMADistributedLayout)):
167+
instr_shape = layout.instr_shape
168+
if M < instr_shape[0] or N < instr_shape[1]:
169+
continue
170+
rets.append((M, N, layout))
171+
return rets
172+
173+
174+
def _reduce_cases():
175+
for layout in _reduce_linear_layouts():
176+
yield (layout.shape[0], layout.shape[1], layout)
177+
for M, N, layout in _reduce_layouts():
178+
yield (M, N, layout)
179+
180+
181+
@pytest.mark.parametrize("M, N, src_layout", _reduce_cases())
175182
@pytest.mark.parametrize("axis", [0, 1])
176183
@pytest.mark.parametrize("epilogue_kind", ["reduce1d", "reduce2d", "expand_reduce2d"])
177184
@pytest.mark.parametrize("dtype_str, sanitize_overflow", [("int32", False), ("int32", True), ("float32", False),
178185
("float16", False)])
179186
@pytest.mark.parametrize("reduce_op", ["sum", "max"])
180187
def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, sanitize_overflow, reduce_op, device):
181-
if src_layout == "linear_layout":
182-
src_layout = ttgl.DistributedLinearLayout(
183-
reg_bases=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], #
184-
lane_bases=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], #
185-
warp_bases=[[32, 0], [0, 32]],
186-
block_bases=[],
187-
shape=[M, N],
188-
)
189-
if THREADS_PER_WARP != (1 << len(src_layout.lane_bases)):
190-
pytest.skip(f"Skipping. This LinearLayout assumes {1 << len(src_layout.lane_bases)} threads per warp")
191-
elif M < 64 or N < 64:
192-
pytest.skip(f"Skipping. This LinearLayout assumes M >= 64 and N >= 64, got M={M}, N={N}")
193-
if isinstance(src_layout,
194-
(ttgl.amd.AMDMFMALayout, ttgl.NVMMADistributedLayout)) and (M < src_layout.instr_shape[0]
195-
or N < src_layout.instr_shape[1]):
196-
pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape")
197188

198189
@gluon.jit
199190
def _add(a, b):
@@ -341,9 +332,33 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, layout: ttgl.constexpr):
341332
])
342333

343334

344-
@pytest.mark.parametrize("M, bins", [[2048, 2], [8, 512], [32, 32]])
345-
@pytest.mark.parametrize("src_layout", [ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4], [0]), "linear_layout"])
346-
@pytest.mark.parametrize("dst_layout", [ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4], [0])])
335+
def _histogram_cases():
336+
if THREADS_PER_WARP not in (32, 64):
337+
raise RuntimeError(f"Unsupported THREADS_PER_WARP: {THREADS_PER_WARP}")
338+
339+
m_bins = [(2048, 2), (8, 512), (32, 32)]
340+
layouts = [(ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4],
341+
[0]), ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4], [0]))]
342+
for m, bins in m_bins:
343+
for src_layout, dst_layout in layouts:
344+
yield (m, bins, src_layout, dst_layout)
345+
import math
346+
347+
linear_layouts = [(
348+
ttgl.DistributedLinearLayout(
349+
reg_bases=[[1 << (5 + i)] for i in range(int(math.log2(m)) - 5)],
350+
lane_bases=[[0], [16], [4], [2], [1]] + ([[0]] if THREADS_PER_WARP == 64 else []),
351+
warp_bases=[[0], [8]],
352+
block_bases=[],
353+
shape=(m, ),
354+
),
355+
bins,
356+
) for (m, bins) in m_bins if m >= 32]
357+
for linear_layout, bins in linear_layouts:
358+
yield (linear_layout.shape[0], bins, linear_layout, ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4], [0]))
359+
360+
361+
@pytest.mark.parametrize("M, bins, src_layout, dst_layout", _histogram_cases())
347362
def test_histogram(M, bins, src_layout, dst_layout, device):
348363

349364
@gluon.jit
@@ -355,18 +370,6 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, B: ttgl.constexpr, src_layout: ttgl.
355370
z_offs = ttgl.arange(0, B, layout=dst_layout)
356371
ttgl.store(z_ptr + z_offs, h)
357372

358-
if src_layout == "linear_layout":
359-
if M == 32:
360-
src_layout = ttgl.DistributedLinearLayout(
361-
reg_bases=[],
362-
lane_bases=[[0], [16], [4], [2], [1]] + [[0]] * (THREADS_PER_WARP >> 6),
363-
warp_bases=[[0], [8]],
364-
block_bases=[],
365-
shape=(M, ),
366-
)
367-
else:
368-
pytest.skip("Linear layout is specialized for 32 elements")
369-
370373
torch.manual_seed(0)
371374
x = torch.randint(0, bins, (M, ), dtype=torch.int32, device=device)
372375
z = torch.zeros((bins, ), dtype=torch.int32, device=device)

0 commit comments

Comments
 (0)