@@ -54,8 +54,7 @@ tt.func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
54
54
// -----
55
55
56
56
#blocked = #ttg.blocked <{sizePerThread = [1 ], threadsPerWarp = [32 ], warpsPerCTA = [4 ], order = [0 ]}>
57
- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
58
-
57
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
59
58
60
59
// CHECK: [[NARROW_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
61
60
// CHECK: [[WIDE_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
@@ -343,7 +342,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
343
342
344
343
#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 4 ], order = [1 , 0 ]}>
345
344
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 , 1 ], threadsPerWarp = [1 , 1 , 32 ], warpsPerCTA = [1 , 4 , 4 ], order = [2 , 1 , 0 ]}>
346
- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 16 : i32 , " ttg.threads-per-warp " = 32 : i32 } {
345
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 16 : i32 } {
347
346
// CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}>
348
347
// CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}>
349
348
// CHECK: @triton_red_fused_mul_sum_0
@@ -412,7 +411,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.th
412
411
#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [2 , 16 ], order = [1 , 0 ]}>
413
412
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [8 , 4 ], order = [1 , 0 ]}>
414
413
#blocked2 = #ttg.blocked <{sizePerThread = [4 , 4 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [8 , 4 ], order = [1 , 0 ]}>
415
- module attributes {ttig.target_arch = " spir64 " , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , ttg.target = " xpu " , " ttg.threads-per-warp" = 16 : i32 } {
414
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , " ttg.threads-per-warp" = 16 : i32 } {
416
415
// CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 16], order = [1, 0]}>
417
416
// CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
418
417
// CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
@@ -474,7 +473,7 @@ module attributes {ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.n
474
473
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [2 , 1 ], order = [1 , 0 ]}>
475
474
#blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 , 1 ], threadsPerWarp = [8 , 1 , 4 ], warpsPerCTA = [2 , 1 , 1 ], order = [2 , 1 , 0 ]}>
476
475
#blocked3 = #ttg.blocked <{sizePerThread = [1 , 1 , 1 ], threadsPerWarp = [1 , 8 , 4 ], warpsPerCTA = [1 , 2 , 1 ], order = [0 , 1 , 2 ]}>
477
- module attributes {ttig.min_sg_size = 16 : i32 , ttig.target_arch = " spir64 " , " ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 2 : i32 , ttg.target = " xpu " , " ttg.threads-per-warp " = 32 : i32 } {
476
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 2 : i32 } {
478
477
// CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
479
478
// CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
480
479
// CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
@@ -587,3 +586,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
587
586
tt.return
588
587
}
589
588
}
589
+
590
+ // -----
591
+
592
+ // COM: Test layout propagation for nested operations (scf.if nested in scf.for).
593
+ // COM: Reproducer for issue #4867
594
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
595
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
596
+ // CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
597
+ // CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
598
+ // CHECK: @test_4867
599
+ tt.func public @test_4867 (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg2: i1 ) {
600
+ %c0_i32 = arith.constant 0 : i32
601
+ %c16_i32 = arith.constant 16 : i32
602
+ %c128_i64 = arith.constant 128 : i64
603
+ %c1_i64 = arith.constant 1 : i64
604
+ %c32_i32 = arith.constant 32 : i32
605
+ %0 = tt.make_tensor_ptr %arg0 , [%c128_i64 , %c128_i64 ], [%c1_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 0 , 1 >} : <tensor <128 x32 xf32 , #blocked >>
606
+ %1 = tt.make_tensor_ptr %arg1 , [%c128_i64 , %c128_i64 ], [%c1_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x128 xf32 , #blocked >>
607
+ %2:2 = scf.for %arg3 = %c0_i32 to %c32_i32 step %c32_i32 iter_args (%arg4 = %0 , %arg5 = %1 ) -> (!tt.ptr <tensor <128 x32 xf32 , #blocked >>, !tt.ptr <tensor <32 x128 xf32 , #blocked >>) : i32 {
608
+ // CHECK: scf.for {{.*}}
609
+ // CHECK-NOT: [[BLOCKED_LAYOUT]]>>
610
+ %adv = tt.advance %arg5 , [%c32_i32 , %c0_i32 ] : <tensor <32 x128 xf32 , #blocked >>
611
+ %3:2 = scf.if %arg2 -> (!tt.ptr <tensor <32 x128 xf32 , #blocked >>, !tt.ptr <tensor <32 x128 xf32 , #blocked >>) {
612
+ scf.yield %adv , %arg5 : !tt.ptr <tensor <32 x128 xf32 , #blocked >>, !tt.ptr <tensor <32 x128 xf32 , #blocked >>
613
+ } else {
614
+ scf.yield %arg5 , %adv : !tt.ptr <tensor <32 x128 xf32 , #blocked >>, !tt.ptr <tensor <32 x128 xf32 , #blocked >>
615
+ }
616
+ // CHECK: scf.yield {{.*}} : !tt.ptr<tensor<128x32xf32, [[BLOCKED_LAYOUT]]>>, !tt.ptr<tensor<32x128xf32, [[BLOCKED_LAYOUT1]]>>
617
+ scf.yield %arg4 , %3#0 : !tt.ptr <tensor <128 x32 xf32 , #blocked >>, !tt.ptr <tensor <32 x128 xf32 , #blocked >>
618
+ }
619
+ // CHECK: [[ADV:%.*]] = tt.advance {{.*}} : <tensor<128x32xf32, [[BLOCKED_LAYOUT]]>>
620
+ %3 = tt.advance %2#0 , [%c0_i32 , %c16_i32 ] : <tensor <128 x32 xf32 , #blocked >>
621
+ // CHECK: [[LOAD:%.*]] = tt.load {{.*}} : !tt.ptr<tensor<32x128xf32, [[BLOCKED_LAYOUT1]]>>
622
+ %4 = tt.load %1 {boundaryCheck = array<i32 : 0 >, padding = 1 : i32 } : !tt.ptr <tensor <32 x128 xf32 , #blocked >>
623
+ tt.return
624
+ }
625
+ }
0 commit comments