Skip to content

Commit e793c80

Browse files
committed
Adding to execute_region_op some missing support
1 parent 4ad625b commit e793c80

File tree

3 files changed

+111
-2
lines changed

3 files changed

+111
-2
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def ConditionOp : SCF_Op<"condition", [
7777
//===----------------------------------------------------------------------===//
7878

7979
def ExecuteRegionOp : SCF_Op<"execute_region", [
80-
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
80+
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
81+
RecursiveMemoryEffects,
82+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {//, SingleBlockImplicitTerminator<"scf::YieldOp">]> { //, RecursiveMemoryEffects]> {
8183
let summary = "operation that executes its region exactly once";
8284
let description = [{
8385
The `scf.execute_region` operation is used to allow multiple blocks within SCF

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,94 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
291291
}
292292
};
293293

294+
// Pattern to eliminate ExecuteRegionOp results when it only forwards external
295+
// values. It operates only on execute regions with single terminator yield
296+
// operation.
297+
struct ExecuteRegionForwardingEliminator
298+
: public OpRewritePattern<ExecuteRegionOp> {
299+
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
300+
301+
LogicalResult matchAndRewrite(ExecuteRegionOp op,
302+
PatternRewriter &rewriter) const override {
303+
if (op.getNumResults() == 0)
304+
return failure();
305+
306+
SmallVector<Operation *> yieldOps;
307+
for (Block &block : op.getRegion()) {
308+
if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) {
309+
if (yield.getResults().empty())
310+
continue;
311+
yieldOps.push_back(yield.getOperation());
312+
}
313+
}
314+
315+
if (yieldOps.size() != 1)
316+
return failure();
317+
318+
auto yieldOp = dyn_cast<scf::YieldOp>(yieldOps.front());
319+
auto yieldedValues = yieldOp.getOperands();
320+
// Check if all yielded values are from outside the region
321+
bool allExternal = true;
322+
for (Value yieldedValue : yieldedValues) {
323+
if (isValueFromInsideRegion(yieldedValue, op)) {
324+
allExternal = false;
325+
break;
326+
}
327+
}
328+
329+
if (!allExternal)
330+
return failure();
331+
332+
// All yielded values are external - create a new execute_region with no
333+
// results.
334+
auto newOp = rewriter.create<ExecuteRegionOp>(op.getLoc(), TypeRange{});
335+
newOp->setAttrs(op->getAttrs());
336+
337+
// Move the region content to the new operation
338+
newOp.getRegion().takeBody(op.getRegion());
339+
340+
// Replace the yield operation with a new yield operation with no results.
341+
rewriter.setInsertionPoint(yieldOp);
342+
rewriter.eraseOp(yieldOp);
343+
rewriter.create<scf::YieldOp>(yieldOp.getLoc());
344+
345+
// Replace the old operation with the external values directly.
346+
rewriter.replaceOp(op, yieldedValues);
347+
return success();
348+
}
349+
350+
private:
351+
bool isValueFromInsideRegion(Value value,
352+
ExecuteRegionOp executeRegionOp) const {
353+
// Check if the value is defined within the execute_region
354+
if (Operation *defOp = value.getDefiningOp()) {
355+
return executeRegionOp.getRegion().isAncestor(defOp->getParentRegion());
356+
}
357+
358+
// If it's a block argument, check if it's from within the region
359+
if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) {
360+
return executeRegionOp.getRegion().isAncestor(blockArg.getParentRegion());
361+
}
362+
363+
return false; // Value is from outside the region
364+
}
365+
};
366+
294367
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
295368
MLIRContext *context) {
296-
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
369+
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
370+
ExecuteRegionForwardingEliminator>(context);
371+
}
372+
373+
void ExecuteRegionOp::getEffects(
374+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
375+
&effects) {
376+
if (!getNoInline())
377+
return;
378+
// In case there is attribute no_inline we want the region not to be inlined
379+
// into the parent operation.
380+
effects.emplace_back(MemoryEffects::Write::get(),
381+
SideEffects::DefaultResource::get());
297382
}
298383

299384
void ExecuteRegionOp::getSuccessorRegions(

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,28 @@ func.func @func_execute_region_inline_multi_yield() {
16041604

16051605
// -----
16061606

1607+
// CHECK-LABEL: func.func private @canonicalize_execute_region_yeilding_external_value(
1608+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x120xui8>) -> tensor<1x60xui8> {
1609+
// CHECK: %[[VAL_1:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1610+
// CHECK: scf.execute_region no_inline {
1611+
// CHECK: scf.yield
1612+
// CHECK: }
1613+
// CHECK: %[[VAL_2:.*]] = bufferization.to_tensor %[[VAL_1]] : memref<1x60xui8> to tensor<1x60xui8>
1614+
// CHECK: return %[[VAL_2]] : tensor<1x60xui8>
1615+
// CHECK: }
1616+
1617+
func.func private @canonicalize_execute_region_yeilding_external_value(%arg0: tensor<1x120xui8>) -> tensor<1x60xui8> {
1618+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1619+
%0 = bufferization.to_tensor %alloc : memref<1x60xui8> to tensor<1x60xui8>
1620+
%1 = scf.execute_region -> memref<1x60xui8> no_inline {
1621+
scf.yield %alloc : memref<1x60xui8>
1622+
}
1623+
%2 = bufferization.to_tensor %1 : memref<1x60xui8> to tensor<1x60xui8>
1624+
return %2 : tensor<1x60xui8>
1625+
}
1626+
1627+
// -----
1628+
16071629
// CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices(
16081630
// CHECK-SAME: %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor<?x?xf32>
16091631
func.func @canonicalize_parallel_insert_slice_indices(

0 commit comments

Comments
 (0)