Skip to content

Commit 1a32404

Browse files
committed
Adding to execute_region_op some missing support
1 parent 4ad625b commit 1a32404

File tree

3 files changed

+107
-3
lines changed

3 files changed

+107
-3
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def ConditionOp : SCF_Op<"condition", [
7777
//===----------------------------------------------------------------------===//
7878

7979
def ExecuteRegionOp : SCF_Op<"execute_region", [
80-
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
80+
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
81+
RecursiveMemoryEffects,
82+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {//, SingleBlockImplicitTerminator<"scf::YieldOp">]> { //, RecursiveMemoryEffects]> {
8183
let summary = "operation that executes its region exactly once";
8284
let description = [{
8385
The `scf.execute_region` operation is used to allow multiple blocks within SCF

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,89 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
291291
}
292292
};
293293

294+
// Pattern to eliminate ExecuteRegionOp results when it only forwards external values.
295+
// It operates only on execute regions with single terminator yield operation.
296+
struct ExecuteRegionForwardingEliminator : public OpRewritePattern<ExecuteRegionOp> {
297+
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
298+
299+
LogicalResult matchAndRewrite(ExecuteRegionOp op,
300+
PatternRewriter &rewriter) const override {
301+
if (op.getNumResults() == 0)
302+
return failure();
303+
304+
SmallVector<Operation*> yieldOps;
305+
for (Block &block : op.getRegion()) {
306+
if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) {
307+
if (yield.getResults().empty())
308+
continue;
309+
yieldOps.push_back(yield.getOperation());
310+
}
311+
}
312+
313+
if(yieldOps.size() != 1)
314+
return failure();
315+
316+
auto yieldOp = dyn_cast<scf::YieldOp>(yieldOps.front());
317+
auto yieldedValues = yieldOp.getOperands();
318+
// Check if all yielded values are from outside the region
319+
bool allExternal = true;
320+
for (Value yieldedValue : yieldedValues) {
321+
if (isValueFromInsideRegion(yieldedValue, op)) {
322+
allExternal = false;
323+
break;
324+
}
325+
}
326+
327+
if (!allExternal)
328+
return failure();
329+
330+
// All yielded values are external - create a new execute_region with no results.
331+
auto newOp = rewriter.create<ExecuteRegionOp>(op.getLoc(), TypeRange{});
332+
newOp->setAttrs(op->getAttrs());
333+
334+
// Move the region content to the new operation
335+
newOp.getRegion().takeBody(op.getRegion());
336+
337+
// Replace the yield operation with a new yield operation with no results.
338+
rewriter.setInsertionPoint(yieldOp);
339+
rewriter.eraseOp(yieldOp);
340+
rewriter.create<scf::YieldOp>(yieldOp.getLoc());
341+
342+
// Replace the old operation with the external values directly.
343+
rewriter.replaceOp(op, yieldedValues);
344+
return success();
345+
}
346+
347+
private:
348+
bool isValueFromInsideRegion(Value value, ExecuteRegionOp executeRegionOp) const {
349+
// Check if the value is defined within the execute_region
350+
if (Operation *defOp = value.getDefiningOp()) {
351+
return executeRegionOp.getRegion().isAncestor(defOp->getParentRegion());
352+
}
353+
354+
// If it's a block argument, check if it's from within the region
355+
if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) {
356+
return executeRegionOp.getRegion().isAncestor(blockArg.getParentRegion());
357+
}
358+
359+
return false; // Value is from outside the region
360+
}
361+
};
362+
294363
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
295-
MLIRContext *context) {
296-
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
364+
MLIRContext *context) {
365+
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
366+
ExecuteRegionForwardingEliminator>(context);
367+
}
368+
369+
void ExecuteRegionOp::getEffects(
370+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
371+
&effects) {
372+
if (!getNoInline())
373+
return;
374+
// In case there is attribute no_inline we want the region not to be inlined into the parent operation.
375+
effects.emplace_back(MemoryEffects::Write::get(),
376+
SideEffects::DefaultResource::get());
297377
}
298378

299379
void ExecuteRegionOp::getSuccessorRegions(

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,28 @@ func.func @func_execute_region_inline_multi_yield() {
16041604

16051605
// -----
16061606

1607+
// CHECK-LABEL: func.func private @canonicalize_execute_region_yeilding_external_value(
1608+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x120xui8>) -> tensor<1x60xui8> {
1609+
// CHECK: %[[VAL_1:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1610+
// CHECK: scf.execute_region no_inline {
1611+
// CHECK: scf.yield
1612+
// CHECK: }
1613+
// CHECK: %[[VAL_2:.*]] = bufferization.to_tensor %[[VAL_1]] : memref<1x60xui8> to tensor<1x60xui8>
1614+
// CHECK: return %[[VAL_2]] : tensor<1x60xui8>
1615+
// CHECK: }
1616+
1617+
func.func private @canonicalize_execute_region_yeilding_external_value(%arg0: tensor<1x120xui8>) -> tensor<1x60xui8> {
1618+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1619+
%0 = bufferization.to_tensor %alloc : memref<1x60xui8> to tensor<1x60xui8>
1620+
%1 = scf.execute_region -> memref<1x60xui8> no_inline {
1621+
scf.yield %alloc : memref<1x60xui8>
1622+
}
1623+
%2 = bufferization.to_tensor %1 : memref<1x60xui8> to tensor<1x60xui8>
1624+
return %2 : tensor<1x60xui8>
1625+
}
1626+
1627+
// -----
1628+
16071629
// CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices(
16081630
// CHECK-SAME: %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor<?x?xf32>
16091631
func.func @canonicalize_parallel_insert_slice_indices(

0 commit comments

Comments
 (0)