From f0140c540491eaa0305d339f4ce41002b616c6ed Mon Sep 17 00:00:00 2001 From: Chenguang Wang Date: Wed, 8 Oct 2025 11:20:09 -0700 Subject: [PATCH 1/4] [mlir] Execute same operand name constraints before user constraints. For a pattern like this: Pat<(MyOp $x, $x), (...), [(MyCheck $x)]>; The old implementation generates: Pat<(MyOp $x0, $x1), (...), [(MyCheck $x0), ($x0 == $x1)]>; This is not very straightforward, because the $x name appears in the source pattern; it's attempting to assume equality check will be performed as part of the source pattern matching. This commit moves the equality checks before the other constraints, i.e.: Pat<(MyOp $x0, $x1), (...), [($x0 == $x1), (MyCheck $x0)]>; --- mlir/test/lib/Dialect/Test/TestOps.td | 13 ++++++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 5 +++ mlir/test/mlir-tblgen/pattern.mlir | 28 +++++++++--- mlir/tools/mlir-tblgen/RewriterGen.cpp | 49 +++++++++++---------- 4 files changed, 67 insertions(+), 28 deletions(-) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 6ea27187655ee..ed62bee3bc152 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1169,6 +1169,11 @@ def OpP : TEST_Op<"op_p"> { let results = (outs I32); } +def OpQ : TEST_Op<"op_q"> { + let arguments = (ins AnyType, AnyType); + let results = (outs AnyType); +} + // Test constant-folding a pattern that maps `(F32) -> SI32`. def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> { let arguments = (ins RankedTensorOf<[F32]>:$operand); @@ -1207,6 +1212,14 @@ def TestNestedSameOpAndSameArgEqualityPattern : def TestMultipleEqualArgsPattern : Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>; +// Test equal arguments checks are applied before user provided constraints. +// CheckIntIs32Bits would throw exceptions if input is not i32. +def CheckIntIs32Bits : Constraint>; +def TestEqualArgsCheckBeforeUserConstraintsPattern : + Pat<(OpQ $x, $x), + (replaceWithValue $x), + [(CheckIntIs32Bits $x)]>; + // Test for memrefs normalization of an op with normalizable memrefs. def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> { let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index f8b5144e3acb2..d764deb023873 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -70,6 +70,11 @@ static Attribute opMTest(PatternRewriter &rewriter, Value val) { return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); } +// Requires input value is of i32 type. +static bool intIs32Bits(Value v) { + return mlir::dyn_cast(v.getType()).getWidth() == 32; +} + namespace { #include "TestPatterns.inc" } // namespace diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index bd55338618eec..a67830373e701 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -156,16 +156,19 @@ func.func @verifyNestedOpEqualArgs( // def TestNestedOpEqualArgsPattern : // Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>; - // CHECK: %arg1 + // CHECK: "test.op_o"(%arg1) %0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (i32, i32, i32, i32, i32, i32) -> (i32) %1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32) + %2 = "test.op_o"(%1) : (i32) -> (i32) - // CHECK: test.op_p - // CHECK: test.op_n - %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) + // CHECK-NEXT: %[[P:.*]] = "test.op_p" + // CHECK-NEXT: %[[N:.*]] = "test.op_n"(%arg0, %[[P]]) + // CHECK-NEXT: "test.op_o"(%[[N]]) + %3 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (i32, i32, i32, i32, i32, i32) -> (i32) - %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32) + %4 = "test.op_n"(%arg0, %3) : (i32, i32) -> (i32) + %5 = "test.op_o"(%4) : (i32) -> (i32) return } @@ -206,6 +209,21 @@ func.func @verifyMultipleEqualArgs( return } +func.func @verifyEqualArgsCheckBeforeUserConstraints(%arg0: i32, %arg1: f32) { + // def TestEqualArgsCheckBeforeUserConstraintsPattern : + // Pat<(OpQ $x, $x), + // [(CheckIntIs32Bits $x)], + // (replaceWithValue $x)>; + + // CHECK: "test.op_q"(%arg0, %arg1) + %0 = "test.op_q"(%arg0, %arg1) : (i32, f32) -> (i32) + + // CHECK: "test.op_q"(%arg1, %arg0) + %1 = "test.op_q"(%arg1, %arg0) : (f32, i32) -> (i32) + + return +} + //===----------------------------------------------------------------------===// // Test Symbol Binding //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 605033daa719f..40bc1a9c3868c 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1024,6 +1024,32 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { int depth = 0; emitMatch(tree, opName, depth); + // Some of the operands could be bound to the same symbol name, we need + // to enforce equality constraint on those. + // This has to happen before user provided constraints, which may assume the + // same name checks are already performed, since in the pattern source code + // the user provided constraints appear later. + // TODO: we should be able to emit equality checks early + // and short circuit unnecessary work if vars are not equal. + for (auto symbolInfoIt = symbolInfoMap.begin(); + symbolInfoIt != symbolInfoMap.end();) { + auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); + auto startRange = range.first; + auto endRange = range.second; + + auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); + for (++startRange; startRange != endRange; ++startRange) { + auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); + emitMatchCheck( + opName, + formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), + formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, + secondOperand)); + } + + symbolInfoIt = endRange; + } + for (auto &appliedConstraint : pattern.getConstraints()) { auto &constraint = appliedConstraint.constraint; auto &entities = appliedConstraint.entities; @@ -1068,29 +1094,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { } } - // Some of the operands could be bound to the same symbol name, we need - // to enforce equality constraint on those. - // TODO: we should be able to emit equality checks early - // and short circuit unnecessary work if vars are not equal. - for (auto symbolInfoIt = symbolInfoMap.begin(); - symbolInfoIt != symbolInfoMap.end();) { - auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); - auto startRange = range.first; - auto endRange = range.second; - - auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); - for (++startRange; startRange != endRange; ++startRange) { - auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); - emitMatchCheck( - opName, - formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), - formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, - secondOperand)); - } - - symbolInfoIt = endRange; - } - LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); } From 5df4c81ab78d8a72e80499b8071cb0fb0d762161 Mon Sep 17 00:00:00 2001 From: Chenguang Wang Date: Wed, 8 Oct 2025 14:21:35 -0700 Subject: [PATCH 2/4] use explicit check instead of dyn_cast --- mlir/test/lib/Dialect/Test/TestOps.td | 8 ++++---- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 11 ++++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index ed62bee3bc152..6329d61ba691b 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1213,12 +1213,12 @@ def TestMultipleEqualArgsPattern : Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>; // Test equal arguments checks are applied before user provided constraints. -// CheckIntIs32Bits would throw exceptions if input is not i32. -def CheckIntIs32Bits : Constraint>; +def AssertBinOpEqualArgsAndReturnTrue : Constraint< + CPred<"assertBinOpEqualArgsAndReturnTrue($0)">>; def TestEqualArgsCheckBeforeUserConstraintsPattern : - Pat<(OpQ $x, $x), + Pat<(OpQ:$op $x, $x), (replaceWithValue $x), - [(CheckIntIs32Bits $x)]>; + [(AssertBinOpEqualArgsAndReturnTrue $op)]>; // Test for memrefs normalization of an op with normalizable memrefs. def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index d764deb023873..e7be1981d364c 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -70,9 +70,14 @@ static Attribute opMTest(PatternRewriter &rewriter, Value val) { return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); } -// Requires input value is of i32 type. -static bool intIs32Bits(Value v) { - return mlir::dyn_cast(v.getType()).getWidth() == 32; +static bool assertBinOpEqualArgsAndReturnTrue(Value v) { + Operation* operation = v.getDefiningOp(); + if (operation->getOperand(0) != operation->getOperand(1)) { + // Name binding equality check must happen before user-defined constraints, + // thus this must not be triggered. + llvm::report_fatal_error("Arguments are not equal"); + } + return true; } namespace { From 5a31d36916ad68df019d917fb69c4fa3076f302b Mon Sep 17 00:00:00 2001 From: Chenguang Wang Date: Wed, 8 Oct 2025 14:23:49 -0700 Subject: [PATCH 3/4] update pattern comment in .mlir test --- mlir/test/mlir-tblgen/pattern.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index a67830373e701..ffb78c28412ce 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -211,9 +211,9 @@ func.func @verifyMultipleEqualArgs( func.func @verifyEqualArgsCheckBeforeUserConstraints(%arg0: i32, %arg1: f32) { // def TestEqualArgsCheckBeforeUserConstraintsPattern : - // Pat<(OpQ $x, $x), - // [(CheckIntIs32Bits $x)], - // (replaceWithValue $x)>; + // Pat<(OpQ:$op $x, $x), + // (replaceWithValue $x), + // [(AssertBinOpEqualArgsAndReturnTrue $op)]>; // CHECK: "test.op_q"(%arg0, %arg1) %0 = "test.op_q"(%arg0, %arg1) : (i32, f32) -> (i32) From cd8a8837b0448fc38477feb702fbc573dfad4c55 Mon Sep 17 00:00:00 2001 From: Chenguang Wang Date: Wed, 8 Oct 2025 14:56:17 -0700 Subject: [PATCH 4/4] fix format --- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index e7be1981d364c..ee4fa39158721 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -71,7 +71,7 @@ static Attribute opMTest(PatternRewriter &rewriter, Value val) { } static bool assertBinOpEqualArgsAndReturnTrue(Value v) { - Operation* operation = v.getDefiningOp(); + Operation *operation = v.getDefiningOp(); if (operation->getOperand(0) != operation->getOperand(1)) { // Name binding equality check must happen before user-defined constraints, // thus this must not be triggered.