Skip to content

Commit 9b98e58

Browse files
committed
Adding to execute_region_op some missing support
1 parent 4ad625b commit 9b98e58

File tree

3 files changed

+110
-2
lines changed

3 files changed

+110
-2
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def ConditionOp : SCF_Op<"condition", [
7777
//===----------------------------------------------------------------------===//
7878

7979
def ExecuteRegionOp : SCF_Op<"execute_region", [
80-
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
80+
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
81+
RecursiveMemoryEffects,
82+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {//, SingleBlockImplicitTerminator<"scf::YieldOp">]> { //, RecursiveMemoryEffects]> {
8183
let summary = "operation that executes its region exactly once";
8284
let description = [{
8385
The `scf.execute_region` operation is used to allow multiple blocks within SCF

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,93 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
291291
}
292292
};
293293

294+
// Pattern to eliminate ExecuteRegionOp results when it only forwards external values.
295+
// It operates only on execute regions with single terminator yield operation.
296+
struct ExecuteRegionForwardingEliminator
297+
: public OpRewritePattern<ExecuteRegionOp> {
298+
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
299+
300+
LogicalResult matchAndRewrite(ExecuteRegionOp op,
301+
PatternRewriter &rewriter) const override {
302+
if (op.getNumResults() == 0)
303+
return failure();
304+
305+
SmallVector<Operation *> yieldOps;
306+
for (Block &block : op.getRegion()) {
307+
if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) {
308+
if (yield.getResults().empty())
309+
continue;
310+
yieldOps.push_back(yield.getOperation());
311+
}
312+
}
313+
314+
if (yieldOps.size() != 1)
315+
return failure();
316+
317+
auto yieldOp = dyn_cast<scf::YieldOp>(yieldOps.front());
318+
auto yieldedValues = yieldOp.getOperands();
319+
// Check if all yielded values are from outside the region
320+
bool allExternal = true;
321+
for (Value yieldedValue : yieldedValues) {
322+
if (isValueFromInsideRegion(yieldedValue, op)) {
323+
allExternal = false;
324+
break;
325+
}
326+
}
327+
328+
if (!allExternal)
329+
return failure();
330+
331+
// All yielded values are external - create a new execute_region with no
332+
// results.
333+
auto newOp = rewriter.create<ExecuteRegionOp>(op.getLoc(), TypeRange{});
334+
newOp->setAttrs(op->getAttrs());
335+
336+
// Move the region content to the new operation
337+
newOp.getRegion().takeBody(op.getRegion());
338+
339+
// Replace the yield operation with a new yield operation with no results.
340+
rewriter.setInsertionPoint(yieldOp);
341+
rewriter.eraseOp(yieldOp);
342+
rewriter.create<scf::YieldOp>(yieldOp.getLoc());
343+
344+
// Replace the old operation with the external values directly.
345+
rewriter.replaceOp(op, yieldedValues);
346+
return success();
347+
}
348+
349+
private:
350+
bool isValueFromInsideRegion(Value value,
351+
ExecuteRegionOp executeRegionOp) const {
352+
// Check if the value is defined within the execute_region
353+
if (Operation *defOp = value.getDefiningOp()) {
354+
return executeRegionOp.getRegion().isAncestor(defOp->getParentRegion());
355+
}
356+
357+
// If it's a block argument, check if it's from within the region
358+
if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) {
359+
return executeRegionOp.getRegion().isAncestor(blockArg.getParentRegion());
360+
}
361+
362+
return false; // Value is from outside the region
363+
}
364+
};
365+
294366
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
295367
MLIRContext *context) {
296-
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
368+
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner, ExecuteRegionForwardingEliminator>(
369+
context);
370+
}
371+
372+
void ExecuteRegionOp::getEffects(
373+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
374+
&effects) {
375+
if (!getNoInline())
376+
return;
377+
// In case there is attribute no_inline we want the region not to be inlined
378+
// into the parent operation.
379+
effects.emplace_back(MemoryEffects::Write::get(),
380+
SideEffects::DefaultResource::get());
297381
}
298382

299383
void ExecuteRegionOp::getSuccessorRegions(

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,28 @@ func.func @func_execute_region_inline_multi_yield() {
16041604

16051605
// -----
16061606

1607+
// CHECK-LABEL: func.func private @canonicalize_execute_region_yeilding_external_value(
1608+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x120xui8>) -> tensor<1x60xui8> {
1609+
// CHECK: %[[VAL_1:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1610+
// CHECK: scf.execute_region no_inline {
1611+
// CHECK: scf.yield
1612+
// CHECK: }
1613+
// CHECK: %[[VAL_2:.*]] = bufferization.to_tensor %[[VAL_1]] : memref<1x60xui8> to tensor<1x60xui8>
1614+
// CHECK: return %[[VAL_2]] : tensor<1x60xui8>
1615+
// CHECK: }
1616+
1617+
func.func private @canonicalize_execute_region_yeilding_external_value(%arg0: tensor<1x120xui8>) -> tensor<1x60xui8> {
1618+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1619+
%0 = bufferization.to_tensor %alloc : memref<1x60xui8> to tensor<1x60xui8>
1620+
%1 = scf.execute_region -> memref<1x60xui8> no_inline {
1621+
scf.yield %alloc : memref<1x60xui8>
1622+
}
1623+
%2 = bufferization.to_tensor %1 : memref<1x60xui8> to tensor<1x60xui8>
1624+
return %2 : tensor<1x60xui8>
1625+
}
1626+
1627+
// -----
1628+
16071629
// CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices(
16081630
// CHECK-SAME: %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor<?x?xf32>
16091631
func.func @canonicalize_parallel_insert_slice_indices(

0 commit comments

Comments
 (0)