Skip to content

Commit d3ea299

Browse files
authored
[TEST][EASY] Remove test_chain_reduce from triton's test_core.py (#8038)
As it is a duplicate with the reduce2d mode of `test_reduce_layouts` in gluon's `test_lowerings.py`
1 parent ce47711 commit d3ea299

File tree

1 file changed

+0
-76
lines changed

1 file changed

+0
-76
lines changed

python/test/unit/language/test_core.py

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,82 +3079,6 @@ def kernel(
30793079
assert compiled_kernel.asm["ttgir"].count('"tt.reduce"') == 1, "we shouldn't rematerialize tt.reduce"
30803080

30813081

3082-
layouts = [
3083-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
3084-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
3085-
# [HIP] TO DO: some tests are flaky with the layout, so turn off them for now.
3086-
# BlockedLayout([1, 4], [1, THREADS_PER_WARP], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
3087-
BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
3088-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1])
3089-
]
3090-
3091-
3092-
@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]])
3093-
@pytest.mark.parametrize("src_layout", layouts)
3094-
@pytest.mark.parametrize("op", ["sum", "max"])
3095-
@pytest.mark.parametrize("first_axis", [0, 1])
3096-
def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathlib.Path):
3097-
3098-
op_str = ""
3099-
if op == "sum":
3100-
op_str = """
3101-
%13 = arith.addi %arg2, %arg3 : i32
3102-
tt.reduce.return %13 : i32"""
3103-
elif op == "max":
3104-
op_str = """
3105-
%13 = arith.cmpi "sgt", %arg2, %arg3 : i32
3106-
%14 = arith.select %13, %arg2, %arg3 : i32
3107-
tt.reduce.return %14 : i32"""
3108-
ir = f"""
3109-
#src = {src_layout}
3110-
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
3111-
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
3112-
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
3113-
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
3114-
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src>
3115-
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
3116-
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>>
3117-
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src>
3118-
%5 = tt.broadcast %2 : tensor<{M}x1xi32, #src> -> tensor<{M}x{N}xi32, #src>
3119-
%6 = tt.broadcast %4 : tensor<1x{N}xi32, #src> -> tensor<{M}x{N}xi32, #src>
3120-
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
3121-
%8 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
3122-
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
3123-
%10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr<i32>, #src>
3124-
%11 = "tt.reduce"(%10) ({{
3125-
^bb0(%arg2: i32, %arg3: i32):
3126-
{op_str}
3127-
}}) {{axis = {first_axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>>
3128-
%12 = "tt.reduce"(%11) ({{
3129-
^bb0(%arg2: i32, %arg3: i32):
3130-
{op_str}
3131-
}}) {{axis = 0 : i32}} : (tensor<{M if first_axis == 1 else N}xi32, #{GPU_DIALECT}.slice<{{dim = {first_axis}, parent = #src}}>>) -> i32
3132-
tt.store %arg1, %12 : !tt.ptr<i32>
3133-
tt.return
3134-
}}
3135-
}}
3136-
"""
3137-
temp_file = tmp_path / "test_chain_reduce.ttgir"
3138-
temp_file.write_text(ir)
3139-
kernel = triton.compile(str(temp_file))
3140-
3141-
rs = RandomState(17)
3142-
x = rs.randint(0, 4, (M, N)).astype('int32')
3143-
3144-
z = np.zeros((1, )).astype('int32')
3145-
3146-
x_tri = torch.tensor(x, device=device)
3147-
z_tri = torch.tensor(z, device=device)
3148-
3149-
pgm = kernel[(1, 1, 1)](x_tri, z_tri)
3150-
if op == "sum":
3151-
z_ref = np.sum(x)
3152-
elif op == "max":
3153-
z_ref = np.max(x)
3154-
3155-
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
3156-
3157-
31583082
@triton.jit
31593083
def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
31603084
delta = mean_2 - mean_1

0 commit comments

Comments
 (0)