@@ -345,8 +345,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
345
345
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 , 1 ], threadsPerWarp = [1 , 1 , 32 ], warpsPerCTA = [1 , 4 , 4 ], order = [2 , 1 , 0 ]}>
346
346
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 16 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
347
347
// CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}>
348
- // CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [1, 1, 16], order = [0, 1, 2]}>
349
- // CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}>
348
+ // CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}>
350
349
// CHECK: @triton_red_fused_mul_sum_0
351
350
tt.func public @triton_red_fused_mul_sum_0 (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }) {
352
351
%c128_i32 = arith.constant 128 : i32
@@ -368,7 +367,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.th
368
367
// CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[ARG1:%.*]] = [[PTR1]], [[ARG2:%.*]] = {{.*}}) -> (!tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>, tensor<32x128xf32, [[BLOCKED_LAYOUT]]>)
369
368
%8:2 = scf.for %arg5 = %c0_i32 to %c512_i32 step %c128_i32 iter_args (%arg6 = %6 , %arg8 = %cst_0 ) -> (!tt.ptr <tensor <1 x32 x128 xf32 , #blocked1 >>, tensor <32 x128 xf32 , #blocked >) : i32 {
370
369
// CHECK: [[LOAD:%.*]] = tt.load [[ARG1]] evictionPolicy = evict_last {boundaryCheck = array<i32: 2>, padding = 1 : i32} : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>
371
- // CHECK-NEXT: ttg.convert_layout [[LOAD]] : tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]> -> tensor<1x32x128xf32, [[BLOCKED_LAYOUT2]]>
372
370
%17 = tt.load %arg6 evictionPolicy = evict_last {boundaryCheck = array<i32 : 2 >, padding = 1 : i32 } : !tt.ptr <tensor <1 x32 x128 xf32 , #blocked1 >>
373
371
// CHECK: scf.yield [[ARG1]], [[ARG2]] : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>, tensor<32x128xf32, [[BLOCKED_LAYOUT]]>
374
372
scf.yield %arg6 , %arg8 : !tt.ptr <tensor <1 x32 x128 xf32 , #blocked1 >>, tensor <32 x128 xf32 , #blocked >
@@ -404,7 +402,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.th
404
402
scf.yield %arg7 : !tt.ptr <tensor <1 x32 x128 xf32 , #blocked1 >>
405
403
}
406
404
// CHECK: [[LOAD_RES:%.*]] = tt.load [[RES]] : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>
407
- // CHECK: ttg.convert_layout [[LOAD_RES]] : tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]> -> tensor<1x32x128xf32, [[BLOCKED_LAYOUT2]]>
408
405
%res = tt.load %8#0 : !tt.ptr <tensor <1 x32 x128 xf32 , #blocked1 >>
409
406
tt.return
410
407
}
0 commit comments