Skip to content

Commit d3c2c21

Browse files
committed
CR 2
1 parent 7bf573a commit d3c2c21

File tree

3 files changed

+45
-29
lines changed

3 files changed

+45
-29
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ def ConditionOp : SCF_Op<"condition", [
7777
//===----------------------------------------------------------------------===//
7878

7979
def ExecuteRegionOp : SCF_Op<"execute_region", [
80-
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
81-
RecursiveMemoryEffects]> {
80+
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
8281
let summary = "operation that executes its region exactly once";
8382
let description = [{
8483
The `scf.execute_region` operation is used to allow multiple blocks within SCF

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,10 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
292292
}
293293
};
294294

295-
// Pattern to eliminate ExecuteRegionOp results when it only forwards external
296-
// values. It operates only on execute regions with single terminator yield
297-
// operation.
295+
// Pattern to eliminate ExecuteRegionOp results which forward external
296+
// values from the region. In case there are multiple yield operations,
297+
// all of them must have the same operands iin order for the pattern to be
298+
// applicable.
298299
struct ExecuteRegionForwardingEliminator
299300
: public OpRewritePattern<ExecuteRegionOp> {
300301
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
@@ -306,11 +307,8 @@ struct ExecuteRegionForwardingEliminator
306307

307308
SmallVector<Operation *> yieldOps;
308309
for (Block &block : op.getRegion()) {
309-
if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) {
310-
if (yield.getOperands().empty())
311-
continue;
310+
if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
312311
yieldOps.push_back(yield.getOperation());
313-
}
314312
}
315313

316314
if (yieldOps.empty())
@@ -323,36 +321,52 @@ struct ExecuteRegionForwardingEliminator
323321
return failure();
324322
}
325323

326-
// Check if all yielded values are from outside the region
327-
bool allExternal = true;
328-
for (Value yieldedValue : yieldOpsOperands) {
324+
SmallVector<Value> externalValues;
325+
SmallVector<Value> internalValues;
326+
SmallVector<Value> opResultsToReplaceWithExternalValues;
327+
SmallVector<Value> opResultsToKeep;
328+
for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
329329
if (isValueFromInsideRegion(yieldedValue, op)) {
330-
allExternal = false;
331-
break;
330+
internalValues.push_back(yieldedValue);
331+
opResultsToKeep.push_back(op.getResult(index));
332+
} else {
333+
externalValues.push_back(yieldedValue);
334+
opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
332335
}
333336
}
334-
335-
if (!allExternal)
337+
// No yeilded external values - nothing to do.
338+
if (externalValues.empty())
336339
return failure();
337340

338-
// All yielded values are external - create a new execute_region with no
339-
// results.
340-
auto newOp = rewriter.create<ExecuteRegionOp>(op.getLoc(), TypeRange{});
341+
// There are yielded external values - create a new execute_region returning
342+
// just the internal values.
343+
SmallVector<Type> resultTypes;
344+
for (Value value : internalValues)
345+
resultTypes.push_back(value.getType());
346+
auto newOp =
347+
rewriter.create<ExecuteRegionOp>(op.getLoc(), TypeRange(resultTypes));
341348
newOp->setAttrs(op->getAttrs());
342349

343-
// Move the region content to the new operation
344-
newOp.getRegion().takeBody(op.getRegion());
350+
// Move old op's region to the new operation.
351+
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
352+
newOp.getRegion().end());
345353

346-
// Replace all yield operations with a new yield operation with no results.
347-
// scf.execute_region must have at least one yield operation.
354+
// Replace all yield operations with a new yield operation with updated
355+
// results. scf.execute_region must have at least one yield operation.
348356
for (auto *yieldOp : yieldOps) {
349357
rewriter.setInsertionPoint(yieldOp);
350358
rewriter.eraseOp(yieldOp);
351-
rewriter.create<scf::YieldOp>(yieldOp->getLoc());
359+
rewriter.create<scf::YieldOp>(yieldOp->getLoc(),
360+
ValueRange(internalValues));
352361
}
353362

354363
// Replace the old operation with the external values directly.
355-
rewriter.replaceOp(op, yieldOpsOperands);
364+
rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
365+
externalValues);
366+
// Replace the old operation's remaining results with the new operation's
367+
// results.
368+
rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
369+
rewriter.eraseOp(op);
356370
return success();
357371
}
358372

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,6 +1608,7 @@ func.func @func_execute_region_inline_multi_yield() {
16081608
// Make scf.execute_region not to return anything.
16091609

16101610
// CHECK: scf.execute_region no_inline {
1611+
// CHECK: func.call @foo() : () -> ()
16111612
// CHECK: scf.yield
16121613
// CHECK: }
16131614

@@ -1627,11 +1628,13 @@ func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8>
16271628

16281629
// Test case with scf.yield op inside execute_region with multiple operands.
16291630
// One of operands is defined outside the execute_region op.
1630-
// In this case scf.execute_region doesn't change.
1631-
1632-
// CHECK: %[[VAL_1:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
1633-
// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
1631+
// Remove just this operand from the op results.
16341632

1633+
// CHECK: %[[VAL_1:.*]] = scf.execute_region -> memref<1x120xui8> no_inline {
1634+
// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
1635+
// CHECK: func.call @foo() : () -> ()
1636+
// CHECK: scf.yield %[[VAL_2]] : memref<1x120xui8>
1637+
// CHECK: }
16351638
module {
16361639
func.func private @foo()->()
16371640
func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) {

0 commit comments

Comments
 (0)