Skip to content

Commit 8bd4dd1

Browse files
authored
[GLUON] Support associative_scan and device_assert and remove ttgir scan layout tests (#7894)
1 parent 0e71b2c commit 8bd4dd1

File tree

5 files changed

+89
-100
lines changed

5 files changed

+89
-100
lines changed

python/test/gluon/test_core.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import triton
66
import triton.language as tl
77

8-
from triton._internal_testing import is_cuda, is_ampere_or_newer, is_hip_cdna3, is_hip_cdna4, is_hopper_or_newer, is_hopper
8+
from triton._internal_testing import is_ampere_or_newer, is_hip_cdna3, is_hip_cdna4, is_hopper_or_newer, is_hopper
99
from triton.experimental import gluon
1010
from triton.experimental.gluon import language as ttgl
1111
from triton.experimental.gluon.language.nvidia.ampere import async_copy, mbarrier
@@ -14,6 +14,8 @@
1414
from triton.experimental.gluon.language.amd.cdna4 import async_copy as cdna4_async_copy
1515
from triton.experimental.gluon.language.extra import libdevice
1616

17+
THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size
18+
1719

1820
@gluon.jit
1921
def copy_kernel(Out, In, numel, XBLOCK: ttgl.constexpr, layout: ttgl.constexpr):
@@ -24,18 +26,15 @@ def copy_kernel(Out, In, numel, XBLOCK: ttgl.constexpr, layout: ttgl.constexpr):
2426
ttgl.store(Out + xoffset, data, xmask)
2527

2628

27-
copy_kernel_tpw = [32] if is_cuda() else [64]
28-
29-
3029
@pytest.mark.parametrize("layout", [
31-
ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=copy_kernel_tpw, warps_per_cta=[4], order=[0]),
32-
ttgl.BlockedLayout(size_per_thread=[2], threads_per_warp=copy_kernel_tpw, warps_per_cta=[4], order=[0]),
33-
ttgl.BlockedLayout(size_per_thread=[4], threads_per_warp=copy_kernel_tpw, warps_per_cta=[4], order=[0]),
34-
ttgl.BlockedLayout(size_per_thread=[8], threads_per_warp=copy_kernel_tpw, warps_per_cta=[4], order=[0]),
35-
ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=copy_kernel_tpw, warps_per_cta=[8], order=[0]),
36-
ttgl.BlockedLayout(size_per_thread=[2], threads_per_warp=copy_kernel_tpw, warps_per_cta=[8], order=[0]),
37-
ttgl.BlockedLayout(size_per_thread=[4], threads_per_warp=copy_kernel_tpw, warps_per_cta=[8], order=[0]),
38-
ttgl.BlockedLayout(size_per_thread=[8], threads_per_warp=copy_kernel_tpw, warps_per_cta=[8], order=[0]),
30+
ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[4], order=[0]),
31+
ttgl.BlockedLayout(size_per_thread=[2], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[4], order=[0]),
32+
ttgl.BlockedLayout(size_per_thread=[4], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[4], order=[0]),
33+
ttgl.BlockedLayout(size_per_thread=[8], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[4], order=[0]),
34+
ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[8], order=[0]),
35+
ttgl.BlockedLayout(size_per_thread=[2], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[8], order=[0]),
36+
ttgl.BlockedLayout(size_per_thread=[4], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[8], order=[0]),
37+
ttgl.BlockedLayout(size_per_thread=[8], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[8], order=[0]),
3938
])
4039
@pytest.mark.parametrize("XBLOCK", [128, 256, 512, 1024, 2048])
4140
def test_copy_kernel(layout, XBLOCK):
@@ -403,13 +402,12 @@ def fast_expf_kernel(x_ptr, y_ptr, warp_size: ttgl.constexpr, num_warps: ttgl.co
403402
y = libdevice.fast_expf(x)
404403
ttgl.store(y_ptr + offs, y)
405404

406-
warp_size = 32 if is_cuda() else 64
407405
num_warps = 4
408406

409407
torch.manual_seed(0)
410-
x = torch.randn(warp_size * num_warps, device="cuda", dtype=torch.float32)
408+
x = torch.randn(THREADS_PER_WARP * num_warps, device="cuda", dtype=torch.float32)
411409
y = torch.empty_like(x)
412-
fast_expf_kernel[(1, )](x, y, warp_size, num_warps)
410+
fast_expf_kernel[(1, )](x, y, THREADS_PER_WARP, num_warps)
413411
torch.testing.assert_close(y, torch.exp(x), atol=1e-5, rtol=1e-4)
414412

415413

@@ -425,13 +423,12 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
425423
z = libdevice.fast_dividef(x, y)
426424
ttgl.store(z_ptr + offs, z)
427425

428-
warp_size = 32 if is_cuda() else 64
429426
num_warps = 4
430427

431428
torch.manual_seed(0)
432-
x = torch.randn(warp_size * num_warps, device="cuda", dtype=torch.float32)
429+
x = torch.randn(THREADS_PER_WARP * num_warps, device="cuda", dtype=torch.float32)
433430
y = torch.randn_like(x)
434431
z = torch.empty_like(x)
435432
y[y == 0] = 1.0
436-
fast_dividef_kernel[(1, )](x, y, z, warp_size, num_warps)
433+
fast_dividef_kernel[(1, )](x, y, z, THREADS_PER_WARP, num_warps)
437434
torch.testing.assert_close(z, torch.div(x, y), atol=1e-5, rtol=1e-4)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import pytest
3+
4+
import triton
5+
from triton.experimental import gluon
6+
from triton.experimental.gluon import language as ttgl
7+
8+
THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size
9+
10+
11+
@pytest.mark.parametrize("M, N", [(32, 16), (32, 32), (32, 64), (64, 32)])
12+
@pytest.mark.parametrize("src_layout", [
13+
ttgl.BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1]),
14+
ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]),
15+
ttgl.BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1]),
16+
ttgl.BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1]),
17+
ttgl.BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1]),
18+
ttgl.BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0]),
19+
ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]),
20+
ttgl.BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0]),
21+
ttgl.BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0]),
22+
ttgl.BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0]),
23+
ttgl.BlockedLayout([1, 2], [1, THREADS_PER_WARP], [1, 4], [1, 0]),
24+
])
25+
@pytest.mark.parametrize("axis", [0, 1])
26+
@pytest.mark.parametrize("sanitize_overflow", [False, True])
27+
def test_scan_layouts(M, N, src_layout, axis, sanitize_overflow, device):
28+
29+
@gluon.jit
30+
def _combine(a, b):
31+
return a + b
32+
33+
@gluon.jit
34+
def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.constexpr, axis: ttgl.constexpr):
35+
x_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, layout))[:, None]
36+
x_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, layout))[None, :]
37+
x = ttgl.load(x_ptr + x_offs_m * N + x_offs_n)
38+
y = ttgl.associative_scan(x, axis=axis, combine_fn=_combine)
39+
ttgl.store(z_ptr + x_offs_m * N + x_offs_n, y)
40+
41+
torch.manual_seed(0)
42+
43+
x = torch.randint(-100, 100, (M, N), dtype=torch.int32, device=device)
44+
z = torch.zeros((M, N), dtype=torch.int32, device=device)
45+
z_tri = torch.empty_like(z)
46+
47+
kernel[(1, 1, 1)](x, z_tri, M, N, src_layout, axis, num_warps=4, sanitize_overflow=sanitize_overflow,
48+
debug=sanitize_overflow)
49+
50+
z_ref = torch.cumsum(x, dim=axis, dtype=torch.int32)
51+
torch.testing.assert_close(z_tri, z_ref)

python/test/unit/language/test_core.py

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3034,21 +3034,6 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
30343034
np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3)
30353035

30363036

3037-
scan_layouts = [
3038-
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
3039-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
3040-
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]),
3041-
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
3042-
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
3043-
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
3044-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
3045-
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
3046-
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
3047-
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
3048-
BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
3049-
]
3050-
3051-
30523037
def test_no_rematerialization_op():
30533038

30543039
if torch.version.hip:
@@ -3094,73 +3079,6 @@ def kernel(
30943079
assert compiled_kernel.asm["ttgir"].count('"tt.reduce"') == 1, "we shouldn't rematerialize tt.reduce"
30953080

30963081

3097-
@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]])
3098-
@pytest.mark.parametrize("src_layout", scan_layouts)
3099-
@pytest.mark.parametrize("axis", [0, 1])
3100-
@pytest.mark.parametrize("add_overflow_check", [False, True])
3101-
def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_path: pathlib.Path):
3102-
3103-
overflow_check = """
3104-
%17 = arith.extsi %arg2 : i32 to i64
3105-
%18 = arith.extsi %arg3 : i32 to i64
3106-
%19 = arith.addi %17, %18 : i64
3107-
%i32.min = arith.constant -2147483648: i64
3108-
%i32.max = arith.constant 2147483647: i64
3109-
%20 = arith.cmpi slt, %19, %i32.max : i64
3110-
%21 = arith.cmpi sge, %19, %i32.min : i64
3111-
%22 = arith.andi %20, %21 : i1
3112-
tt.assert %22, "overflow detected" : i1
3113-
"""
3114-
3115-
ir = f"""
3116-
#blocked = {src_layout}
3117-
module attributes {{"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
3118-
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
3119-
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
3120-
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>>
3121-
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked>
3122-
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked>
3123-
%3 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
3124-
%4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
3125-
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>>
3126-
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked>
3127-
%7 = tt.broadcast %4 : tensor<{M}x1x!tt.ptr<i32>, #blocked> -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
3128-
%8 = tt.broadcast %6 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked>
3129-
%9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
3130-
%10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
3131-
%11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{
3132-
^bb0(%arg2: i32, %arg3: i32):
3133-
%16 = arith.addi %arg2, %arg3 : i32{overflow_check if add_overflow_check else ""}
3134-
tt.scan.return %16 : i32
3135-
}}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
3136-
%12 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
3137-
%13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
3138-
%14 = tt.broadcast %13 : tensor<{M}x1x!tt.ptr<i32>, #blocked> -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
3139-
%15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
3140-
tt.store %15, %11 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
3141-
tt.return
3142-
}}
3143-
}}
3144-
"""
3145-
3146-
temp_file = tmp_path / "test_scan_layouts.ttgir"
3147-
temp_file.write_text(ir)
3148-
kernel = triton.compile(str(temp_file))
3149-
3150-
rs = RandomState(17)
3151-
x = rs.randint(-100, 100, (M, N)).astype('int32')
3152-
3153-
z = np.zeros((M, N)).astype('int32')
3154-
x_tri = torch.tensor(x, device=device)
3155-
z_tri = torch.tensor(z, device=device)
3156-
3157-
kernel[(1, 1, 1)](x_tri, z_tri)
3158-
3159-
z_ref = np.cumsum(x, axis=axis)
3160-
3161-
np.testing.assert_equal(z_ref, z_tri.cpu().numpy())
3162-
3163-
31643082
layouts = [
31653083
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
31663084
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),

python/triton/experimental/gluon/language/_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646

4747
_IMPORT_FROM_TRITON: List[str] = [
48+
"associative_scan",
4849
"atomic_add",
4950
"atomic_and",
5051
"atomic_cas",
@@ -54,6 +55,7 @@
5455
"atomic_xchg",
5556
"atomic_xor",
5657
"broadcast",
58+
"device_assert",
5759
"expand_dims",
5860
"inline_asm_elementwise",
5961
"join",

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,27 @@ def _check_same_layout(xs):
280280
_check(all(l == l0 for l in layouts[1:]),
281281
lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
282282

283+
def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
284+
reverse: bool) -> Tuple[TensorTy, ...]:
285+
shape = inputs[0].type.shape
286+
rank = len(shape)
287+
288+
assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
289+
290+
if axis < 0:
291+
axis += rank
292+
293+
for t in inputs:
294+
assert t.type.shape == shape, "all scan inputs must have the same shape"
295+
296+
scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse)
297+
region_builder_fn(scan_op)
298+
assert scan_op.verify()
299+
300+
return tuple(
301+
self._wrap_handle_infer_layout(scan_op.get_result(i), inputs[i].type.scalar, shape)
302+
for i in range(len(inputs)))
303+
283304
def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
284305
_check(axis is not None, lambda: "All-reduce is not yet implemented in gluon")
285306
# get result shape

0 commit comments

Comments
 (0)