Skip to content

Commit 7bf573a

Browse files
committed
Code review comments
1 parent 629949c commit 7bf573a

File tree

4 files changed

+153
-55
lines changed

4 files changed

+153
-55
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ def ConditionOp : SCF_Op<"condition", [
7878

7979
def ExecuteRegionOp : SCF_Op<"execute_region", [
8080
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
81-
RecursiveMemoryEffects,
82-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {//, SingleBlockImplicitTerminator<"scf::YieldOp">]> { //, RecursiveMemoryEffects]> {
81+
RecursiveMemoryEffects]> {
8382
let summary = "operation that executes its region exactly once";
8483
let description = [{
8584
The `scf.execute_region` operation is used to allow multiple blocks within SCF

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2828
#include "mlir/Transforms/InliningUtils.h"
2929
#include "llvm/ADT/MapVector.h"
30+
#include "llvm/ADT/STLExtras.h"
3031
#include "llvm/ADT/SmallPtrSet.h"
3132
#include "llvm/Support/Casting.h"
3233
#include "llvm/Support/DebugLog.h"
@@ -306,20 +307,25 @@ struct ExecuteRegionForwardingEliminator
306307
SmallVector<Operation *> yieldOps;
307308
for (Block &block : op.getRegion()) {
308309
if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) {
309-
if (yield.getResults().empty())
310+
if (yield.getOperands().empty())
310311
continue;
311312
yieldOps.push_back(yield.getOperation());
312313
}
313314
}
314315

315-
if (yieldOps.size() != 1)
316+
if (yieldOps.empty())
316317
return failure();
317318

318-
auto yieldOp = cast<scf::YieldOp>(yieldOps.front());
319-
auto yieldedValues = yieldOp.getOperands();
319+
// Check if all yield operations have the same operands.
320+
auto yieldOpsOperands = yieldOps[0]->getOperands();
321+
for (auto *yieldOp : yieldOps) {
322+
if (yieldOp->getOperands() != yieldOpsOperands)
323+
return failure();
324+
}
325+
320326
// Check if all yielded values are from outside the region
321327
bool allExternal = true;
322-
for (Value yieldedValue : yieldedValues) {
328+
for (Value yieldedValue : yieldOpsOperands) {
323329
if (isValueFromInsideRegion(yieldedValue, op)) {
324330
allExternal = false;
325331
break;
@@ -337,13 +343,16 @@ struct ExecuteRegionForwardingEliminator
337343
// Move the region content to the new operation
338344
newOp.getRegion().takeBody(op.getRegion());
339345

340-
// Replace the yield operation with a new yield operation with no results.
341-
rewriter.setInsertionPoint(yieldOp);
342-
rewriter.eraseOp(yieldOp);
343-
rewriter.create<scf::YieldOp>(yieldOp.getLoc());
346+
// Replace all yield operations with a new yield operation with no results.
347+
// scf.execute_region must have at least one yield operation.
348+
for (auto *yieldOp : yieldOps) {
349+
rewriter.setInsertionPoint(yieldOp);
350+
rewriter.eraseOp(yieldOp);
351+
rewriter.create<scf::YieldOp>(yieldOp->getLoc());
352+
}
344353

345354
// Replace the old operation with the external values directly.
346-
rewriter.replaceOp(op, yieldedValues);
355+
rewriter.replaceOp(op, yieldOpsOperands);
347356
return success();
348357
}
349358

@@ -352,12 +361,11 @@ struct ExecuteRegionForwardingEliminator
352361
ExecuteRegionOp executeRegionOp) const {
353362
// Check if the value is defined within the execute_region
354363
if (Operation *defOp = value.getDefiningOp())
355-
return executeRegionOp.getRegion() = defOp->getParentRegion();
364+
return &executeRegionOp.getRegion() == defOp->getParentRegion();
356365

357366
// If it's a block argument, check if it's from within the region
358-
if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) {
359-
return executeRegionOp.getRegion() == blockArg.getParentRegion());
360-
}
367+
if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
368+
return &executeRegionOp.getRegion() == blockArg.getParentRegion();
361369

362370
return false; // Value is from outside the region
363371
}
@@ -369,17 +377,6 @@ void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
369377
ExecuteRegionForwardingEliminator>(context);
370378
}
371379

372-
void ExecuteRegionOp::getEffects(
373-
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
374-
&effects) {
375-
if (!getNoInline())
376-
return;
377-
// In case there is attribute no_inline we want the region not to be inlined
378-
// into the parent operation.
379-
effects.emplace_back(MemoryEffects::Write::get(),
380-
SideEffects::DefaultResource::get());
381-
}
382-
383380
void ExecuteRegionOp::getSuccessorRegions(
384381
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
385382
// If the predecessor is the ExecuteRegionOp, branch into the body.

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -377,21 +377,6 @@ func.func private @execute_region_test(%t1 : tensor<?xf32>)
377377
// CHECK: return %{{.*}}, %{{.*}} : f32, f32
378378
return %0, %1, %2 : f32, tensor<?xf32>, f32
379379
}
380-
381-
// -----
382-
383-
// CHECK-LABEL: func @no_inline_execute_region_not_canonicalized
384-
func.func @no_inline_execute_region_not_canonicalized() {
385-
%c = arith.constant 42 : i32
386-
// CHECK: scf.execute_region
387-
// CHECK-SAME: no_inline
388-
%v = scf.execute_region -> i32 no_inline {
389-
scf.yield %c : i32
390-
}
391-
// CHECK: return
392-
return
393-
}
394-
395380
// -----
396381

397382
// CHECK: func private @some_external_func(memref<?xf32, strided<[?], offset: ?>>)

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 130 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,24 +1604,141 @@ func.func @func_execute_region_inline_multi_yield() {
16041604

16051605
// -----
16061606

1607-
// CHECK-LABEL: func.func private @canonicalize_execute_region_yeilding_external_value(
1608-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x120xui8>) -> tensor<1x60xui8> {
1609-
// CHECK: %[[VAL_1:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1607+
// Test case with single scf.yield op inside execute_region and its operand is defined outside the execute_region op.
1608+
// Make scf.execute_region not to return anything.
1609+
16101610
// CHECK: scf.execute_region no_inline {
16111611
// CHECK: scf.yield
16121612
// CHECK: }
1613-
// CHECK: %[[VAL_2:.*]] = bufferization.to_tensor %[[VAL_1]] : memref<1x60xui8> to tensor<1x60xui8>
1614-
// CHECK: return %[[VAL_2]] : tensor<1x60xui8>
1615-
// CHECK: }
16161613

1617-
func.func private @canonicalize_execute_region_yeilding_external_value(%arg0: tensor<1x120xui8>) -> tensor<1x60xui8> {
1618-
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1619-
%0 = bufferization.to_tensor %alloc : memref<1x60xui8> to tensor<1x60xui8>
1620-
%1 = scf.execute_region -> memref<1x60xui8> no_inline {
1621-
scf.yield %alloc : memref<1x60xui8>
1614+
module {
1615+
func.func private @foo()->()
1616+
func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> {
1617+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1618+
%1 = scf.execute_region -> memref<1x60xui8> no_inline {
1619+
func.call @foo():()->()
1620+
scf.yield %alloc: memref<1x60xui8>
1621+
}
1622+
return %1 : memref<1x60xui8>
1623+
}
1624+
}
1625+
1626+
// -----
1627+
1628+
// Test case with scf.yield op inside execute_region with multiple operands.
1629+
// One of operands is defined outside the execute_region op.
1630+
// In this case scf.execute_region doesn't change.
1631+
1632+
// CHECK: %[[VAL_1:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
1633+
// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
1634+
1635+
module {
1636+
func.func private @foo()->()
1637+
func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) {
1638+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1639+
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
1640+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
1641+
func.call @foo():()->()
1642+
scf.yield %alloc, %alloc_1: memref<1x60xui8>, memref<1x120xui8>
1643+
}
1644+
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
1645+
}
1646+
}
1647+
1648+
// -----
1649+
1650+
// Test case with multiple scf.yield ops inside execute_region with same operands and those operands are defined outside the execute_region op..
1651+
// Make scf.execute_region not to return anything.
1652+
// scf.yield must remain, cause scf.execute_region can't be empty.
1653+
1654+
// CHECK: scf.execute_region no_inline {
1655+
// CHECK: %[[VAL_3:.*]] = "test.cmp"() : () -> i1
1656+
// CHECK: cf.cond_br %[[VAL_3]], ^bb1, ^bb2
1657+
// CHECK: ^bb1:
1658+
// CHECK: scf.yield
1659+
// CHECK: ^bb2:
1660+
// CHECK: scf.yield
1661+
// CHECK: }
1662+
1663+
module {
1664+
func.func private @foo()->()
1665+
func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
1666+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1667+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
1668+
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
1669+
%c = "test.cmp"() : () -> i1
1670+
cf.cond_br %c, ^bb2, ^bb3
1671+
^bb2:
1672+
func.call @foo():()->()
1673+
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
1674+
^bb3:
1675+
func.call @foo():()->()
1676+
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
1677+
}
1678+
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
16221679
}
1623-
%2 = bufferization.to_tensor %1 : memref<1x60xui8> to tensor<1x60xui8>
1624-
return %2 : tensor<1x60xui8>
1680+
}
1681+
1682+
// -----
1683+
1684+
// Test case with multiple scf.yield ops with at least one different operand, then no change.
1685+
1686+
// CHECK: %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
1687+
// CHECK: ^bb1:
1688+
// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
1689+
// CHECK: ^bb2:
1690+
// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
1691+
// CHECK: }
1692+
1693+
module {
1694+
func.func private @foo()->()
1695+
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
1696+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1697+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
1698+
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
1699+
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
1700+
%c = "test.cmp"() : () -> i1
1701+
cf.cond_br %c, ^bb2, ^bb3
1702+
^bb2:
1703+
func.call @foo():()->()
1704+
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
1705+
^bb3:
1706+
func.call @foo():()->()
1707+
scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8>
1708+
}
1709+
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
1710+
}
1711+
}
1712+
1713+
// -----
1714+
1715+
// Test case with multiple scf.yield ops each has different operand.
1716+
// In this case scf.execute_region isn't changed.
1717+
1718+
// CHECK: %[[VAL_2:.*]] = scf.execute_region -> memref<1x60xui8> no_inline {
1719+
// CHECK: ^bb1:
1720+
// CHECK: scf.yield %{{.*}} : memref<1x60xui8>
1721+
// CHECK: ^bb2:
1722+
// CHECK: scf.yield %{{.*}} : memref<1x60xui8>
1723+
// CHECK: }
1724+
1725+
module {
1726+
func.func private @foo()->()
1727+
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>) {
1728+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1729+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
1730+
%1 = scf.execute_region -> (memref<1x60xui8>) no_inline {
1731+
%c = "test.cmp"() : () -> i1
1732+
cf.cond_br %c, ^bb2, ^bb3
1733+
^bb2:
1734+
func.call @foo():()->()
1735+
scf.yield %alloc : memref<1x60xui8>
1736+
^bb3:
1737+
func.call @foo():()->()
1738+
scf.yield %alloc_1 : memref<1x60xui8>
1739+
}
1740+
return %1 : memref<1x60xui8>
1741+
}
16251742
}
16261743

16271744
// -----

0 commit comments

Comments
 (0)