@@ -2545,7 +2545,20 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
25452545@pytest .mark .parametrize ("M, N" , [[32 , 16 ], [32 , 32 ], [32 , 64 ], [64 , 32 ]])
25462546@pytest .mark .parametrize ("src_layout" , scan_layouts )
25472547@pytest .mark .parametrize ("axis" , [0 , 1 ])
2548- def test_scan_layouts (M , N , src_layout , axis , device , tmp_path : pathlib .Path ):
2548+ @pytest .mark .parametrize ("add_overflow_check" , [False , True ])
2549+ def test_scan_layouts (M , N , src_layout , axis , add_overflow_check , device , tmp_path : pathlib .Path ):
2550+
2551+ overflow_check = """
2552+ %17 = arith.extsi %arg2 : i32 to i64
2553+ %18 = arith.extsi %arg3 : i32 to i64
2554+ %19 = arith.addi %17, %18 : i64
2555+ %i32.min = arith.constant -2147483648: i64
2556+ %i32.max = arith.constant 2147483647: i64
2557+ %20 = arith.cmpi slt, %19, %i32.max : i64
2558+ %21 = arith.cmpi sge, %19, %i32.min : i64
2559+ %22 = arith.andi %20, %21 : i1
2560+ tt.assert %22, "overflow detected" : i1
2561+ """
25492562
25502563 ir = f"""
25512564 #blocked = { src_layout }
@@ -2565,7 +2578,7 @@ def test_scan_layouts(M, N, src_layout, axis, device, tmp_path: pathlib.Path):
25652578 %10 = tt.load %9 : tensor<{ M } x{ N } x!tt.ptr<i32>, #blocked>
25662579 %11 = "tt.scan"(%10) <{{axis = { axis } : i32, reverse = false}}> ({{
25672580 ^bb0(%arg2: i32, %arg3: i32):
2568- %16 = arith.addi %arg2, %arg3 : i32
2581+ %16 = arith.addi %arg2, %arg3 : i32{ overflow_check if add_overflow_check else "" }
25692582 tt.scan.return %16 : i32
25702583 }}) : (tensor<{ M } x{ N } xi32, #blocked>) -> tensor<{ M } x{ N } xi32, #blocked>
25712584 %12 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<{ M } x1x!tt.ptr<i32>, #blocked>
@@ -2627,9 +2640,11 @@ def test_scan_layouts(M, N, src_layout, axis, device, tmp_path: pathlib.Path):
26272640@pytest .mark .parametrize ("src_layout" , filter_layouts (layouts ))
26282641@pytest .mark .parametrize ("axis" , [0 , 1 ])
26292642@pytest .mark .parametrize ("epilogue_kind" , ['reduce1d' , 'reduce2d' , 'expand_reduce2d' ])
2630- @pytest .mark .parametrize ("dtype_str" , ["int32" , "float32" , "float16" ])
2643+ @pytest .mark .parametrize ("dtype_str,add_overflow_check" , [("int32" , False ), ("int32" , True ), ("float32" , False ),
2644+ ("float16" , False )])
26312645@pytest .mark .parametrize ("reduce_op" , ["sum" , "max" ])
2632- def test_reduce_layouts (M , N , src_layout , axis , epilogue_kind , dtype_str , reduce_op , device , tmp_path : pathlib .Path ):
2646+ def test_reduce_layouts (M , N , src_layout , axis , epilogue_kind , dtype_str , add_overflow_check , reduce_op , device ,
2647+ tmp_path : pathlib .Path ):
26332648 if isinstance (src_layout ,
26342649 (MfmaLayout , MmaLayout )) and (M < src_layout .instr_shape [0 ] or N < src_layout .instr_shape [1 ]):
26352650 pytest .skip ("Skipping because tensor shape is smaller than M(f)maLayout instr_shape" )
@@ -2641,6 +2656,18 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
26412656 if isinstance (src_layout , MmaLayout ) and src_layout .version == 3 :
26422657 src_layout [2 ] = 16 if dtype_str == "float16" else 8
26432658
2659+ overflow_check = """
2660+ %18 = arith.extsi %arg3 : i32 to i64
2661+ %19 = arith.extsi %arg4 : i32 to i64
2662+ %20 = arith.addi %18, %19 : i64
2663+ %i32.min = arith.constant -2147483648: i64
2664+ %i32.max = arith.constant 2147483647: i64
2665+ %21 = arith.cmpi slt, %20, %i32.max : i64
2666+ %22 = arith.cmpi sge, %20, %i32.min : i64
2667+ %23 = arith.andi %21, %22 : i1
2668+ tt.assert %23, "overflow detected" : i1
2669+ """
2670+
26442671 ty = {"int32" : "i32" , "float32" : "f32" , "float16" : "f16" }[dtype_str ]
26452672 arith_op = {
26462673 "max" : {"int32" : "arith.maxsi" , "float32" : "arith.maximumf" , "float16" : "arith.maximumf" }, #
@@ -2673,7 +2700,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
26732700 f"""
26742701 %14 = "tt.reduce"(%13) ({{
26752702 ^bb0(%arg3: { ty } , %arg4: { ty } ):
2676- %17 = { arith_op } %arg3, %arg4 : { ty }
2703+ %17 = { arith_op } %arg3, %arg4 : { ty } { overflow_check if add_overflow_check else "" }
26772704 tt.reduce.return %17 : { ty }
26782705 }}) {{axis = 0 : i32}} : (tensor<{ rdims_1d } x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { axis } , parent = #src}}>>) -> { ty }
26792706 tt.store %arg2, %14 : !tt.ptr<{ ty } >
@@ -2685,7 +2712,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
26852712 %14 = tt.expand_dims %13 {{axis = { axis } : i32}} : tensor<{ rdims_1d } x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { axis } , parent = #src}}>> -> tensor<{ expanded_shape } x{ ty } , #src>
26862713 %15 = "tt.reduce"(%14) ({{
26872714 ^bb0(%arg3: { ty } , %arg4: { ty } ):
2688- %17 = { arith_op } %arg3, %arg4 : { ty }
2715+ %17 = { arith_op } %arg3, %arg4 : { ty } { overflow_check if add_overflow_check else "" }
26892716 tt.reduce.return %17 : { ty }
26902717 }}) {{axis = { other_axis } : i32}} : (tensor<{ expanded_shape } x{ ty } , #src>) -> (tensor<1x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { other_axis } , parent = #src}}>>)
26912718 %16 = triton_gpu.convert_layout %15 : tensor<1x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { other_axis } , parent = #src}}>> -> tensor<1x{ ty } , #one_d_layout>
@@ -2718,7 +2745,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
27182745 %12 = { GPU_DIALECT } .convert_layout %11 : tensor<{ M } x{ N } x{ ty } , #blocked> -> tensor<{ M } x{ N } x{ ty } , #src>
27192746 %13 = "tt.reduce"(%12) ({{
27202747 ^bb0(%arg3: { ty } , %arg4: { ty } ):
2721- %17 = { arith_op } %arg3, %arg4 : { ty }
2748+ %17 = { arith_op } %arg3, %arg4 : { ty } { overflow_check if add_overflow_check else "" }
27222749 tt.reduce.return %17 : { ty }
27232750 }}) {{axis = { axis } : i32}} : (tensor<{ M } x{ N } x{ ty } , #src>) -> tensor<{ rdims_1d } x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { axis } , parent = #src}}>>
27242751 """ + epilogue
0 commit comments