Skip to content

Commit a35f1fc

Browse files
ddubov100joker-eph
authored andcommitted
Adding to execute_region_op some missing support (llvm#164159)
Adding canonicalization pattern in case execute_region op has yieldOps which operands are from outside the execute_region, then it simplifies the op to return just internal values. The pattern is applied only in case all yieldOps within execute_region_op have same operands --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent 883a01b commit a35f1fc

File tree

2 files changed

+237
-1
lines changed

2 files changed

+237
-1
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2828
#include "mlir/Transforms/InliningUtils.h"
2929
#include "llvm/ADT/MapVector.h"
30+
#include "llvm/ADT/STLExtras.h"
3031
#include "llvm/ADT/SmallPtrSet.h"
3132
#include "llvm/Support/Casting.h"
3233
#include "llvm/Support/DebugLog.h"
@@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
291292
}
292293
};
293294

295+
// Pattern to eliminate ExecuteRegionOp results which forward external
296+
// values from the region. In case there are multiple yield operations,
297+
// all of them must have the same operands in order for the pattern to be
298+
// applicable.
299+
struct ExecuteRegionForwardingEliminator
300+
: public OpRewritePattern<ExecuteRegionOp> {
301+
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
302+
303+
LogicalResult matchAndRewrite(ExecuteRegionOp op,
304+
PatternRewriter &rewriter) const override {
305+
if (op.getNumResults() == 0)
306+
return failure();
307+
308+
SmallVector<Operation *> yieldOps;
309+
for (Block &block : op.getRegion()) {
310+
if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
311+
yieldOps.push_back(yield.getOperation());
312+
}
313+
314+
if (yieldOps.empty())
315+
return failure();
316+
317+
// Check if all yield operations have the same operands.
318+
auto yieldOpsOperands = yieldOps[0]->getOperands();
319+
for (auto *yieldOp : yieldOps) {
320+
if (yieldOp->getOperands() != yieldOpsOperands)
321+
return failure();
322+
}
323+
324+
SmallVector<Value> externalValues;
325+
SmallVector<Value> internalValues;
326+
SmallVector<Value> opResultsToReplaceWithExternalValues;
327+
SmallVector<Value> opResultsToKeep;
328+
for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
329+
if (isValueFromInsideRegion(yieldedValue, op)) {
330+
internalValues.push_back(yieldedValue);
331+
opResultsToKeep.push_back(op.getResult(index));
332+
} else {
333+
externalValues.push_back(yieldedValue);
334+
opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
335+
}
336+
}
337+
// No yielded external values - nothing to do.
338+
if (externalValues.empty())
339+
return failure();
340+
341+
// There are yielded external values - create a new execute_region returning
342+
// just the internal values.
343+
SmallVector<Type> resultTypes;
344+
for (Value value : internalValues)
345+
resultTypes.push_back(value.getType());
346+
auto newOp =
347+
ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
348+
newOp->setAttrs(op->getAttrs());
349+
350+
// Move old op's region to the new operation.
351+
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
352+
newOp.getRegion().end());
353+
354+
// Replace all yield operations with a new yield operation with updated
355+
// results. scf.execute_region must have at least one yield operation.
356+
for (auto *yieldOp : yieldOps) {
357+
rewriter.setInsertionPoint(yieldOp);
358+
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
359+
ValueRange(internalValues));
360+
}
361+
362+
// Replace the old operation with the external values directly.
363+
rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
364+
externalValues);
365+
// Replace the old operation's remaining results with the new operation's
366+
// results.
367+
rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
368+
rewriter.eraseOp(op);
369+
return success();
370+
}
371+
372+
private:
373+
bool isValueFromInsideRegion(Value value,
374+
ExecuteRegionOp executeRegionOp) const {
375+
// Check if the value is defined within the execute_region
376+
if (Operation *defOp = value.getDefiningOp())
377+
return &executeRegionOp.getRegion() == defOp->getParentRegion();
378+
379+
// If it's a block argument, check if it's from within the region
380+
if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
381+
return &executeRegionOp.getRegion() == blockArg.getParentRegion();
382+
383+
return false; // Value is from outside the region
384+
}
385+
};
386+
294387
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
295388
MLIRContext *context) {
296-
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
389+
results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
390+
ExecuteRegionForwardingEliminator>(context);
297391
}
298392

299393
void ExecuteRegionOp::getSuccessorRegions(

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,148 @@ func.func @func_execute_region_inline_multi_yield() {
16041604

16051605
// -----
16061606

1607+
// Test case with single scf.yield op inside execute_region and its operand is defined outside the execute_region op.
1608+
// Make scf.execute_region not to return anything.
1609+
1610+
// CHECK: scf.execute_region no_inline {
1611+
// CHECK: func.call @foo() : () -> ()
1612+
// CHECK: scf.yield
1613+
// CHECK: }
1614+
1615+
module {
1616+
func.func private @foo()->()
1617+
func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> {
1618+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1619+
%1 = scf.execute_region -> memref<1x60xui8> no_inline {
1620+
func.call @foo():()->()
1621+
scf.yield %alloc: memref<1x60xui8>
1622+
}
1623+
return %1 : memref<1x60xui8>
1624+
}
1625+
}
1626+
1627+
// -----
1628+
1629+
// Test case with scf.yield op inside execute_region with multiple operands.
1630+
// One of operands is defined outside the execute_region op.
1631+
// Remove just this operand from the op results.
1632+
1633+
// CHECK: %[[VAL_1:.*]] = scf.execute_region -> memref<1x120xui8> no_inline {
1634+
// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
1635+
// CHECK: func.call @foo() : () -> ()
1636+
// CHECK: scf.yield %[[VAL_2]] : memref<1x120xui8>
1637+
// CHECK: }
1638+
module {
1639+
func.func private @foo()->()
1640+
func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) {
1641+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1642+
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
1643+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
1644+
func.call @foo():()->()
1645+
scf.yield %alloc, %alloc_1: memref<1x60xui8>, memref<1x120xui8>
1646+
}
1647+
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
1648+
}
1649+
}
1650+
1651+
// -----
1652+
1653+
// Test case with multiple scf.yield ops inside execute_region with same operands and those operands are defined outside the execute_region op..
1654+
// Make scf.execute_region not to return anything.
1655+
// scf.yield must remain, cause scf.execute_region can't be empty.
1656+
1657+
// CHECK: scf.execute_region no_inline {
1658+
// CHECK: %[[VAL_3:.*]] = "test.cmp"() : () -> i1
1659+
// CHECK: cf.cond_br %[[VAL_3]], ^bb1, ^bb2
1660+
// CHECK: ^bb1:
1661+
// CHECK: scf.yield
1662+
// CHECK: ^bb2:
1663+
// CHECK: scf.yield
1664+
// CHECK: }
1665+
1666+
module {
1667+
func.func private @foo()->()
1668+
func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
1669+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1670+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
1671+
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
1672+
%c = "test.cmp"() : () -> i1
1673+
cf.cond_br %c, ^bb2, ^bb3
1674+
^bb2:
1675+
func.call @foo():()->()
1676+
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
1677+
^bb3:
1678+
func.call @foo():()->()
1679+
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
1680+
}
1681+
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
1682+
}
1683+
}
1684+
1685+
// -----
1686+
1687+
// Test case with multiple scf.yield ops with at least one different operand, then no change.
1688+
1689+
// CHECK: %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
1690+
// CHECK: ^bb1:
1691+
// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
1692+
// CHECK: ^bb2:
1693+
// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
1694+
// CHECK: }
1695+
1696+
module {
1697+
func.func private @foo()->()
1698+
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
1699+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1700+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
1701+
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
1702+
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
1703+
%c = "test.cmp"() : () -> i1
1704+
cf.cond_br %c, ^bb2, ^bb3
1705+
^bb2:
1706+
func.call @foo():()->()
1707+
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
1708+
^bb3:
1709+
func.call @foo():()->()
1710+
scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8>
1711+
}
1712+
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
1713+
}
1714+
}
1715+
1716+
// -----
1717+
1718+
// Test case with multiple scf.yield ops each has different operand.
1719+
// In this case scf.execute_region isn't changed.
1720+
1721+
// CHECK: %[[VAL_2:.*]] = scf.execute_region -> memref<1x60xui8> no_inline {
1722+
// CHECK: ^bb1:
1723+
// CHECK: scf.yield %{{.*}} : memref<1x60xui8>
1724+
// CHECK: ^bb2:
1725+
// CHECK: scf.yield %{{.*}} : memref<1x60xui8>
1726+
// CHECK: }
1727+
1728+
module {
1729+
func.func private @foo()->()
1730+
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>) {
1731+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1732+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1733+
%1 = scf.execute_region -> (memref<1x60xui8>) no_inline {
1734+
%c = "test.cmp"() : () -> i1
1735+
cf.cond_br %c, ^bb2, ^bb3
1736+
^bb2:
1737+
func.call @foo():()->()
1738+
scf.yield %alloc : memref<1x60xui8>
1739+
^bb3:
1740+
func.call @foo():()->()
1741+
scf.yield %alloc_1 : memref<1x60xui8>
1742+
}
1743+
return %1 : memref<1x60xui8>
1744+
}
1745+
}
1746+
1747+
// -----
1748+
16071749
// CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices(
16081750
// CHECK-SAME: %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor<?x?xf32>
16091751
func.func @canonicalize_parallel_insert_slice_indices(

0 commit comments

Comments
 (0)