Skip to content

Commit 9e6f975

Browse files
authored
[Coalesce]: Enhance the Intel coalescing pass to support while loops. (#4290)
Enhance the Intel GPU coalescing pass to handle `scf::WhileOp`. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent dc4d128 commit 9e6f975

File tree

2 files changed

+232
-60
lines changed

2 files changed

+232
-60
lines changed

test/TritonIntelGPU/coalesce.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ module attributes {ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.n
472472
}
473473

474474
// -----
475+
475476
#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [2, 4, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
476477
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [2, 1], order = [1, 0]}>
477478
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [8, 1, 4], warpsPerCTA = [2, 1, 1], order = [2, 1, 0]}>
@@ -522,3 +523,41 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.target_arch = "spir64", "tt
522523
tt.return
523524
}
524525
}
526+
527+
// -----
528+
529+
// COM: Ensure layout propagation works for a while loop.
530+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
531+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
532+
// CHECK: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
533+
// CHECK: kernel_make_tensor_descriptor_loop_carried
534+
tt.func public @kernel_make_tensor_descriptor_loop_carried(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i64 {tt.divisibility = 16 : i32}, %arg2: i64 {tt.divisibility = 16 : i32}) {
535+
%c1_i64 = arith.constant 1 : i64
536+
%c0_i32 = arith.constant 0 : i32
537+
%c2_i32 = arith.constant 2 : i32
538+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}} {order = array<i32: 1, 0>} : <tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
539+
// CHECK: [[ADV_PTR:%.*]] = tt.advance [[PTR]], {{.*}} : <tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
540+
%4 = tt.make_tensor_ptr %arg0, [%arg1, %arg2], [%arg2, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x128xf32, #blocked>>
541+
%5 = tt.advance %4, [%c2_i32, %c0_i32] : <tensor<8x128xf32, #blocked>>
542+
%7 = arith.cmpi slt, %arg1, %arg2 : i64
543+
// CHECK: scf.while ([[ARG3:%.*]] = [[PTR]], [[ARG4:%.*]] = [[ADV_PTR]]) : (!tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>, !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>) -> (!tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>, !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>) {
544+
%6:2 = scf.while (%arg3 = %4, %arg4 = %5) : (!tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>) -> (!tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>) {
545+
// CHECK: scf.condition({{.*}}) [[ARG3]], [[ARG4]] : !tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>
546+
scf.condition(%7) %arg3, %arg4 : !tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>
547+
} do {
548+
// CHECK: ^bb0({{.*}}: !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>, {{.*}}: !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>):
549+
^bb0(%arg3: !tt.ptr<tensor<8x128xf32, #blocked>>, %arg4: !tt.ptr<tensor<8x128xf32, #blocked>>):
550+
// CHECK: [[PTR1:%.*]] = arith.select {{.*}} : !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
551+
// CHECK: [[PTR2:%.*]] = tt.advance [[PTR1]], {{.*}} : <tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
552+
// CHECK: [[LOAD:%.*]] = tt.load [[PTR1]] : !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
553+
// CHECK: tt.store [[PTR2]], {{.*}} : !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
554+
// CHECK: scf.yield [[PTR1]], [[PTR2]] : !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>, !tt.ptr<tensor<8x128xf32, [[BLOCKED_LAYOUT]]>>
555+
%12 = arith.select %7, %arg4, %arg3 : !tt.ptr<tensor<8x128xf32, #blocked>>
556+
%13 = tt.advance %12, [%c0_i32, %c2_i32] : <tensor<8x128xf32, #blocked>>
557+
%15 = tt.load %12 : !tt.ptr<tensor<8x128xf32, #blocked>>
558+
tt.store %13, %15 : !tt.ptr<tensor<8x128xf32, #blocked>>
559+
scf.yield %12, %13 : !tt.ptr<tensor<8x128xf32, #blocked>>, !tt.ptr<tensor<8x128xf32, #blocked>>
560+
}
561+
tt.return
562+
}
563+
}

0 commit comments

Comments
 (0)