Skip to content

Commit 26e6e97

Browse files
Add support for dynamic unit trip scf.for to scf.if (#20880)
This PR adds support for dynamic unit trip (0 or 1 trip) scf.for using scf.if. --------- Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 48b081d commit 26e6e97

File tree

3 files changed

+134
-16
lines changed

3 files changed

+134
-16
lines changed

compiler/src/iree/compiler/Codegen/Common/test/remove_single_iteration_loop.mlir

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@ func.func @thread_tile_loop() {
1616
gpu.barrier
1717
}
1818
}
19-
// The inner loop doesn't always execute once so it cannot be removed.
20-
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C250]] step %[[C250]]
21-
// CHECK: gpu.barrier
2219
scf.for %arg3 = %tidy to %c2 step %c2 {
20+
// CHECK-NOT: scf.for
2321
%0 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%tidx]
22+
// CHECK: %[[LB:.+]] = affine.apply
23+
// The inner loop doesn't always execute once so it needs an scf.if
24+
// CHECK: %[[COND:.+]] = arith.cmpi slt, %[[LB]], %[[C250]] : index
25+
// CHECK: scf.if %[[COND]] {
26+
// CHECK: gpu.barrier
2427
scf.for %arg4 = %0 to %c250 step %c250 {
2528
gpu.barrier
2629
}
@@ -161,6 +164,7 @@ func.func @delinearize_linearize() {
161164
%c64 = arith.constant 64 : index
162165
%tidx = gpu.thread_id x upper_bound 128
163166
%ids:2 = affine.delinearize_index %tidx into (4, 32) : index, index
167+
// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index
164168
// CHECK-NOT: scf.for
165169
// CHECK: gpu.barrier
166170
scf.for %arg3 = %ids#0 to %c4 step %c4 {
@@ -169,8 +173,9 @@ func.func @delinearize_linearize() {
169173
gpu.barrier
170174
}
171175
}
172-
// The loop loop doesn't always execute once so it cannot be removed.
173-
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C3]] step %{{.*}}
176+
// The loop doesn't always execute once so it needs an scf.if
177+
// CHECK: %[[COND:.+]] = arith.cmpi slt, %[[IDS:.+]]#0, %[[C3]] : index
178+
// CHECK: scf.if %[[COND]] {
174179
// CHECK: gpu.barrier
175180
scf.for %arg3 = %ids#0 to %c3 step %c4 {
176181
gpu.barrier
@@ -220,3 +225,91 @@ func.func @argument_with_assume(%arg_index : index) {
220225
}
221226
return
222227
}
228+
229+
// -----
230+
231+
func.func @dynamic_ub_unittrip(%arg_index : index, %arg_value : memref<8xf16>) {
232+
%c1 = arith.constant 0 : index
233+
%c3 = arith.constant 3 : index
234+
%0 = util.assume.int %arg_index<umin = 0, umax = 3> : index
235+
scf.for %arg1 = %c1 to %0 step %c3 {
236+
%alloc = memref.alloc() : memref<4xf16>
237+
%subview = memref.subview %arg_value[%arg1][4][1] : memref<8xf16> to memref<4xf16, strided<[1], offset: ?>>
238+
memref.copy %alloc, %subview : memref<4xf16> to memref<4xf16, strided<[1], offset: ?>>
239+
}
240+
return
241+
}
242+
// CHECK-LABEL: func.func @dynamic_ub_unittrip
243+
// CHECK-SAME: (%[[ARGINDEX:.+]]: index, %[[ARGVALUE:.+]]: memref<8xf16>)
244+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
245+
// CHECK: %[[UB:.+]] = util.assume.int %[[ARGINDEX]]
246+
// CHECK: %[[COND:.+]] = arith.cmpi sgt, %[[UB]], %[[C0]] : index
247+
// CHECK: scf.if %[[COND]] {
248+
// CHECK: memref.alloc()
249+
// CHECK: memref.subview %[[ARGVALUE]][%[[C0]]] [4] [1]
250+
// CHECK: memref.copy
251+
252+
// -----
253+
254+
func.func @dynamic_lb_unittrip(%arg_index : index, %arg_value : memref<8xf16>) {
255+
%c1 = arith.constant 1 : index
256+
%c3 = arith.constant 3 : index
257+
%0 = util.assume.int %arg_index<umin = 0, umax = 50> : index
258+
scf.for %arg1 = %0 to %c3 step %c3 {
259+
%alloc = memref.alloc() : memref<4xf16>
260+
%subview = memref.subview %arg_value[%arg1][4][1] : memref<8xf16> to memref<4xf16, strided<[1], offset: ?>>
261+
memref.copy %alloc, %subview : memref<4xf16> to memref<4xf16, strided<[1], offset: ?>>
262+
}
263+
return
264+
}
265+
266+
// CHECK-LABEL: func.func @dynamic_lb_unittrip
267+
// CHECK-SAME: (%[[ARGINDEX:.+]]: index, %[[ARGVALUE:.+]]: memref<8xf16>)
268+
// CHECK: %[[C3:.+]] = arith.constant 3 : index
269+
// CHECK: %[[LB:.+]] = util.assume.int %[[ARGINDEX]]
270+
// CHECK: %[[COND:.+]] = arith.cmpi slt, %[[LB]], %[[C3]] : index
271+
// CHECK: scf.if %[[COND]] {
272+
// CHECK: memref.alloc()
273+
// CHECK: memref.subview %[[ARGVALUE]][%[[LB]]] [4] [1]
274+
// CHECK: memref.copy
275+
276+
// -----
277+
278+
func.func @dynamic_nonunittrip(%arg_index : index, %arg_value : memref<8xf16>) {
279+
%c1 = arith.constant 1 : index
280+
%c3 = arith.constant 3 : index
281+
%0 = util.assume.int %arg_index<umin = 0, umax = 5> : index
282+
scf.for %arg1 = %c1 to %0 step %c3 {
283+
gpu.barrier
284+
}
285+
return
286+
}
287+
// CHECK-LABEL: func.func @dynamic_nonunittrip
288+
// CHECK: scf.for
289+
290+
// -----
291+
292+
func.func @dynamic_unittrip_with_destination(%arg_index : index, %arg_value : tensor<8xf16>) -> tensor<4xf16> {
293+
%c0 = arith.constant 0 : index
294+
%c3 = arith.constant 3 : index
295+
%0 = util.assume.int %arg_index<umin = 0, umax = 3> : index
296+
%empty = tensor.empty() : tensor<4xf16>
297+
%1 = scf.for %arg1 = %c0 to %0 step %c3 iter_args(%arg2 = %empty) -> (tensor<4xf16>) {
298+
%extract = tensor.extract_slice %arg_value[%arg1][4][1] : tensor<8xf16> to tensor<4xf16>
299+
%2 = arith.negf %extract : tensor<4xf16>
300+
scf.yield %2 : tensor<4xf16>
301+
}
302+
return %1 : tensor<4xf16>
303+
}
304+
305+
// CHECK-LABEL: func.func @dynamic_unittrip_with_destination
306+
// CHECK-SAME: (%[[ARGINDEX:.+]]: index, %[[ARGTENSOR:.+]]: tensor<8xf16>)
307+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4xf16>
308+
// CHECK: %[[RESULT:.+]] = scf.if
309+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice
310+
// CHECK: %[[NEG:.+]] = arith.negf %[[SLICE]] : tensor<4xf16>
311+
// CHECK: scf.yield %[[NEG]] : tensor<4xf16>
312+
// CHECK: } else {
313+
// CHECK: scf.yield %[[EMPTY]] : tensor<4xf16>
314+
// CHECK: }
315+
// CHECK: return %[[RESULT]] : tensor<4xf16>

compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ void addSPIRVCooperativeMatrixVectorizePassPipeline(
403403

404404
// Tile and distribute to GPU subgroups.
405405
funcPassManager.addPass(createSPIRVTileToCooperativeOpsPass());
406+
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
406407
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
407408
funcPassManager.addPass(createCanonicalizerPass());
408409
funcPassManager.addPass(createCSEPass());

compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,39 @@ namespace mlir::iree_compiler {
2727

2828
/// Replaces the given op with the contents of the given single-block region,
2929
/// using the operands of the block terminator to replace operation results.
30-
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
31-
Region &region, ValueRange blockArgs = {}) {
32-
assert(llvm::hasSingleElement(region) && "expected single-region block");
33-
Block *block = &region.front();
30+
static void replaceOpWithRegion(PatternRewriter &rewriter, scf::ForOp op,
31+
ValueRange blockArgs = {}) {
32+
Block *block = op.getBody();
3433
Operation *terminator = block->getTerminator();
3534
ValueRange results = terminator->getOperands();
3635
rewriter.inlineBlockBefore(block, op, blockArgs);
3736
rewriter.replaceOp(op, results);
3837
rewriter.eraseOp(terminator);
3938
}
4039

40+
/// Same as `replaceOpWithRegion` function but within an scf.if region.
41+
static void replaceForWithIf(PatternRewriter &rewriter, scf::ForOp op,
42+
ValueRange blockArgs = {}) {
43+
Block *block = op.getBody();
44+
ValueRange initArgs = op.getInitArgs();
45+
Value count =
46+
rewriter.create<arith::CmpIOp>(op->getLoc(), arith::CmpIPredicate::sgt,
47+
op.getUpperBound(), op.getLowerBound());
48+
auto ifOp =
49+
rewriter.create<scf::IfOp>(op->getLoc(), op.getResultTypes(), count,
50+
/*withElseRegion=*/initArgs.size() != 0);
51+
Operation *terminator = block->getTerminator();
52+
rewriter.inlineBlockBefore(block, &ifOp.getThenRegion().front(),
53+
ifOp.getThenRegion().front().begin(), blockArgs);
54+
if (initArgs.size() == 0) {
55+
rewriter.eraseOp(terminator);
56+
} else {
57+
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
58+
rewriter.create<scf::YieldOp>(ifOp.getLoc(), initArgs);
59+
}
60+
rewriter.replaceOp(op, ifOp);
61+
}
62+
4163
/// Return true if we can prove that the we always run at least the first
4264
/// iteration of the ForOp.
4365
static bool alwaysRunsFirstIteration(scf::ForOp op) {
@@ -75,20 +97,22 @@ struct SimplifyTrivialLoops : public OpRewritePattern<scf::ForOp> {
7597

7698
LogicalResult matchAndRewrite(scf::ForOp op,
7799
PatternRewriter &rewriter) const override {
78-
// TODO: Handle the case where we know that the loop doesn't run more than
79-
// once but the loop may not run at least once by replace the `loop` with an
80-
// `if`.
81-
if (!(alwaysRunsFirstIteration(op) && neverRunsSecondIteration(op))) {
100+
if (!(neverRunsSecondIteration(op))) {
82101
return failure();
83102
}
84103

85-
// The first iteration is always run and the second iteration is never run
86-
// so the loop always have 1 iteration. Inline its body and remove the loop.
104+
// The second iteration is never run
105+
// so the loop atmost can have 1 iteration. Inline its body and remove the
106+
// loop.
87107
SmallVector<Value> blockArgs;
88108
blockArgs.reserve(op.getInitArgs().size() + 1);
89109
blockArgs.push_back(op.getLowerBound());
90110
llvm::append_range(blockArgs, op.getInitArgs());
91-
replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
111+
if (alwaysRunsFirstIteration(op)) {
112+
replaceOpWithRegion(rewriter, op, blockArgs);
113+
} else {
114+
replaceForWithIf(rewriter, op, blockArgs);
115+
}
92116
return success();
93117
}
94118
};

0 commit comments

Comments
 (0)