Skip to content

Commit a75565a

Browse files
authored
[mlir] Execute same operand name constraints before user constraints. (#162526)
For a pattern like this: Pat<(MyOp $x, $x), (...), [(MyCheck $x)]>; The old implementation generates: Pat<(MyOp $x0, $x1), (...), [(MyCheck $x0), ($x0 == $x1)]>; This is not very straightforward, because the $x name appears in the source pattern; it's attempting to assume equality check will be performed as part of the source pattern matching. This commit moves the equality checks before the other constraints, i.e.: Pat<(MyOp $x0, $x1), (...), [($x0 == $x1), (MyCheck $x0)]>;
1 parent 9e0d3bc commit a75565a

File tree

4 files changed

+72
-28
lines changed

4 files changed

+72
-28
lines changed

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,11 @@ def OpP : TEST_Op<"op_p"> {
11691169
let results = (outs I32);
11701170
}
11711171

1172+
def OpQ : TEST_Op<"op_q"> {
1173+
let arguments = (ins AnyType, AnyType);
1174+
let results = (outs AnyType);
1175+
}
1176+
11721177
// Test constant-folding a pattern that maps `(F32) -> SI32`.
11731178
def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> {
11741179
let arguments = (ins RankedTensorOf<[F32]>:$operand);
@@ -1207,6 +1212,14 @@ def TestNestedSameOpAndSameArgEqualityPattern :
12071212
def TestMultipleEqualArgsPattern :
12081213
Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
12091214

1215+
// Test equal arguments checks are applied before user provided constraints.
1216+
def AssertBinOpEqualArgsAndReturnTrue : Constraint<
1217+
CPred<"assertBinOpEqualArgsAndReturnTrue($0)">>;
1218+
def TestEqualArgsCheckBeforeUserConstraintsPattern :
1219+
Pat<(OpQ:$op $x, $x),
1220+
(replaceWithValue $x),
1221+
[(AssertBinOpEqualArgsAndReturnTrue $op)]>;
1222+
12101223
// Test for memrefs normalization of an op with normalizable memrefs.
12111224
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
12121225
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ static Attribute opMTest(PatternRewriter &rewriter, Value val) {
7070
return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
7171
}
7272

73+
static bool assertBinOpEqualArgsAndReturnTrue(Value v) {
74+
Operation *operation = v.getDefiningOp();
75+
if (operation->getOperand(0) != operation->getOperand(1)) {
76+
// Name binding equality check must happen before user-defined constraints,
77+
// thus this must not be triggered.
78+
llvm::report_fatal_error("Arguments are not equal");
79+
}
80+
return true;
81+
}
82+
7383
namespace {
7484
#include "TestPatterns.inc"
7585
} // namespace

mlir/test/mlir-tblgen/pattern.mlir

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,19 @@ func.func @verifyNestedOpEqualArgs(
156156
// def TestNestedOpEqualArgsPattern :
157157
// Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;
158158

159-
// CHECK: %arg1
159+
// CHECK: "test.op_o"(%arg1)
160160
%0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
161161
: (i32, i32, i32, i32, i32, i32) -> (i32)
162162
%1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)
163+
%2 = "test.op_o"(%1) : (i32) -> (i32)
163164

164-
// CHECK: test.op_p
165-
// CHECK: test.op_n
166-
%2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
165+
// CHECK-NEXT: %[[P:.*]] = "test.op_p"
166+
// CHECK-NEXT: %[[N:.*]] = "test.op_n"(%arg0, %[[P]])
167+
// CHECK-NEXT: "test.op_o"(%[[N]])
168+
%3 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
167169
: (i32, i32, i32, i32, i32, i32) -> (i32)
168-
%3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)
170+
%4 = "test.op_n"(%arg0, %3) : (i32, i32) -> (i32)
171+
%5 = "test.op_o"(%4) : (i32) -> (i32)
169172

170173
return
171174
}
@@ -206,6 +209,21 @@ func.func @verifyMultipleEqualArgs(
206209
return
207210
}
208211

212+
func.func @verifyEqualArgsCheckBeforeUserConstraints(%arg0: i32, %arg1: f32) {
213+
// def TestEqualArgsCheckBeforeUserConstraintsPattern :
214+
// Pat<(OpQ:$op $x, $x),
215+
// (replaceWithValue $x),
216+
// [(AssertBinOpEqualArgsAndReturnTrue $op)]>;
217+
218+
// CHECK: "test.op_q"(%arg0, %arg1)
219+
%0 = "test.op_q"(%arg0, %arg1) : (i32, f32) -> (i32)
220+
221+
// CHECK: "test.op_q"(%arg1, %arg0)
222+
%1 = "test.op_q"(%arg1, %arg0) : (f32, i32) -> (i32)
223+
224+
return
225+
}
226+
209227
//===----------------------------------------------------------------------===//
210228
// Test Symbol Binding
211229
//===----------------------------------------------------------------------===//

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,32 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
10241024
int depth = 0;
10251025
emitMatch(tree, opName, depth);
10261026

1027+
// Some of the operands could be bound to the same symbol name, we need
1028+
// to enforce equality constraint on those.
1029+
// This has to happen before user provided constraints, which may assume the
1030+
// same name checks are already performed, since in the pattern source code
1031+
// the user provided constraints appear later.
1032+
// TODO: we should be able to emit equality checks early
1033+
// and short circuit unnecessary work if vars are not equal.
1034+
for (auto symbolInfoIt = symbolInfoMap.begin();
1035+
symbolInfoIt != symbolInfoMap.end();) {
1036+
auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
1037+
auto startRange = range.first;
1038+
auto endRange = range.second;
1039+
1040+
auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
1041+
for (++startRange; startRange != endRange; ++startRange) {
1042+
auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
1043+
emitMatchCheck(
1044+
opName,
1045+
formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
1046+
formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
1047+
secondOperand));
1048+
}
1049+
1050+
symbolInfoIt = endRange;
1051+
}
1052+
10271053
for (auto &appliedConstraint : pattern.getConstraints()) {
10281054
auto &constraint = appliedConstraint.constraint;
10291055
auto &entities = appliedConstraint.entities;
@@ -1068,29 +1094,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
10681094
}
10691095
}
10701096

1071-
// Some of the operands could be bound to the same symbol name, we need
1072-
// to enforce equality constraint on those.
1073-
// TODO: we should be able to emit equality checks early
1074-
// and short circuit unnecessary work if vars are not equal.
1075-
for (auto symbolInfoIt = symbolInfoMap.begin();
1076-
symbolInfoIt != symbolInfoMap.end();) {
1077-
auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
1078-
auto startRange = range.first;
1079-
auto endRange = range.second;
1080-
1081-
auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
1082-
for (++startRange; startRange != endRange; ++startRange) {
1083-
auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
1084-
emitMatchCheck(
1085-
opName,
1086-
formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
1087-
formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
1088-
secondOperand));
1089-
}
1090-
1091-
symbolInfoIt = endRange;
1092-
}
1093-
10941097
LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
10951098
}
10961099

0 commit comments

Comments
 (0)