Skip to content

Commit 1697e33

Browse files
Modify TransformRewriter listener to get the match failure remark and use it to test failure in the op.
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent d13fc99 commit 1697e33

File tree

6 files changed

+66
-6
lines changed

6 files changed

+66
-6
lines changed

mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,10 +1074,18 @@ class ErrorCheckingTrackingListener : public TrackingListener {
10741074
/// resets the error state to "success".
10751075
DiagnosedSilenceableFailure checkAndResetError();
10761076

1077+
/// Return the latest match notification message.
1078+
std::string getLatestMatchFailureMessage();
1079+
10771080
/// Return "true" if this tracking listener had a failure.
10781081
bool failed() const;
10791082

10801083
protected:
1084+
1085+
void
1086+
notifyMatchFailure(Location loc,
1087+
function_ref<void(Diagnostic &)> reasonCallback) override;
1088+
10811089
void
10821090
notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
10831091
DiagnosedSilenceableFailure &&diag) override;
@@ -1089,6 +1097,9 @@ class ErrorCheckingTrackingListener : public TrackingListener {
10891097

10901098
/// The number of errors that have been encountered.
10911099
int64_t errorCounter = 0;
1100+
1101+
/// Latest message from match failure notification.
1102+
std::string matchFailureMsg = "";
10921103
};
10931104

10941105
/// This is a special rewriter to be used in transform op implementations,

mlir/include/mlir/Transforms/RegionUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
7272

7373
/// Move SSA values used within an operation before an insertion point,
7474
/// so that the operation itself (or its replacement) can be moved to
75-
/// the insertion point.
75+
/// the insertion point. Current support is only for movement of
76+
/// dependencies of `op` before `insertionPoint` in the same basic block.
7677
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
7778
Operation *insertionPoint,
7879
DominanceInfo &dominance);

mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,17 @@ void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
13901390
++errorCounter;
13911391
}
13921392

1393+
std::string transform::ErrorCheckingTrackingListener::getLatestMatchFailureMessage() {
1394+
return matchFailureMsg;
1395+
}
1396+
1397+
void transform::ErrorCheckingTrackingListener::notifyMatchFailure(
1398+
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1399+
Diagnostic diag(loc, DiagnosticSeverity::Remark);
1400+
reasonCallback(diag);
1401+
matchFailureMsg = diag.str();
1402+
}
1403+
13931404
//===----------------------------------------------------------------------===//
13941405
// TransformRewriter
13951406
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,13 +1073,21 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
10731073
op, "unsupported caes where operation and insertion point are not in "
10741074
"the same basic block");
10751075
}
1076+
// If `insertionPoint` does not dominate `op`, do nothing
1077+
if (!dominance.properlyDominates(insertionPoint, op)) {
1078+
return rewriter.notifyMatchFailure(op,
1079+
"insertion point does not dominate op");
1080+
}
10761081

10771082
// Find the backward slice of operation for each `Value` the operation
10781083
// depends on. Prune the slice to only include operations not already
10791084
// dominated by the `insertionPoint`
10801085
BackwardSliceOptions options;
10811086
options.inclusive = true;
10821087
options.omitUsesFromAbove = false;
1088+
// Since current support is to only move within a same basic block,
1089+
// the slices dont need to look past block arguments.
1090+
options.omitBlockArguments = true;
10831091
options.filter = [&](Operation *sliceBoundaryOp) {
10841092
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
10851093
};

mlir/test/Transforms/move-operation-deps.mlir

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt --allow-unregistered-dialect --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
22

33
// Check simple move of dependent operation before insertion.
44
func.func @simple_move() -> f32 {
@@ -95,7 +95,7 @@ module attributes {transform.with_named_sequence} {
9595
func.func @move_region_dependencies() -> f32 {
9696
%0 = "before"() : () -> (f32)
9797
%1 = "moved_op_1"() : () -> (f32)
98-
%2 = "moved_op"() ({
98+
%2 = "moved_op_2"() ({
9999
"yield"(%1) : (f32) -> ()
100100
}) : () -> (f32)
101101
%3 = "foo"() ({
@@ -139,6 +139,7 @@ module attributes {transform.with_named_sequence} {
139139
: (!transform.any_op) -> !transform.any_op
140140
%op2 = transform.structured.match ops{["before"]} in %arg0
141141
: (!transform.any_op) -> !transform.any_op
142+
// expected-remark@+1{{cannot move dependencies before operation in backward slice of op}}
142143
transform.test.move_operand_deps %op1 before %op2
143144
: !transform.any_op, !transform.any_op
144145
transform.yield
@@ -147,7 +148,9 @@ module attributes {transform.with_named_sequence} {
147148

148149
// -----
149150

150-
func.func @move_region_dependencies() -> f32 {
151+
// Fail when the "before" operation is part of the operation slice (computed
152+
// when looking through implicitly captured values).
153+
func.func @do_not_move_slice() -> f32 {
151154
%0 = "before"() : () -> (f32)
152155
%1 = "moved_op"() ({
153156
"yield"(%0) : (f32) -> ()
@@ -164,6 +167,32 @@ module attributes {transform.with_named_sequence} {
164167
: (!transform.any_op) -> !transform.any_op
165168
%op2 = transform.structured.match ops{["before"]} in %arg0
166169
: (!transform.any_op) -> !transform.any_op
170+
// expected-remark@+1{{cannot move dependencies before operation in backward slice of op}}
171+
transform.test.move_operand_deps %op1 before %op2
172+
: !transform.any_op, !transform.any_op
173+
transform.yield
174+
}
175+
}
176+
177+
// -----
178+
179+
// Dont move ops when insertion point does not dominate the op
180+
func.func @do_not_move() -> f32 {
181+
%1 = "moved_op"() : () -> (f32)
182+
%2 = "foo"() ({
183+
"yield"(%1) : (f32) -> ()
184+
}) : () -> (f32)
185+
%3 = "before"() : () -> f32
186+
return %2 : f32
187+
}
188+
189+
module attributes {transform.with_named_sequence} {
190+
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
191+
%op1 = transform.structured.match ops{["foo"]} in %arg0
192+
: (!transform.any_op) -> !transform.any_op
193+
%op2 = transform.structured.match ops{["before"]} in %arg0
194+
: (!transform.any_op) -> !transform.any_op
195+
// expected-remark@+1{{insertion point does not dominate op}}
167196
transform.test.move_operand_deps %op1 before %op2
168197
: !transform.any_op, !transform.any_op
169198
transform.yield

mlir/test/lib/Transforms/TestTransformsOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
3333
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
3434
if (failed(moveOperationDependencies(rewriter, op, moveBefore))) {
3535
auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
36-
std::string errorMsg = listener->checkAndResetError().getMessage();
37-
return emitSilenceableFailure(op, errorMsg);
36+
std::string errorMsg = listener->getLatestMatchFailureMessage();
37+
(void)emitRemark(errorMsg);
3838
}
3939
return DiagnosedSilenceableFailure::success();
4040
}

0 commit comments

Comments
 (0)