Skip to content

Commit 81a0db6

Browse files
committed
Skip over all results in the Bytecode if a Constraint/Rewrite failed, instead of just skipping over the first result.
Skipping only over the first results leads to the curCodeIt pointing to the wrong location in the bytecode, causing the execution to continue with a wrong instruction after the Constraint/Rewrite. Signed-off-by: Rickert, Jonas <[email protected]>
1 parent b291cfc commit 81a0db6

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

mlir/lib/Rewrite/ByteCode.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,22 +1496,24 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
14961496
void ByteCodeExecutor::processNativeFunResults(
14971497
ByteCodeRewriteResultList &results, unsigned numResults,
14981498
LogicalResult &rewriteResult) {
1499-
// Store the results in the bytecode memory or handle missing results on
1500-
// failure.
1501-
for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1502-
PDLValue::Kind resultKind = read<PDLValue::Kind>();
1503-
1499+
if (failed(rewriteResult)) {
15041500
// Skip the according number of values on the buffer on failure and exit
15051501
// early as there are no results to process.
1506-
if (failed(rewriteResult)) {
1502+
for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1503+
const PDLValue::Kind resultKind = read<PDLValue::Kind>();
15071504
if (resultKind == PDLValue::Kind::TypeRange ||
15081505
resultKind == PDLValue::Kind::ValueRange) {
15091506
skip(2);
15101507
} else {
15111508
skip(1);
15121509
}
1513-
return;
15141510
}
1511+
return;
1512+
}
1513+
1514+
// Store the results in the bytecode memory
1515+
for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1516+
PDLValue::Kind resultKind = read<PDLValue::Kind>();
15151517
PDLValue result = results.getResults()[resultIdx];
15161518
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
15171519
assert(result.getKind() == resultKind &&

mlir/test/Rewrite/pdl-bytecode.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,36 @@ module @ir attributes { test.apply_constraint_4 } {
143143

144144
// -----
145145

146+
// Test returning a type from a native constraint.
147+
module @patterns {
148+
pdl_interp.func @matcher(%root : !pdl.operation) {
149+
%new_type:2 = pdl_interp.apply_constraint "op_multiple_returns_failure"(%root : !pdl.operation) : !pdl.type, !pdl.type -> ^pat2, ^end
150+
151+
^pat2:
152+
pdl_interp.record_match @rewriters::@success(%root, %new_type#0 : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
153+
154+
^end:
155+
pdl_interp.finalize
156+
}
157+
158+
module @rewriters {
159+
pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
160+
%op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
161+
pdl_interp.erase %root
162+
pdl_interp.finalize
163+
}
164+
}
165+
}
166+
167+
// CHECK-LABEL: test.apply_constraint_multi_result_failure
168+
// CHECK-NOT: "test.replaced_by_pattern"
169+
// CHECK: "test.success_op"
170+
module @ir attributes { test.apply_constraint_multi_result_failure } {
171+
"test.success_op"() : () -> ()
172+
}
173+
174+
// -----
175+
146176
// Test success and failure cases of native constraints with pdl.range results.
147177
module @patterns {
148178
pdl_interp.func @matcher(%root : !pdl.operation) {

mlir/test/lib/Rewrite/TestPDLByteCode.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
5555
return failure();
5656
}
5757

58+
// Custom constraint that always returns failure
59+
static LogicalResult customConstraintFailure(PatternRewriter & /*rewriter*/,
60+
PDLResultList & /*results*/,
61+
ArrayRef<PDLValue> /*args*/) {
62+
return failure();
63+
}
64+
5865
// Custom constraint that returns a type range of variable length if the op is
5966
// named test.success_op
6067
static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
@@ -150,6 +157,8 @@ struct TestPDLByteCodePass
150157
customValueResultConstraint);
151158
pdlPattern.registerConstraintFunction("op_constr_return_type",
152159
customTypeResultConstraint);
160+
pdlPattern.registerConstraintFunction("op_multiple_returns_failure",
161+
customConstraintFailure);
153162
pdlPattern.registerConstraintFunction("op_constr_return_type_range",
154163
customTypeRangeResultConstraint);
155164
pdlPattern.registerRewriteFunction("creator", customCreate);

0 commit comments

Comments
 (0)