Skip to content

Commit 11d85f7

Browse files
authored
[BACKEND] Fix asserts in 2d scan and add assert mode to layout tests (#5075)
This is a follow up to #5033 but for scan ops, and also improving the testing as it was clearly insufficient before.
1 parent 94684d3 commit 11d85f7

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ static void AddPartialReduce(SmallVector<SmallVector<Value>> &srcValues,
187187
}
188188
Value mask = icmp_sge(warpId, i32_val(i + 1));
189189
accumulator.acc =
190-
accumulate(helper, rewriter, accumulator.acc, partialReduce, mask);
190+
accumulate(helper, rewriter, accumulator.acc, partialReduce);
191191
for (unsigned j = 0; j < helper.getNumOperands(); ++j) {
192192
accumulator.maskedAcc[j] =
193193
select(mask, accumulator.acc[j], accumulator.maskedAcc[j]);

python/test/unit/language/test_core.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2545,7 +2545,20 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
25452545
@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]])
25462546
@pytest.mark.parametrize("src_layout", scan_layouts)
25472547
@pytest.mark.parametrize("axis", [0, 1])
2548-
def test_scan_layouts(M, N, src_layout, axis, device, tmp_path: pathlib.Path):
2548+
@pytest.mark.parametrize("add_overflow_check", [False, True])
2549+
def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_path: pathlib.Path):
2550+
2551+
overflow_check = """
2552+
%17 = arith.extsi %arg2 : i32 to i64
2553+
%18 = arith.extsi %arg3 : i32 to i64
2554+
%19 = arith.addi %17, %18 : i64
2555+
%i32.min = arith.constant -2147483648: i64
2556+
%i32.max = arith.constant 2147483647: i64
2557+
%20 = arith.cmpi slt, %19, %i32.max : i64
2558+
%21 = arith.cmpi sge, %19, %i32.min : i64
2559+
%22 = arith.andi %20, %21 : i1
2560+
tt.assert %22, "overflow detected" : i1
2561+
"""
25492562

25502563
ir = f"""
25512564
#blocked = {src_layout}
@@ -2565,7 +2578,7 @@ def test_scan_layouts(M, N, src_layout, axis, device, tmp_path: pathlib.Path):
25652578
%10 = tt.load %9 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
25662579
%11 = "tt.scan"(%10) <{{axis = {axis} : i32, reverse = false}}> ({{
25672580
^bb0(%arg2: i32, %arg3: i32):
2568-
%16 = arith.addi %arg2, %arg3 : i32
2581+
%16 = arith.addi %arg2, %arg3 : i32{overflow_check if add_overflow_check else ""}
25692582
tt.scan.return %16 : i32
25702583
}}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
25712584
%12 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
@@ -2627,9 +2640,11 @@ def test_scan_layouts(M, N, src_layout, axis, device, tmp_path: pathlib.Path):
26272640
@pytest.mark.parametrize("src_layout", filter_layouts(layouts))
26282641
@pytest.mark.parametrize("axis", [0, 1])
26292642
@pytest.mark.parametrize("epilogue_kind", ['reduce1d', 'reduce2d', 'expand_reduce2d'])
2630-
@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"])
2643+
@pytest.mark.parametrize("dtype_str,add_overflow_check", [("int32", False), ("int32", True), ("float32", False),
2644+
("float16", False)])
26312645
@pytest.mark.parametrize("reduce_op", ["sum", "max"])
2632-
def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce_op, device, tmp_path: pathlib.Path):
2646+
def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_overflow_check, reduce_op, device,
2647+
tmp_path: pathlib.Path):
26332648
if isinstance(src_layout,
26342649
(MfmaLayout, MmaLayout)) and (M < src_layout.instr_shape[0] or N < src_layout.instr_shape[1]):
26352650
pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape")
@@ -2641,6 +2656,18 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
26412656
if isinstance(src_layout, MmaLayout) and src_layout.version == 3:
26422657
src_layout[2] = 16 if dtype_str == "float16" else 8
26432658

2659+
overflow_check = """
2660+
%18 = arith.extsi %arg3 : i32 to i64
2661+
%19 = arith.extsi %arg4 : i32 to i64
2662+
%20 = arith.addi %18, %19 : i64
2663+
%i32.min = arith.constant -2147483648: i64
2664+
%i32.max = arith.constant 2147483647: i64
2665+
%21 = arith.cmpi slt, %20, %i32.max : i64
2666+
%22 = arith.cmpi sge, %20, %i32.min : i64
2667+
%23 = arith.andi %21, %22 : i1
2668+
tt.assert %23, "overflow detected" : i1
2669+
"""
2670+
26442671
ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str]
26452672
arith_op = {
26462673
"max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, #
@@ -2673,7 +2700,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
26732700
f"""
26742701
%14 = "tt.reduce"(%13) ({{
26752702
^bb0(%arg3: {ty}, %arg4: {ty}):
2676-
%17 = {arith_op} %arg3, %arg4 : {ty}
2703+
%17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""}
26772704
tt.reduce.return %17 : {ty}
26782705
}}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty}
26792706
tt.store %arg2, %14 : !tt.ptr<{ty}>
@@ -2685,7 +2712,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
26852712
%14 = tt.expand_dims %13 {{axis = {axis} : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> -> tensor<{expanded_shape}x{ty}, #src>
26862713
%15 = "tt.reduce"(%14) ({{
26872714
^bb0(%arg3: {ty}, %arg4: {ty}):
2688-
%17 = {arith_op} %arg3, %arg4 : {ty}
2715+
%17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""}
26892716
tt.reduce.return %17 : {ty}
26902717
}}) {{axis = {other_axis} : i32}} : (tensor<{expanded_shape}x{ty}, #src>) -> (tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>>)
26912718
%16 = triton_gpu.convert_layout %15 : tensor<1x{ty}, #{GPU_DIALECT}.slice<{{dim = {other_axis}, parent = #src}}>> -> tensor<1x{ty}, #one_d_layout>
@@ -2718,7 +2745,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
27182745
%12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src>
27192746
%13 = "tt.reduce"(%12) ({{
27202747
^bb0(%arg3: {ty}, %arg4: {ty}):
2721-
%17 = {arith_op} %arg3, %arg4 : {ty}
2748+
%17 = {arith_op} %arg3, %arg4 : {ty}{overflow_check if add_overflow_check else ""}
27222749
tt.reduce.return %17 : {ty}
27232750
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>
27242751
""" + epilogue

0 commit comments

Comments
 (0)