diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 6ea27187655ee..6329d61ba691b 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1169,6 +1169,11 @@ def OpP : TEST_Op<"op_p"> { let results = (outs I32); } +def OpQ : TEST_Op<"op_q"> { + let arguments = (ins AnyType, AnyType); + let results = (outs AnyType); +} + // Test constant-folding a pattern that maps `(F32) -> SI32`. def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> { let arguments = (ins RankedTensorOf<[F32]>:$operand); @@ -1207,6 +1212,14 @@ def TestNestedSameOpAndSameArgEqualityPattern : def TestMultipleEqualArgsPattern : Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>; +// Test equal arguments checks are applied before user provided constraints. +def AssertBinOpEqualArgsAndReturnTrue : Constraint< + CPred<"assertBinOpEqualArgsAndReturnTrue($0)">>; +def TestEqualArgsCheckBeforeUserConstraintsPattern : + Pat<(OpQ:$op $x, $x), + (replaceWithValue $x), + [(AssertBinOpEqualArgsAndReturnTrue $op)]>; + // Test for memrefs normalization of an op with normalizable memrefs. def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> { let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index f8b5144e3acb2..ee4fa39158721 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -70,6 +70,16 @@ static Attribute opMTest(PatternRewriter &rewriter, Value val) { return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); } +static bool assertBinOpEqualArgsAndReturnTrue(Value v) { + Operation *operation = v.getDefiningOp(); + if (operation->getOperand(0) != operation->getOperand(1)) { + // Name binding equality check must happen before user-defined constraints, + // thus this must not be triggered. + llvm::report_fatal_error("Arguments are not equal"); + } + return true; +} + namespace { #include "TestPatterns.inc" } // namespace diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index bd55338618eec..ffb78c28412ce 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -156,16 +156,19 @@ func.func @verifyNestedOpEqualArgs( // def TestNestedOpEqualArgsPattern : // Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>; - // CHECK: %arg1 + // CHECK: "test.op_o"(%arg1) %0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (i32, i32, i32, i32, i32, i32) -> (i32) %1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32) + %2 = "test.op_o"(%1) : (i32) -> (i32) - // CHECK: test.op_p - // CHECK: test.op_n - %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) + // CHECK-NEXT: %[[P:.*]] = "test.op_p" + // CHECK-NEXT: %[[N:.*]] = "test.op_n"(%arg0, %[[P]]) + // CHECK-NEXT: "test.op_o"(%[[N]]) + %3 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (i32, i32, i32, i32, i32, i32) -> (i32) - %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32) + %4 = "test.op_n"(%arg0, %3) : (i32, i32) -> (i32) + %5 = "test.op_o"(%4) : (i32) -> (i32) return } @@ -206,6 +209,21 @@ func.func @verifyMultipleEqualArgs( return } +func.func @verifyEqualArgsCheckBeforeUserConstraints(%arg0: i32, %arg1: f32) { + // def TestEqualArgsCheckBeforeUserConstraintsPattern : + // Pat<(OpQ:$op $x, $x), + // (replaceWithValue $x), + // [(AssertBinOpEqualArgsAndReturnTrue $op)]>; + + // CHECK: "test.op_q"(%arg0, %arg1) + %0 = "test.op_q"(%arg0, %arg1) : (i32, f32) -> (i32) + + // CHECK: "test.op_q"(%arg1, %arg0) + %1 = "test.op_q"(%arg1, %arg0) : (f32, i32) -> (i32) + + return +} + //===----------------------------------------------------------------------===// // Test Symbol Binding //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 605033daa719f..40bc1a9c3868c 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1024,6 +1024,32 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { int depth = 0; emitMatch(tree, opName, depth); + // Some of the operands could be bound to the same symbol name, we need + // to enforce equality constraint on those. + // This has to happen before user provided constraints, which may assume the + // same name checks are already performed, since in the pattern source code + // the user provided constraints appear later. + // TODO: we should be able to emit equality checks early + // and short circuit unnecessary work if vars are not equal. + for (auto symbolInfoIt = symbolInfoMap.begin(); + symbolInfoIt != symbolInfoMap.end();) { + auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); + auto startRange = range.first; + auto endRange = range.second; + + auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); + for (++startRange; startRange != endRange; ++startRange) { + auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); + emitMatchCheck( + opName, + formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), + formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, + secondOperand)); + } + + symbolInfoIt = endRange; + } + for (auto &appliedConstraint : pattern.getConstraints()) { auto &constraint = appliedConstraint.constraint; auto &entities = appliedConstraint.entities; @@ -1068,29 +1094,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { } } - // Some of the operands could be bound to the same symbol name, we need - // to enforce equality constraint on those. - // TODO: we should be able to emit equality checks early - // and short circuit unnecessary work if vars are not equal. - for (auto symbolInfoIt = symbolInfoMap.begin(); - symbolInfoIt != symbolInfoMap.end();) { - auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); - auto startRange = range.first; - auto endRange = range.second; - - auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); - for (++startRange; startRange != endRange; ++startRange) { - auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); - emitMatchCheck( - opName, - formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), - formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, - secondOperand)); - } - - symbolInfoIt = endRange; - } - LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); }