@@ -472,6 +472,7 @@ module attributes {ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.n
472
472
}
473
473
474
474
// -----
475
+
475
476
#blocked = #ttg.blocked <{sizePerThread = [1 , 1 , 1 ], threadsPerWarp = [2 , 4 , 4 ], warpsPerCTA = [2 , 1 , 1 ], order = [2 , 1 , 0 ]}>
476
477
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [2 , 1 ], order = [1 , 0 ]}>
477
478
#blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 , 1 ], threadsPerWarp = [8 , 1 , 4 ], warpsPerCTA = [2 , 1 , 1 ], order = [2 , 1 , 0 ]}>
@@ -522,3 +523,41 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.target_arch = "spir64", "tt
522
523
tt.return
523
524
}
524
525
}
526
+
527
+ // -----
528
+
529
+ // COM: Ensure layout propagation works for a while loop.
530
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
531
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
532
+ // CHECK: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
533
+ // CHECK: kernel_make_tensor_descriptor_loop_carried
534
+ tt.func public @kernel_make_tensor_descriptor_loop_carried (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg1: i64 {tt.divisibility = 16 : i32 }, %arg2: i64 {tt.divisibility = 16 : i32 }) {
535
+ %c1_i64 = arith.constant 1 : i64
536
+ %c0_i32 = arith.constant 0 : i32
537
+ %c2_i32 = arith.constant 2 : i32
538
+ // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}} {order = array<i32: 1, 0>} : <tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
539
+ // CHECK: [[ADV_PTR:%.*]] = tt.advance [[PTR]], {{.*}} : <tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
540
+ %4 = tt.make_tensor_ptr %arg0 , [%arg1 , %arg2 ], [%arg2 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <8 x128 xf32 , #blocked >>
541
+ %5 = tt.advance %4 , [%c2_i32 , %c0_i32 ] : <tensor <8 x128 xf32 , #blocked >>
542
+ %7 = arith.cmpi slt , %arg1 , %arg2 : i64
543
+ // CHECK: scf.while ([[ARG3:%.*]] = [[PTR]], [[ARG4:%.*]] = [[ADV_PTR]]) : (!tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>, !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>) -> (!tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>, !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>) {
544
+ %6:2 = scf.while (%arg3 = %4 , %arg4 = %5 ) : (!tt.ptr <tensor <8 x128 xf32 , #blocked >>, !tt.ptr <tensor <8 x128 xf32 , #blocked >>) -> (!tt.ptr <tensor <8 x128 xf32 , #blocked >>, !tt.ptr <tensor <8 x128 xf32 , #blocked >>) {
545
+ // CHECK: scf.condition({{.*}}) [[ARG3]], [[ARG4]] : !tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>
546
+ scf.condition (%7 ) %arg3 , %arg4 : !tt.ptr <tensor <8 x128 xf32 , #blocked >>, !tt.ptr <tensor <8 x128 xf32 , #blocked >>
547
+ } do {
548
+ // CHECK: ^bb0({{.*}}: !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>, {{.*}}: !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>):
549
+ ^bb0 (%arg3: !tt.ptr <tensor <8 x128 xf32 , #blocked >>, %arg4: !tt.ptr <tensor <8 x128 xf32 , #blocked >>):
550
+ // CHECK: [[PTR1:%.*]] = arith.select {{.*}} : !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
551
+ // CHECK: [[PTR2:%.*]] = tt.advance [[PTR1]], {{.*}} : <tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
552
+ // CHECK: [[LOAD:%.*]] = tt.load [[PTR1]] : !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
553
+ // CHECK: tt.store [[PTR2]], {{.*}} : !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
554
+ // CHECK: scf.yield [[PTR1]], [[PTR2]] : !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>, !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
555
+ %12 = arith.select %7 , %arg4 , %arg3 : !tt.ptr <tensor <8 x128 xf32 , #blocked >>
556
+ %13 = tt.advance %12 , [%c0_i32 , %c2_i32 ] : <tensor <8 x128 xf32 , #blocked >>
557
+ %15 = tt.load %12 : !tt.ptr <tensor <8 x128 xf32 , #blocked >>
558
+ tt.store %13 , %15 : !tt.ptr <tensor <8 x128 xf32 , #blocked >>
559
+ scf.yield %12 , %13 : !tt.ptr <tensor <8 x128 xf32 , #blocked >>, !tt.ptr <tensor <8 x128 xf32 , #blocked >>
560
+ }
561
+ tt.return
562
+ }
563
+ }
0 commit comments