Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1872,6 +1872,11 @@ def TestEitherOpB : TEST_Op<"either_op_b"> {
let results = (outs I32:$output);
}

def TestEitherOpC : TEST_Op<"either_op_c"> {
let arguments = (ins AnyI32Attr:$attr, AnyInteger:$arg0, AnyInteger:$arg1);
let results = (outs I32:$output);
}

def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $x),
(TestEitherOpB $arg2, $x)>;

Expand All @@ -1883,6 +1888,9 @@ def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1, $_),
$x),
(TestEitherOpB $arg2, $x)>;

def : Pat<(TestEitherOpC ConstantAttr<I32Attr, "0">, (either $arg1, I32:$arg2)),
(TestEitherOpB $arg1, $arg2)>;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with this without the changes to the RewriterGen.cpp file?

Copy link
Contributor Author

@xl4624 xl4624 May 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the changes this test fails on a cast assertion:

mlir-tblgen: /home/xiaomin/dev/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::tblgen::NamedTypeConstraint *, From = llvm::PointerUnion<mlir::tblgen::NamedAttribute *, mlir::tblgen::NamedProperty *, mlir::tblgen::NamedTypeConstraint *>]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

Coming from:

} else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) {
emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
operandIndex,
/*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
/*argName=*/eitherArgTree.getArgName(i), argIndex,
/*variadicSubIndex=*/std::nullopt);

void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
StringRef operandName, int operandIndex,
DagLeaf operandMatcher, StringRef argName,
int argIndex,
std::optional<int> variadicSubIndex) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));

In emitEitherOperand(), we check if op.getArg(argIndex) is a NamedTypeConstraint, but in emitOperandMatch we cast op.getArg(operandIndex). In the test case above, operandIndex and argIndex get out of sync due to the Attribute being in the front which leads to a cast with the wrong index (specifically argIndex=1 referring to $arg1 and operandIndex=0 referring to the ConstantAttr)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test for this rewrite though to ensure it works as expected?

Copy link
Contributor Author

@xl4624 xl4624 May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for this in pattern.mlir, also cleaned up surrounding tests to match the style of the rest of the file.

def TestEitherHelperOpA : TEST_Op<"either_helper_op_a"> {
let arguments = (ins I32:$arg0);
let results = (outs I32:$output);
Expand Down
14 changes: 7 additions & 7 deletions mlir/tools/mlir-tblgen/RewriterGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
if (isa<NamedTypeConstraint *>(opArg)) {
auto operandName =
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
emitOperandMatch(tree, castedName, operandName.str(), nextOperand,
/*operandMatcher=*/tree.getArgAsLeaf(i),
/*argName=*/tree.getArgName(i), opArgIdx,
/*variadicSubIndex=*/std::nullopt);
Expand All @@ -680,7 +680,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
int argIndex,
std::optional<int> variadicSubIndex) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex));
NamedTypeConstraint operand = op.getOperand(operandIndex);

// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
Expand All @@ -693,8 +693,8 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Only need to verify if the matcher's type is different from the one
// of op definition.
Constraint constraint = operandMatcher.getAsConstraint();
if (operand->constraint != constraint) {
if (operand->isVariableLength()) {
if (operand.constraint != constraint) {
if (operand.isVariableLength()) {
auto error = formatv(
"further constrain op {0}'s variadic operand #{1} unsupported now",
op.getOperationName(), argIndex);
Expand All @@ -706,7 +706,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
verifier, opName, self.str(),
formatv(
"\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
operand - op.operand_begin(), op.getOperationName(),
operandIndex, op.getOperationName(),
escapeString(constraint.getSummary()))
.str());
}
Expand All @@ -715,7 +715,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
// Capture the value
// `$_` is a special symbol to ignore op argument matching.
if (!argName.empty() && argName != "_") {
auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex,
variadicSubIndex);
if (res == symbolInfoMap.end())
PrintFatalError(loc, formatv("symbol not found: {0}", argName));
Expand Down Expand Up @@ -821,7 +821,7 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
StringRef variadicTreeName = variadicArgTree.getSymbol();
if (!variadicTreeName.empty()) {
auto res =
symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, argIndex,
/*variadicSubIndex=*/std::nullopt);
if (res == symbolInfoMap.end())
PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));
Expand Down