@@ -336,3 +336,50 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
336336 tt.return
337337 }
338338}
339+
340+ // -----
341+
342+ // COM: Test coalescing on blocked pointers: loop result used by tt.reduce
343+
344+ #blocked = #triton_gpu.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 4 ], order = [1 , 0 ]}>
345+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 1 , 1 ], threadsPerWarp = [1 , 1 , 32 ], warpsPerCTA = [1 , 4 , 4 ], order = [2 , 1 , 0 ]}>
346+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 16 : i32 , " triton_gpu.threads-per-warp" = 32 : i32 } {
347+ // CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}>
348+ // CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [1, 1, 16], order = [0, 1, 2]}>
349+ // CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}>
350+ // CHECK: @triton_red_fused_mul_sum_0
351+ tt.func public @triton_red_fused_mul_sum_0 (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }) {
352+ %c128_i32 = arith.constant 128 : i32
353+ %cst_0 = arith.constant dense <0.000000e+00 > : tensor <32 x128 xf32 , #blocked >
354+ %c0_i32 = arith.constant 0 : i32
355+ %c262144_i64 = arith.constant 262144 : i64
356+ %c1_i64 = arith.constant 1 : i64
357+ %c512_i64 = arith.constant 512 : i64
358+ %c32_i32 = arith.constant 32 : i32
359+ %c512_i32 = arith.constant 512 : i32
360+ %0 = tt.get_program_id x : i32
361+ %1 = arith.muli %0 , %c32_i32 : i32
362+ %2 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 xi32 , #triton_gpu.slice <{dim = 0 , parent = #blocked }>>
363+ %3 = tt.expand_dims %2 {axis = 0 : i32 } : tensor <128 xi32 , #triton_gpu.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x128 xi32 , #blocked >
364+ %4 = arith.divsi %1 , %c512_i32 : i32
365+ %5 = arith.remsi %1 , %c512_i32 : i32
366+ // CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr %arg0, {{.*}} : <tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>
367+ %6 = tt.make_tensor_ptr %arg0 , [%c512_i64 , %c512_i64 , %c512_i64 ], [%c1_i64 , %c512_i64 , %c262144_i64 ], [%4 , %5 , %c0_i32 ] {order = array<i32 : 2 , 1 , 0 >} : <tensor <1 x32 x128 xf32 , #blocked1 >>
368+ // CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[ARG1:%.*]] = [[PTR1]], [[ARG2:%.*]] = {{.*}}) -> (!tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>, tensor<32x128xf32, [[BLOCKED_LAYOUT]]>)
369+ %8:2 = scf.for %arg5 = %c0_i32 to %c512_i32 step %c128_i32 iter_args (%arg6 = %6 , %arg8 = %cst_0 ) -> (!tt.ptr <tensor <1 x32 x128 xf32 , #blocked1 >>, tensor <32 x128 xf32 , #blocked >) : i32 {
370+ // CHECK: [[LOAD:%.*]] = tt.load [[ARG1]] evictionPolicy = evict_last {boundaryCheck = array<i32: 2>, padding = 1 : i32} : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>
371+ // CHECK-NEXT: triton_gpu.convert_layout [[LOAD]] : tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]> -> tensor<1x32x128xf32, [[BLOCKED_LAYOUT2]]>
372+ %17 = tt.load %arg6 evictionPolicy = evict_last {boundaryCheck = array<i32 : 2 >, padding = 1 : i32 } : !tt.ptr <tensor <1 x32 x128 xf32 , #blocked1 >>
373+ // CHECK: scf.yield [[ARG1]], [[ARG2]] : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>, tensor<32x128xf32, [[BLOCKED_LAYOUT]]>
374+ scf.yield %arg6 , %arg8 : !tt.ptr <tensor <1 x32 x128 xf32 , #blocked1 >>, tensor <32 x128 xf32 , #blocked >
375+ }
376+ // CHECK: = "tt.reduce"([[RES]]#1) <{axis = 1 : i32}> ({
377+ // CHECK }) : (tensor<32x128xf32, [[BLOCKED_LAYOUT]]) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = [[BLOCKED_LAYOUT]]}>>
378+ %9 = " tt.reduce" (%8#1 ) <{axis = 1 : i32 }> ({
379+ ^bb0 (%arg5: f32 , %arg6: f32 ):
380+ %14 = arith.addf %arg5 , %arg6 : f32
381+ tt.reduce.return %14 : f32
382+ }) : (tensor <32 x128 xf32 , #blocked >) -> tensor <32 xf32 , #triton_gpu.slice <{dim = 1 , parent = #blocked }>>
383+ tt.return
384+ }
385+ }
0 commit comments