Skip to content

Commit 9465ef5

Browse files
authored
[mlir][tblgen] Fix bug when mixing props and InferTypes (#157367)
This patch fixes a bug occurring when properties are mixed with any of the InferType traits, causing tblgen to crash. A simple reproducer is: ``` def _TypeInferredPropOp : NS_Op<"type_inferred_prop_op_with_properties", [ AllTypesMatch<["value", "result"]> ]> { let arguments = (ins Property<"unsigned">:$prop, AnyType:$value); let results = (outs AnyType:$result); let hasCustomAssemblyFormat = 1; } ``` The issue occurs because of the call: ``` op.getArgToOperandOrAttribute(infer.getIndex()); ``` To understand better the issue, consider: ``` attrOrOperandMapping = [Operand0] arguments = [Prop0, Operand0] ``` In this case, `infer.getIndex()` will return `1` for `Operand0`, but `getArgToOperandOrAttribute` expects `0`, causing the discrepancy that causes the crash. The fix is to change `attrOrOperandMapping` to also include props.
1 parent 3327a4c commit 9465ef5

File tree

4 files changed

+42
-28
lines changed

4 files changed

+42
-28
lines changed

mlir/include/mlir/TableGen/Operator.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -323,21 +323,22 @@ class Operator {
323323
/// Requires: all result types are known.
324324
const InferredResultType &getInferredResultType(int index) const;
325325

326-
/// Pair consisting kind of argument and index into operands or attributes.
327-
struct OperandOrAttribute {
328-
enum class Kind { Operand, Attribute };
329-
OperandOrAttribute(Kind kind, int index) {
330-
packed = (index << 1) | (kind == Kind::Attribute);
326+
/// Pair consisting kind of argument and index into operands, attributes, or
327+
/// properties.
328+
struct OperandAttrOrProp {
329+
enum class Kind { Operand = 0x0, Attribute = 0x1, Property = 0x2 };
330+
OperandAttrOrProp(Kind kind, int index) {
331+
packed = (index << 2) | static_cast<int>(kind);
331332
}
332-
int operandOrAttributeIndex() const { return (packed >> 1); }
333-
Kind kind() { return (packed & 0x1) ? Kind::Attribute : Kind::Operand; }
333+
int operandOrAttributeIndex() const { return (packed >> 2); }
334+
Kind kind() const { return static_cast<Kind>(packed & 0x3); }
334335

335336
private:
336337
int packed;
337338
};
338339

339-
/// Returns the OperandOrAttribute corresponding to the index.
340-
OperandOrAttribute getArgToOperandOrAttribute(int index) const;
340+
/// Returns the OperandAttrOrProp corresponding to the index.
341+
OperandAttrOrProp getArgToOperandAttrOrProp(int index) const;
341342

342343
/// Returns the builders of this operation.
343344
ArrayRef<Builder> getBuilders() const { return builders; }
@@ -405,8 +406,8 @@ class Operator {
405406
/// The argument with the same type as the result.
406407
SmallVector<InferredResultType> resultTypeMapping;
407408

408-
/// Map from argument to attribute or operand number.
409-
SmallVector<OperandOrAttribute, 4> attrOrOperandMapping;
409+
/// Map from argument to attribute, property, or operand number.
410+
SmallVector<OperandAttrOrProp, 4> attrPropOrOperandMapping;
410411

411412
/// The builders of this operator.
412413
SmallVector<Builder> builders;

mlir/lib/TableGen/Operator.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,8 @@ void Operator::populateTypeInferenceInfo(
385385
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
386386
// Check for a non-variable length operand to use as the type anchor.
387387
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
388-
NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
388+
NamedTypeConstraint *operand =
389+
llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
389390
return operand && !operand->isVariableLength();
390391
});
391392
if (operandI == arguments.end())
@@ -663,15 +664,17 @@ void Operator::populateOpStructure() {
663664
argDef = argDef->getValueAsDef("constraint");
664665

665666
if (argDef->isSubClassOf(typeConstraintClass)) {
666-
attrOrOperandMapping.push_back(
667-
{OperandOrAttribute::Kind::Operand, operandIndex});
667+
attrPropOrOperandMapping.push_back(
668+
{OperandAttrOrProp::Kind::Operand, operandIndex});
668669
arguments.emplace_back(&operands[operandIndex++]);
669670
} else if (argDef->isSubClassOf(attrClass)) {
670-
attrOrOperandMapping.push_back(
671-
{OperandOrAttribute::Kind::Attribute, attrIndex});
671+
attrPropOrOperandMapping.push_back(
672+
{OperandAttrOrProp::Kind::Attribute, attrIndex});
672673
arguments.emplace_back(&attributes[attrIndex++]);
673674
} else {
674675
assert(argDef->isSubClassOf(propertyClass));
676+
attrPropOrOperandMapping.push_back(
677+
{OperandAttrOrProp::Kind::Property, propIndex});
675678
arguments.emplace_back(&properties[propIndex++]);
676679
}
677680
}
@@ -867,9 +870,8 @@ auto Operator::VariableDecoratorIterator::unwrap(const Init *init)
867870
return VariableDecorator(cast<DefInit>(init)->getDef());
868871
}
869872

870-
auto Operator::getArgToOperandOrAttribute(int index) const
871-
-> OperandOrAttribute {
872-
return attrOrOperandMapping[index];
873+
auto Operator::getArgToOperandAttrOrProp(int index) const -> OperandAttrOrProp {
874+
return attrPropOrOperandMapping[index];
873875
}
874876

875877
std::string Operator::getGetterName(StringRef name) const {

mlir/test/mlir-tblgen/op-decl-and-defs.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,3 +543,12 @@ def _BOp : NS_Op<"_op_with_leading_underscore_and_no_namespace", []>;
543543

544544
// REDUCE_EXC-NOT: NS::AOp declarations
545545
// REDUCE_EXC-LABEL: NS::BOp declarations
546+
547+
// CHECK-LABEL: _TypeInferredPropOp declarations
548+
def _TypeInferredPropOp : NS_Op<"type_inferred_prop_op_with_properties", [
549+
AllTypesMatch<["value", "result"]>
550+
]> {
551+
let arguments = (ins Property<"unsigned">:$prop, AnyType:$value);
552+
let results = (outs AnyType:$result);
553+
let hasCustomAssemblyFormat = 1;
554+
}

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3849,9 +3849,9 @@ void OpEmitter::genTypeInterfaceMethods() {
38493849
const InferredResultType &infer = op.getInferredResultType(i);
38503850
if (!infer.isArg())
38513851
continue;
3852-
Operator::OperandOrAttribute arg =
3853-
op.getArgToOperandOrAttribute(infer.getIndex());
3854-
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3852+
Operator::OperandAttrOrProp arg =
3853+
op.getArgToOperandAttrOrProp(infer.getIndex());
3854+
if (arg.kind() == Operator::OperandAttrOrProp::Kind::Operand) {
38553855
maxAccessedIndex =
38563856
std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
38573857
}
@@ -3877,17 +3877,16 @@ void OpEmitter::genTypeInterfaceMethods() {
38773877
if (infer.isArg()) {
38783878
// If this is an operand, just index into operand list to access the
38793879
// type.
3880-
Operator::OperandOrAttribute arg =
3881-
op.getArgToOperandOrAttribute(infer.getIndex());
3882-
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3880+
Operator::OperandAttrOrProp arg =
3881+
op.getArgToOperandAttrOrProp(infer.getIndex());
3882+
if (arg.kind() == Operator::OperandAttrOrProp::Kind::Operand) {
38833883
typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
38843884
"].getType()")
38853885
.str();
38863886

38873887
// If this is an attribute, index into the attribute dictionary.
3888-
} else {
3889-
auto *attr =
3890-
cast<NamedAttribute *>(op.getArg(arg.operandOrAttributeIndex()));
3888+
} else if (auto *attr = dyn_cast<NamedAttribute *>(
3889+
op.getArg(arg.operandOrAttributeIndex()))) {
38913890
body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
38923891
<< " = ";
38933892
if (op.getDialect().usePropertiesForAttributes()) {
@@ -3907,6 +3906,9 @@ void OpEmitter::genTypeInterfaceMethods() {
39073906
typeStr =
39083907
("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()")
39093908
.str();
3909+
} else {
3910+
llvm::PrintFatalError(&op.getDef(),
3911+
"Properties cannot be used for type inference");
39103912
}
39113913
} else if (std::optional<StringRef> builder =
39123914
op.getResult(infer.getResultIndex())

0 commit comments

Comments
 (0)