Skip to content

Commit 736f47d

Browse files
committed
Add support to TableGen source patterns to match multi-result values by index
1 parent 3fe85ca commit 736f47d

File tree

5 files changed

+57
-5
lines changed

5 files changed

+57
-5
lines changed

mlir/include/mlir/TableGen/Pattern.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,9 @@ class SymbolInfoMap {
433433
DagAndConstant(node.getAsOpaquePointer(), operandIndex,
434434
variadicSubIndex));
435435
}
436-
static SymbolInfo getResult(const Operator *op) {
437-
return SymbolInfo(op, Kind::Result, std::nullopt);
436+
static SymbolInfo getResult(const Operator *op, int index) {
437+
return SymbolInfo(op, Kind::Result,
438+
DagAndConstant(nullptr, index, std::nullopt));
438439
}
439440
static SymbolInfo getValue() {
440441
return SymbolInfo(nullptr, Kind::Value, std::nullopt);

mlir/lib/TableGen/Pattern.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
370370
case Kind::Result: {
371371
// If `index` is greater than zero, then we are referencing a specific
372372
// result of a multi-result op. The result can still be variadic.
373+
if (index < 0)
374+
index = dagAndConstant->operandIndexOrNumValues;
373375
if (index >= 0) {
374376
std::string v =
375377
std::string(formatv("{0}.getODSResults({1})", name, index));
@@ -442,6 +444,8 @@ std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
442444
return std::string(repl);
443445
}
444446
case Kind::Result: {
447+
if (index < 0)
448+
index = dagAndConstant->operandIndexOrNumValues;
445449
if (index >= 0) {
446450
auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
447451
LLVM_DEBUG(dbgs() << repl << " (SingleResult)\n");
@@ -522,8 +526,9 @@ bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
522526
}
523527

524528
bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
525-
std::string name = getValuePackName(symbol).str();
526-
auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
529+
int index = -1;
530+
StringRef name = getValuePackName(symbol, &index);
531+
auto inserted = symbolInfoMap.emplace(name.str(), SymbolInfo::getResult(&op, index));
527532

528533
return symbolInfoMap.count(inserted->first) == 1;
529534
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,6 +1660,16 @@ def OneResultOp3 : TEST_Op<"one_result3"> {
16601660
let results = (outs I32:$result1);
16611661
}
16621662

1663+
def OneResultOp4 : TEST_Op<"one_result4"> {
1664+
let arguments = (ins F32);
1665+
let results = (outs F32);
1666+
}
1667+
1668+
def TwoResultOp2 : TEST_Op<"two_result2"> {
1669+
let arguments = (ins);
1670+
let results = (outs F32, F32);
1671+
}
1672+
16631673
// Test using multi-result op as a whole
16641674
def : Pat<(ThreeResultOp MultiResultOpKind1:$kind),
16651675
(AnotherThreeResultOp $kind)>;
@@ -1696,6 +1706,12 @@ def : Pattern<
16961706
(AnotherTwoResultOp $kind)
16971707
]>;
16981708

1709+
// Test referencing a one-param op whose
1710+
// param comes from the first result of a two-result op.
1711+
def : Pat<
1712+
(OneResultOp4 (TwoResultOp2:$a__1)),
1713+
(replaceWithValue $a__0)>;
1714+
16991715
//===----------------------------------------------------------------------===//
17001716
// Test Patterns (Variadic Ops)
17011717
//===----------------------------------------------------------------------===//

mlir/test/mlir-tblgen/pattern.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,25 @@ func.func @testMatchMixedVaradicOptional(%arg0: i32, %arg1: i32, %arg2: i32, %ar
594594
return
595595
}
596596

597+
// CHECK-LABEL: @replaceOneResultWithNSuffixArgMatch
598+
func.func @replaceOneResultWithNSuffixArgMatch() -> (f32) {
599+
// CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32)
600+
%0:2 = "test.two_result2"() : () -> (f32, f32)
601+
%1 = "test.one_result4"(%0#1) : (f32) -> (f32)
602+
// CHECK: return %0#0 : f32
603+
return %1 : f32
604+
}
605+
606+
// CHECK-LABEL: @replaceOneResultWithNSuffixArgNoMatch
607+
func.func @replaceOneResultWithNSuffixArgNoMatch() -> (f32) {
608+
// CHECK: %0:2 = "test.two_result2"() : () -> (f32, f32)
609+
%0:2 = "test.two_result2"() : () -> (f32, f32)
610+
// CHECK: %1 = "test.one_result4"(%0#0) : (f32) -> f32
611+
%1 = "test.one_result4"(%0#0) : (f32) -> (f32)
612+
// CHECK: return %1 : f32
613+
return %1 : f32
614+
}
615+
597616
//===----------------------------------------------------------------------===//
598617
// Test patterns that operate on properties
599618
//===----------------------------------------------------------------------===//

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,17 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
615615
op.getQualCppClassName()));
616616

617617
// If the operand's name is set, set to that variable.
618-
auto name = tree.getSymbol();
618+
int index = -1;
619+
auto name = SymbolInfoMap::getValuePackName(tree.getSymbol(), &index).str();
619620
if (!name.empty())
620621
os << formatv("{0} = {1};\n", name, castedName);
621622

623+
if (index != -1) {
624+
emitMatchCheck(opName,
625+
formatv("(resultNumber{0} == 1)", depth),
626+
formatv("\"{0} does not come from result number {1} type\"", castedName, index));
627+
}
628+
622629
for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
623630
++i, ++opArgIdx) {
624631
auto opArg = op.getArg(opArgIdx);
@@ -662,6 +669,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
662669
"auto *{0} = "
663670
"(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
664671
argName, castedName, nextOperand);
672+
os.indent() << formatv(
673+
"[[maybe_unused]] auto resultNumber{0} = "
674+
"::llvm::dyn_cast<::mlir::OpResult>((*{1}.getODSOperands({2}).begin())).getResultNumber();\n",
675+
depth + 1, castedName, nextOperand);
665676
// Null check of operand's definingOp
666677
emitMatchCheck(
667678
castedName, /*matchStr=*/argName,

0 commit comments

Comments
 (0)