@@ -54,8 +54,7 @@ tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
5454// -----
5555
5656#blocked = #ttg.blocked <{sizePerThread = [1 ], threadsPerWarp = [32 ], warpsPerCTA = [4 ], order = [0 ]}>
57- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
58-
57+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
5958
6059// CHECK: [[NARROW_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
6160// CHECK: [[WIDE_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
@@ -343,7 +342,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
343342
344343#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 4 ], order = [1 , 0 ]}>
345344#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 , 1 ], threadsPerWarp = [1 , 1 , 32 ], warpsPerCTA = [1 , 4 , 4 ], order = [2 , 1 , 0 ]}>
346- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 16 : i32 , " ttg.threads-per-warp " = 32 : i32 } {
345+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 16 : i32 } {
347346 // CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}>
348347 // CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}>
349348 // CHECK: @triton_red_fused_mul_sum_0
@@ -412,7 +411,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.th
412411#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [2 , 16 ], order = [1 , 0 ]}>
413412#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [8 , 4 ], order = [1 , 0 ]}>
414413#blocked2 = #ttg.blocked <{sizePerThread = [4 , 4 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [8 , 4 ], order = [1 , 0 ]}>
415- module attributes {ttig.target_arch = " spir64 " , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , ttg.target = " xpu " , " ttg.threads-per-warp" = 16 : i32 } {
414+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , " ttg.threads-per-warp" = 16 : i32 } {
416415 // CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 16], order = [1, 0]}>
417416 // CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
418417 // CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
@@ -474,7 +473,7 @@ module attributes {ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.n
474473#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [2 , 1 ], order = [1 , 0 ]}>
475474#blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 , 1 ], threadsPerWarp = [8 , 1 , 4 ], warpsPerCTA = [2 , 1 , 1 ], order = [2 , 1 , 0 ]}>
476475#blocked3 = #ttg.blocked <{sizePerThread = [1 , 1 , 1 ], threadsPerWarp = [1 , 8 , 4 ], warpsPerCTA = [1 , 2 , 1 ], order = [0 , 1 , 2 ]}>
477- module attributes {ttig.min_sg_size = 16 : i32 , ttig.target_arch = " spir64 " , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 2 : i32 , ttg.target = " xpu " , " ttg.threads-per-warp " = 32 : i32 } {
476+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 2 : i32 } {
478477 // CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
479478 // CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
480479 // CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
@@ -587,3 +586,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
587586 tt.return
588587 }
589588}
589+
590+ // -----
591+
592+ // COM: Test layout propagation for nested operations (scf.if nested in scf.for).
593+ // COM: Reproducer for issue #4867
594+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
595+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
596+ // CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
597+ // CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
598+ // CHECK: @test_4867
599+ tt.func public @test_4867 (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg2: i1 ) {
600+ %c0_i32 = arith.constant 0 : i32
601+ %c16_i32 = arith.constant 16 : i32
602+ %c128_i64 = arith.constant 128 : i64
603+ %c1_i64 = arith.constant 1 : i64
604+ %c32_i32 = arith.constant 32 : i32
605+ %0 = tt.make_tensor_ptr %arg0 , [%c128_i64 , %c128_i64 ], [%c1_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 0 , 1 >} : <tensor <128 x32 xf32 , #blocked >>
606+ %1 = tt.make_tensor_ptr %arg1 , [%c128_i64 , %c128_i64 ], [%c1_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x128 xf32 , #blocked >>
607+ %2:2 = scf.for %arg3 = %c0_i32 to %c32_i32 step %c32_i32 iter_args (%arg4 = %0 , %arg5 = %1 ) -> (!tt.ptr <tensor <128 x32 xf32 , #blocked >>, !tt.ptr <tensor <32 x128 xf32 , #blocked >>) : i32 {
608+ // CHECK: scf.for {{.*}}
609+ // CHECK-NOT: [[BLOCKED_LAYOUT]]>>
610+ %adv = tt.advance %arg5 , [%c32_i32 , %c0_i32 ] : <tensor <32 x128 xf32 , #blocked >>
611+ %3:2 = scf.if %arg2 -> (!tt.ptr <tensor <32 x128 xf32 , #blocked >>, !tt.ptr <tensor <32 x128 xf32 , #blocked >>) {
612+ scf.yield %adv , %arg5 : !tt.ptr <tensor <32 x128 xf32 , #blocked >>, !tt.ptr <tensor <32 x128 xf32 , #blocked >>
613+ } else {
614+ scf.yield %arg5 , %adv : !tt.ptr <tensor <32 x128 xf32 , #blocked >>, !tt.ptr <tensor <32 x128 xf32 , #blocked >>
615+ }
616+ // CHECK: scf.yield {{.*}} : !tt.ptr<tensor<128x32xf32, [[BLOCKED_LAYOUT]]>>, !tt.ptr<tensor<32x128xf32, [[BLOCKED_LAYOUT1]]>>
617+ scf.yield %arg4 , %3#0 : !tt.ptr <tensor <128 x32 xf32 , #blocked >>, !tt.ptr <tensor <32 x128 xf32 , #blocked >>
618+ }
619+ // CHECK: [[ADV:%.*]] = tt.advance {{.*}} : <tensor<128x32xf32, [[BLOCKED_LAYOUT]]>>
620+ %3 = tt.advance %2#0 , [%c0_i32 , %c16_i32 ] : <tensor <128 x32 xf32 , #blocked >>
621+ // CHECK: [[LOAD:%.*]] = tt.load {{.*}} : !tt.ptr<tensor<32x128xf32, [[BLOCKED_LAYOUT1]]>>
622+ %4 = tt.load %1 {boundaryCheck = array<i32 : 0 >, padding = 1 : i32 } : !tt.ptr <tensor <32 x128 xf32 , #blocked >>
623+ tt.return
624+ }
625+ }
0 commit comments