@@ -3079,82 +3079,6 @@ def kernel(
30793079 assert compiled_kernel .asm ["ttgir" ].count ('"tt.reduce"' ) == 1 , "we shouldn't rematerialize tt.reduce"
30803080
30813081
3082- layouts = [
3083- BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3084- BlockedLayout ([1 , 4 ], [1 , THREADS_PER_WARP ], [2 , 2 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3085- # [HIP] TO DO: some tests are flaky with the layout, so turn off them for now.
3086- # BlockedLayout([1, 4], [1, THREADS_PER_WARP], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
3087- BlockedLayout ([1 , 4 ], [THREADS_PER_WARP // 32 , 32 ], [1 , 4 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3088- BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ])
3089- ]
3090-
3091-
3092- @pytest .mark .parametrize ("M, N" , [[128 , 128 ], [256 , 128 ], [256 , 256 ], [128 , 256 ]])
3093- @pytest .mark .parametrize ("src_layout" , layouts )
3094- @pytest .mark .parametrize ("op" , ["sum" , "max" ])
3095- @pytest .mark .parametrize ("first_axis" , [0 , 1 ])
3096- def test_chain_reduce (M , N , src_layout , op , device , first_axis , tmp_path : pathlib .Path ):
3097-
3098- op_str = ""
3099- if op == "sum" :
3100- op_str = """
3101- %13 = arith.addi %arg2, %arg3 : i32
3102- tt.reduce.return %13 : i32"""
3103- elif op == "max" :
3104- op_str = """
3105- %13 = arith.cmpi "sgt", %arg2, %arg3 : i32
3106- %14 = arith.select %13, %arg2, %arg3 : i32
3107- tt.reduce.return %14 : i32"""
3108- ir = f"""
3109- #src = { src_layout }
3110- module attributes {{"{ GPU_DIALECT } .num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = { THREADS_PER_WARP } : i32}} {{
3111- tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
3112- %cst = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #src>
3113- %0 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #{ GPU_DIALECT } .slice<{{dim = 1, parent = #src}}>>
3114- %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{ M } xi32, #{ GPU_DIALECT } .slice<{{dim = 1, parent = #src}}>> -> tensor<{ M } x1xi32, #src>
3115- %2 = arith.muli %1, %cst : tensor<{ M } x1xi32, #src>
3116- %3 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #{ GPU_DIALECT } .slice<{{dim = 0, parent = #src}}>>
3117- %4 = tt.expand_dims %3 {{axis = 0 : i32}} : tensor<{ N } xi32, #{ GPU_DIALECT } .slice<{{dim = 0, parent = #src}}>> -> tensor<1x{ N } xi32, #src>
3118- %5 = tt.broadcast %2 : tensor<{ M } x1xi32, #src> -> tensor<{ M } x{ N } xi32, #src>
3119- %6 = tt.broadcast %4 : tensor<1x{ N } xi32, #src> -> tensor<{ M } x{ N } xi32, #src>
3120- %7 = arith.addi %5, %6 : tensor<{ M } x{ N } xi32, #src>
3121- %8 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<{ M } x{ N } x!tt.ptr<i32>, #src>
3122- %9 = tt.addptr %8, %7 : tensor<{ M } x{ N } x!tt.ptr<i32>, #src>, tensor<{ M } x{ N } xi32, #src>
3123- %10 = tt.load %9 : tensor<{ M } x{ N } x!tt.ptr<i32>, #src>
3124- %11 = "tt.reduce"(%10) ({{
3125- ^bb0(%arg2: i32, %arg3: i32):
3126- { op_str }
3127- }}) {{axis = { first_axis } : i32}} : (tensor<{ M } x{ N } xi32, #src>) -> tensor<{ M if first_axis == 1 else N } xi32, #{ GPU_DIALECT } .slice<{{dim = { first_axis } , parent = #src}}>>
3128- %12 = "tt.reduce"(%11) ({{
3129- ^bb0(%arg2: i32, %arg3: i32):
3130- { op_str }
3131- }}) {{axis = 0 : i32}} : (tensor<{ M if first_axis == 1 else N } xi32, #{ GPU_DIALECT } .slice<{{dim = { first_axis } , parent = #src}}>>) -> i32
3132- tt.store %arg1, %12 : !tt.ptr<i32>
3133- tt.return
3134- }}
3135- }}
3136- """
3137- temp_file = tmp_path / "test_chain_reduce.ttgir"
3138- temp_file .write_text (ir )
3139- kernel = triton .compile (str (temp_file ))
3140-
3141- rs = RandomState (17 )
3142- x = rs .randint (0 , 4 , (M , N )).astype ('int32' )
3143-
3144- z = np .zeros ((1 , )).astype ('int32' )
3145-
3146- x_tri = torch .tensor (x , device = device )
3147- z_tri = torch .tensor (z , device = device )
3148-
3149- pgm = kernel [(1 , 1 , 1 )](x_tri , z_tri )
3150- if op == "sum" :
3151- z_ref = np .sum (x )
3152- elif op == "max" :
3153- z_ref = np .max (x )
3154-
3155- np .testing .assert_allclose (z_ref , z_tri .cpu ().numpy (), rtol = 0.01 , atol = 1e-3 )
3156-
3157-
31583082@triton .jit
31593083def _welford_combine (mean_1 , m2_1 , weight_1 , mean_2 , m2_2 , weight_2 ):
31603084 delta = mean_2 - mean_1
0 commit comments