Skip to content

Commit 2dd3d38

Browse files
authored
[MLIR] getBackwardSlice: don't bail on ops that are IsolatedFromAbove (#158135)
Ops with the `IsIsolatedFromAbove` trait should be captured by the backward slice. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent dfad983 commit 2dd3d38

File tree

4 files changed

+37
-6
lines changed

4 files changed

+37
-6
lines changed

mlir/lib/Analysis/SliceAnalysis.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
109109
DenseSet<Operation *> &visited,
110110
SetVector<Operation *> *backwardSlice,
111111
const BackwardSliceOptions &options) {
112-
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
112+
if (!op)
113113
return success();
114114

115115
// Evaluate whether we should keep this def.
@@ -136,7 +136,8 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
136136
// blocks of parentOp, which are not technically backward unless they flow
137137
// into us. For now, just bail.
138138
if (parentOp && backwardSlice->count(parentOp) == 0) {
139-
if (parentOp->getNumRegions() == 1 &&
139+
if (!parentOp->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
140+
parentOp->getNumRegions() == 1 &&
140141
parentOp->getRegion(0).hasOneBlock()) {
141142
return getBackwardSliceImpl(parentOp, visited, backwardSlice,
142143
options);
@@ -150,7 +151,8 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
150151

151152
bool succeeded = true;
152153

153-
if (!options.omitUsesFromAbove) {
154+
if (!options.omitUsesFromAbove &&
155+
!op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
154156
llvm::for_each(op->getRegions(), [&](Region &region) {
155157
// Walk this region recursively to collect the regions that descend from
156158
// this op's nested regions (inclusive).

mlir/test/Transforms/move-operation-deps.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,3 +460,31 @@ module attributes {transform.with_named_sequence} {
460460
transform.yield
461461
}
462462
}
463+
464+
// -----
465+
466+
func.func @move_isolated_from_above() -> () {
467+
%1 = "before"() : () -> (f32)
468+
%2 = "moved0"() : () -> (f32)
469+
%3 = test.isolated_one_region_op %2 {} : f32 -> f32
470+
%4 = "moved1"(%3) : (f32) -> (f32)
471+
return
472+
}
473+
// CHECK-LABEL: func @move_isolated_from_above()
474+
// CHECK: %[[MOVED0:.+]] = "moved0"
475+
// CHECK: %[[ISOLATED:.+]] = test.isolated_one_region_op %[[MOVED0]]
476+
// CHECK: %[[MOVED1:.+]] = "moved1"(%[[ISOLATED]])
477+
// CHECK: %[[BEFORE:.+]] = "before"
478+
479+
module attributes {transform.with_named_sequence} {
480+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
481+
%op1 = transform.structured.match ops{["before"]} in %arg0
482+
: (!transform.any_op) -> !transform.any_op
483+
%op2 = transform.structured.match ops{["moved1"]} in %arg0
484+
: (!transform.any_op) -> !transform.any_op
485+
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
486+
transform.test.move_value_defns %v1 before %op1
487+
: (!transform.any_value), !transform.any_op
488+
transform.yield
489+
}
490+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,10 @@ def OneRegionWithOperandsOp : TEST_Op<"one_region_with_operands_op", []> {
552552

553553
def IsolatedOneRegionOp : TEST_Op<"isolated_one_region_op", [IsolatedFromAbove]> {
554554
let arguments = (ins Variadic<AnyType>:$operands);
555+
let results = (outs Variadic<AnyType>:$results);
555556
let regions = (region AnyRegion:$my_region);
556557
let assemblyFormat = [{
557-
attr-dict-with-keyword $operands $my_region `:` type($operands)
558+
attr-dict-with-keyword $operands $my_region `:` type($operands) `->` type($results)
558559
}];
559560
}
560561

mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ makeIsolatedFromAboveImpl(RewriterBase &rewriter,
2727
makeRegionIsolatedFromAbove(rewriter, region, callBack);
2828
SmallVector<Value> operands = regionOp.getOperands();
2929
operands.append(capturedValues);
30-
auto isolatedRegionOp =
31-
test::IsolatedOneRegionOp::create(rewriter, regionOp.getLoc(), operands);
30+
auto isolatedRegionOp = test::IsolatedOneRegionOp::create(
31+
rewriter, regionOp.getLoc(), TypeRange(), operands);
3232
rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(),
3333
isolatedRegionOp.getRegion().begin());
3434
rewriter.eraseOp(regionOp);

0 commit comments

Comments
 (0)