Skip to content

Commit e3b6aaa

Browse files
Merge commit 'fff5a2dd02081b4b4e6fbeae8b55ff46a6d89462'
2 parents d847b10 + fff5a2d commit e3b6aaa

File tree

25 files changed

+618
-397
lines changed

25 files changed

+618
-397
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ test-unit: all
3535
--ignore=language/test_subprocess.py --ignore=test_debug.py
3636
$(PYTEST) -s -n $(NUM_PROCS) python/test/unit/language/test_subprocess.py
3737
$(PYTEST) -s -n $(NUM_PROCS) python/test/unit/test_debug.py --forked
38-
$(PYTEST) -s -n 8 python/triton_kernels/tests/
38+
$(PYTEST) -s -n 6 python/triton_kernels/tests/
3939
TRITON_DISABLE_LINE_INFO=0 $(PYTEST) -s python/test/unit/language/test_line_info.py
4040
# Run attention separately to avoid out of gpu memory
4141
$(PYTEST) -vs python/tutorials/06-fused-attention.py

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
bc773632355b3cebde350b0341624e88be40b744
1+
064f02dac0c81c19350a74415b3245f42fed09dc

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,8 @@ class ScaledBlockedToMMAv5
645645
auto CTALayout = getCTALayout(oldRetType.getEncoding());
646646
if ((computeCapability) / 10 != 10)
647647
return failure();
648+
if (numWarps != 4 && numWarps != 8)
649+
return failure();
648650
if (retShapePerCTA[0] < 128 || retShapePerCTA[1] < 8)
649651
return failure();
650652
Location loc = dotOp.getLoc();

python/test/gluon/test_lowerings.py

Lines changed: 152 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,49 @@
44
import triton
55
from triton.experimental import gluon
66
from triton.experimental.gluon import language as ttgl
7+
from triton._internal_testing import is_cuda, is_hip, is_hopper_or_newer
8+
9+
10+
def _is_layout_applicable(layout) -> bool:
11+
if isinstance(layout, ttgl.SliceLayout):
12+
return _is_layout_applicable(layout.parent)
13+
elif is_cuda():
14+
mma_layout = layout.parent if isinstance(layout, ttgl.DotOperandLayout) else layout
15+
if not isinstance(mma_layout, ttgl.NVMMADistributedLayout):
16+
return False
17+
if mma_layout.version[0] >= 3 and not is_hopper_or_newer():
18+
return False
19+
return True
20+
elif is_hip():
21+
# TODO: Add other amd layouts
22+
return isinstance(layout, ttgl.amd.AMDMFMALayout)
23+
else:
24+
return True
25+
26+
27+
def _filter_layouts(layouts):
28+
return [l for l in layouts if _is_layout_applicable(l)]
29+
730

831
THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size
932

1033

1134
@pytest.mark.parametrize("M, N", [(32, 16), (32, 32), (32, 64), (64, 32)])
12-
@pytest.mark.parametrize("src_layout", [
13-
ttgl.BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1]),
14-
ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]),
15-
ttgl.BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1]),
16-
ttgl.BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1]),
17-
ttgl.BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1]),
18-
ttgl.BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0]),
19-
ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]),
20-
ttgl.BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0]),
21-
ttgl.BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0]),
22-
ttgl.BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0]),
23-
ttgl.BlockedLayout([1, 2], [1, THREADS_PER_WARP], [1, 4], [1, 0]),
24-
])
35+
@pytest.mark.parametrize(
36+
"src_layout",
37+
_filter_layouts([
38+
ttgl.BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1]),
39+
ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]),
40+
ttgl.BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1]),
41+
ttgl.BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1]),
42+
ttgl.BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1]),
43+
ttgl.BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0]),
44+
ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]),
45+
ttgl.BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0]),
46+
ttgl.BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0]),
47+
ttgl.BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0]),
48+
ttgl.BlockedLayout([1, 2], [1, THREADS_PER_WARP], [1, 4], [1, 0]),
49+
]))
2550
@pytest.mark.parametrize("axis", [0, 1])
2651
@pytest.mark.parametrize("sanitize_overflow", [False, True])
2752
def test_scan_layouts(M, N, src_layout, axis, sanitize_overflow, device):
@@ -49,3 +74,117 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.cons
4974

5075
z_ref = torch.cumsum(x, dim=axis, dtype=torch.int32)
5176
torch.testing.assert_close(z_tri, z_ref)
77+
78+
79+
@pytest.mark.parametrize("M, N", [[128, 16], [32, 128], [32, 32], [16, 16]])
80+
@pytest.mark.parametrize(
81+
"src_layout",
82+
_filter_layouts([
83+
# FIXME: Do not enable these tests until the SLPVectorizor problem with nvptx target has been resolved
84+
# SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 1, 4], [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2])),
85+
# SliceLayout(dim=0, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 4, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2])),
86+
ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
87+
ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
88+
ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1],
89+
cta_order=[0, 1], instr_shape=[16, 8]),
90+
ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1],
91+
cta_order=[1, 0], instr_shape=[16, 16, 16]),
92+
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32],
93+
transposed=False),
94+
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[1, 4], tiles_per_warp=[1, 1], instr_shape=[32, 32],
95+
transposed=False),
96+
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32],
97+
transposed=True),
98+
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[1, 4], tiles_per_warp=[1, 1], instr_shape=[32, 32],
99+
transposed=True),
100+
# TODO: AMDWMMA layouts
101+
# WmmaLayout(version=1, warps_per_cta=[4, 1]),
102+
# WmmaLayout(version=1, warps_per_cta=[1, 4]),
103+
ttgl.DotOperandLayout(
104+
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], ctas_per_cga=[1, 1], #
105+
cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), #
106+
operand_index=1, k_width=8),
107+
ttgl.DotOperandLayout(
108+
parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 1], ctas_per_cga=[1, 1], #
109+
cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 32, 16]), #
110+
operand_index=0, k_width=2),
111+
ttgl.SliceLayout(
112+
dim=0,
113+
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1], #
114+
cta_split_num=[1, 1, 1], cta_order=[2, 1, 0], instr_shape=[1, 16,
115+
8])), #
116+
ttgl.SliceLayout(
117+
dim=1, parent=ttgl.DotOperandLayout(
118+
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1], #
119+
cta_split_num=[1, 1, 1], cta_order=[2, 1, 0], instr_shape=[1, 16,
120+
8]), #
121+
operand_index=1, k_width=2)),
122+
"linear_layout",
123+
]))
124+
@pytest.mark.parametrize("axis", [0, 1])
125+
@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d'])
126+
@pytest.mark.parametrize("dtype_str, sanitize_overflow", [("int32", False), ("int32", True), ("float32", False),
127+
("float16", False)])
128+
@pytest.mark.parametrize("reduce_op", ["sum", "max"])
129+
def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, sanitize_overflow, reduce_op, device):
130+
if src_layout == "linear_layout":
131+
src_layout = ttgl.DistributedLinearLayout(reg_bases=[[0, 16], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], #
132+
lane_bases=[[0, 0], [0, 1], [0, 2], [0, 4], [0, 8]], #
133+
warp_bases=[[32, 0], [0, 32]], block_bases=[], shape=[M, N])
134+
if THREADS_PER_WARP != (1 << len(src_layout.lane_bases)):
135+
pytest.skip(f"Skipping. This LinearLayout assumes {1 << len(src_layout.lane_bases)} threads per warp")
136+
elif M < 64 or N < 64:
137+
pytest.skip(f"Skipping. This LinearLayout assumes M >= 64 and N >= 64, got M={M}, N={N}")
138+
if isinstance(src_layout,
139+
(ttgl.amd.AMDMFMALayout, ttgl.NVMMADistributedLayout)) and (M < src_layout.instr_shape[0]
140+
or N < src_layout.instr_shape[1]):
141+
pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape")
142+
143+
@gluon.jit
144+
def _add(a, b):
145+
return a + b
146+
147+
@gluon.jit
148+
def _max(a, b):
149+
return ttgl.maximum(a, b)
150+
151+
combine_fn = _add if reduce_op == "sum" else _max
152+
153+
@gluon.jit
154+
def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.constexpr, axis: ttgl.constexpr,
155+
epilogue_kind: ttgl.constexpr):
156+
x_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, layout))[:, None]
157+
x_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, layout))[None, :]
158+
x = ttgl.load(x_ptr + x_offs_m * N + x_offs_n)
159+
y = ttgl.reduce(x, axis=axis, combine_fn=combine_fn)
160+
if epilogue_kind == "reduce1d":
161+
if axis == 0:
162+
z_offs = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, layout))
163+
else:
164+
z_offs = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, layout))
165+
ttgl.store(z_ptr + z_offs, y)
166+
elif epilogue_kind == "reduce2d":
167+
y = ttgl.reduce(y, axis=0, combine_fn=combine_fn)
168+
ttgl.store(z_ptr, y)
169+
elif epilogue_kind == "expand_reduce2d":
170+
y = ttgl.expand_dims(y, axis=axis)
171+
y = ttgl.reduce(y, axis=1 - axis, combine_fn=combine_fn)
172+
z_offs = ttgl.arange(0, 1, layout=ttgl.SliceLayout(1 - axis, layout))
173+
ttgl.store(z_ptr + z_offs, y)
174+
175+
torch.manual_seed(0)
176+
177+
torch_dtype = getattr(torch, dtype_str)
178+
x = torch.randint(-10, 10, (M, N), dtype=torch.int32, device=device).to(torch_dtype)
179+
out_shape = (1, 1) if "reduce2d" in epilogue_kind else (1, N) if axis == 0 else (M, 1)
180+
z = torch.empty(out_shape, dtype=torch_dtype, device=device)
181+
182+
num_warps = int(torch.prod(torch.tensor(ttgl._layouts.warps_per_cta(src_layout, (M, N)))))
183+
kernel[(1, 1, 1)](x, z, M, N, src_layout, axis, num_warps=num_warps, epilogue_kind=epilogue_kind,
184+
sanitize_overflow=sanitize_overflow, debug=sanitize_overflow)
185+
186+
reduce_fn = torch.sum if reduce_op == "sum" else torch.amax
187+
z_ref = reduce_fn(x, dim=axis, keepdim=True)
188+
if epilogue_kind in ("expand_reduce2d", "reduce2d"):
189+
z_ref = reduce_fn(z_ref, dim=1 - axis, keepdim=True)
190+
torch.testing.assert_close(z, z_ref.to(torch_dtype))

0 commit comments

Comments
 (0)