@@ -2517,7 +2517,20 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
25172517@pytest .mark .parametrize ("M, N" , [[32 , 16 ], [32 , 32 ], [32 , 64 ], [64 , 32 ]])
25182518@pytest .mark .parametrize ("src_layout" , scan_layouts )
25192519@pytest .mark .parametrize ("axis" , [0 , 1 ])
2520- def test_scan_layouts (M , N , src_layout , axis , device ):
2520+ @pytest .mark .parametrize ("add_overflow_check" , [False , True ])
2521+ def test_scan_layouts (M , N , src_layout , axis , add_overflow_check , device ):
2522+
2523+ overflow_check = """
2524+ %17 = arith.extsi %arg2 : i32 to i64
2525+ %18 = arith.extsi %arg3 : i32 to i64
2526+ %19 = arith.addi %17, %18 : i64
2527+ %i32.min = arith.constant -2147483648: i64
2528+ %i32.max = arith.constant 2147483647: i64
2529+ %20 = arith.cmpi slt, %19, %i32.max : i64
2530+ %21 = arith.cmpi sge, %19, %i32.min : i64
2531+ %22 = arith.andi %20, %21 : i1
2532+ tt.assert %22, "overflow detected" : i1
2533+ """
25212534
25222535 ir = f"""
25232536 #blocked = { src_layout }
@@ -2537,7 +2550,7 @@ def test_scan_layouts(M, N, src_layout, axis, device):
25372550 %10 = tt.load %9 : tensor<{ M } x{ N } x!tt.ptr<i32>, #blocked>
25382551 %11 = "tt.scan"(%10) <{{axis = { axis } : i32, reverse = false}}> ({{
25392552 ^bb0(%arg2: i32, %arg3: i32):
2540- %16 = arith.addi %arg2, %arg3 : i32
2553+ %16 = arith.addi %arg2, %arg3 : i32{ overflow_check if add_overflow_check else "" }
25412554 tt.scan.return %16 : i32
25422555 }}) : (tensor<{ M } x{ N } xi32, #blocked>) -> tensor<{ M } x{ N } xi32, #blocked>
25432556 %12 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<{ M } x1x!tt.ptr<i32>, #blocked>
@@ -2599,9 +2612,10 @@ def test_scan_layouts(M, N, src_layout, axis, device):
25992612@pytest .mark .parametrize ("src_layout" , filter_layouts (layouts ))
26002613@pytest .mark .parametrize ("axis" , [0 , 1 ])
26012614@pytest .mark .parametrize ("epilogue_kind" , ['reduce1d' , 'reduce2d' , 'expand_reduce2d' ])
2602- @pytest .mark .parametrize ("dtype_str" , ["int32" , "float32" , "float16" ])
2615+ @pytest .mark .parametrize ("dtype_str,add_overflow_check" , [("int32" , False ), ("int32" , True ), ("float32" , False ),
2616+ ("float16" , False )])
26032617@pytest .mark .parametrize ("reduce_op" , ["sum" , "max" ])
2604- def test_reduce_layouts (M , N , src_layout , axis , epilogue_kind , dtype_str , reduce_op , device ):
2618+ def test_reduce_layouts (M , N , src_layout , axis , epilogue_kind , dtype_str , add_overflow_check , reduce_op , device ):
26052619 if isinstance (src_layout ,
26062620 (MfmaLayout , MmaLayout )) and (M < src_layout .instr_shape [0 ] or N < src_layout .instr_shape [1 ]):
26072621 pytest .skip ("Skipping because tensor shape is smaller than M(f)maLayout instr_shape" )
@@ -2613,6 +2627,18 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
26132627 if isinstance (src_layout , MmaLayout ) and src_layout .version == 3 :
26142628 src_layout [2 ] = 16 if dtype_str == "float16" else 8
26152629
2630+ overflow_check = """
2631+ %18 = arith.extsi %arg3 : i32 to i64
2632+ %19 = arith.extsi %arg4 : i32 to i64
2633+ %20 = arith.addi %18, %19 : i64
2634+ %i32.min = arith.constant -2147483648: i64
2635+ %i32.max = arith.constant 2147483647: i64
2636+ %21 = arith.cmpi slt, %20, %i32.max : i64
2637+ %22 = arith.cmpi sge, %20, %i32.min : i64
2638+ %23 = arith.andi %21, %22 : i1
2639+ tt.assert %23, "overflow detected" : i1
2640+ """
2641+
26162642 ty = {"int32" : "i32" , "float32" : "f32" , "float16" : "f16" }[dtype_str ]
26172643 arith_op = {
26182644 "max" : {"int32" : "arith.maxsi" , "float32" : "arith.maximumf" , "float16" : "arith.maximumf" }, #
@@ -2645,7 +2671,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
26452671 f"""
26462672 %14 = "tt.reduce"(%13) ({{
26472673 ^bb0(%arg3: { ty } , %arg4: { ty } ):
2648- %17 = { arith_op } %arg3, %arg4 : { ty }
2674+ %17 = { arith_op } %arg3, %arg4 : { ty } { overflow_check if add_overflow_check else "" }
26492675 tt.reduce.return %17 : { ty }
26502676 }}) {{axis = 0 : i32}} : (tensor<{ rdims_1d } x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { axis } , parent = #src}}>>) -> { ty }
26512677 tt.store %arg2, %14 : !tt.ptr<{ ty } >
@@ -2657,7 +2683,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
26572683 %14 = tt.expand_dims %13 {{axis = { axis } : i32}} : tensor<{ rdims_1d } x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { axis } , parent = #src}}>> -> tensor<{ expanded_shape } x{ ty } , #src>
26582684 %15 = "tt.reduce"(%14) ({{
26592685 ^bb0(%arg3: { ty } , %arg4: { ty } ):
2660- %17 = { arith_op } %arg3, %arg4 : { ty }
2686+ %17 = { arith_op } %arg3, %arg4 : { ty } { overflow_check if add_overflow_check else "" }
26612687 tt.reduce.return %17 : { ty }
26622688 }}) {{axis = { other_axis } : i32}} : (tensor<{ expanded_shape } x{ ty } , #src>) -> (tensor<1x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { other_axis } , parent = #src}}>>)
26632689 %16 = triton_gpu.convert_layout %15 : tensor<1x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { other_axis } , parent = #src}}>> -> tensor<1x{ ty } , #one_d_layout>
@@ -2690,7 +2716,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, reduce
26902716 %12 = { GPU_DIALECT } .convert_layout %11 : tensor<{ M } x{ N } x{ ty } , #blocked> -> tensor<{ M } x{ N } x{ ty } , #src>
26912717 %13 = "tt.reduce"(%12) ({{
26922718 ^bb0(%arg3: { ty } , %arg4: { ty } ):
2693- %17 = { arith_op } %arg3, %arg4 : { ty }
2719+ %17 = { arith_op } %arg3, %arg4 : { ty } { overflow_check if add_overflow_check else "" }
26942720 tt.reduce.return %17 : { ty }
26952721 }}) {{axis = { axis } : i32}} : (tensor<{ M } x{ N } x{ ty } , #src>) -> tensor<{ rdims_1d } x{ ty } , #{ GPU_DIALECT } .slice<{{dim = { axis } , parent = #src}}>>
26962722 """ + epilogue
0 commit comments