Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit e9d2818

Browse files
authored
[MLIR][Python] add type hints for accessors (#158455)
This PR adds type hints for accessors in the generated builders.
1 parent 1ece67c commit e9d2818

File tree

2 files changed

+141
-34
lines changed

2 files changed

+141
-34
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,9 +1742,9 @@ nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
17421742
return nb::cast(PyBoolAttribute(pyAttribute));
17431743
if (PyIntegerAttribute::isaFunction(pyAttribute))
17441744
return nb::cast(PyIntegerAttribute(pyAttribute));
1745-
std::string msg =
1746-
std::string("Can't cast unknown element type DenseArrayAttr (") +
1747-
nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1745+
std::string msg = std::string("Can't cast unknown attribute type Attr (") +
1746+
nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
1747+
")";
17481748
throw nb::type_error(msg.c_str());
17491749
}
17501750

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Lines changed: 138 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ _ods_ir = _ods_cext.ir
4545
_ods_cext.globals.register_traceback_file_exclusion(__file__)
4646
4747
import builtins
48-
from typing import Sequence as _Sequence, Union as _Union
48+
from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional
4949
5050
)Py";
5151

@@ -95,9 +95,10 @@ constexpr const char *opClassRegionSpecTemplate = R"Py(
9595
/// {0} is the name of the accessor;
9696
/// {1} is either 'operand' or 'result';
9797
/// {2} is the position in the element list.
98+
/// {3} is the type hint.
9899
constexpr const char *opSingleTemplate = R"Py(
99100
@builtins.property
100-
def {0}(self):
101+
def {0}(self) -> {3}:
101102
return self.operation.{1}s[{2}]
102103
)Py";
103104

@@ -106,11 +107,12 @@ constexpr const char *opSingleTemplate = R"Py(
106107
/// {1} is either 'operand' or 'result';
107108
/// {2} is the total number of element groups;
108109
/// {3} is the position of the current group in the group list.
110+
/// {4} is the type hint.
109111
/// This works for both a single variadic group (non-negative length) and an
110112
/// single optional element (zero length if the element is absent).
111113
constexpr const char *opSingleAfterVariableTemplate = R"Py(
112114
@builtins.property
113-
def {0}(self):
115+
def {0}(self) -> {4}:
114116
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
115117
return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
116118
)Py";
@@ -120,12 +122,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
120122
/// {1} is either 'operand' or 'result';
121123
/// {2} is the total number of element groups;
122124
/// {3} is the position of the current group in the group list.
125+
/// {4} is the type hint.
123126
/// This works if we have only one variable-length group (and it's the optional
124127
/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
125128
/// smaller than the total number of groups.
126129
constexpr const char *opOneOptionalTemplate = R"Py(
127130
@builtins.property
128-
def {0}(self):
131+
def {0}(self) -> _Optional[{4}]:
129132
return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
130133
)Py";
131134

@@ -134,9 +137,10 @@ constexpr const char *opOneOptionalTemplate = R"Py(
134137
/// {1} is either 'operand' or 'result';
135138
/// {2} is the total number of element groups;
136139
/// {3} is the position of the current group in the group list.
140+
/// {4} is the type hint.
137141
constexpr const char *opOneVariadicTemplate = R"Py(
138142
@builtins.property
139-
def {0}(self):
143+
def {0}(self) -> {4}:
140144
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
141145
return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
142146
)Py";
@@ -148,9 +152,10 @@ constexpr const char *opOneVariadicTemplate = R"Py(
148152
/// {3} is the total number of variadic groups;
149153
/// {4} is the number of non-variadic groups preceding the current group;
150154
/// {5} is the number of variadic groups preceding the current group.
155+
/// {6} is the type hint.
151156
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
152157
@builtins.property
153-
def {0}(self):
158+
def {0}(self) -> {6}:
154159
start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
155160

156161
/// Second part of the template for equally-sized case, accessing a single
@@ -173,9 +178,10 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
173178
/// {2} is the position of the group in the group list;
174179
/// {3} is a return suffix (expected [0] for single-element, empty for
175180
/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
181+
/// {4} is the type hint.
176182
constexpr const char *opVariadicSegmentTemplate = R"Py(
177183
@builtins.property
178-
def {0}(self):
184+
def {0}(self) -> {4}:
179185
{1}_range = _ods_segmented_accessor(
180186
self.operation.{1}s,
181187
self.operation.attributes["{1}SegmentSizes"], {2})
@@ -191,18 +197,20 @@ constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
191197
/// Template for an operation attribute getter:
192198
/// {0} is the name of the attribute sanitized for Python;
193199
/// {1} is the original name of the attribute.
200+
/// {2} is the type hint.
194201
constexpr const char *attributeGetterTemplate = R"Py(
195202
@builtins.property
196-
def {0}(self):
203+
def {0}(self) -> {2}:
197204
return self.operation.attributes["{1}"]
198205
)Py";
199206

200207
/// Template for an optional operation attribute getter:
201208
/// {0} is the name of the attribute sanitized for Python;
202209
/// {1} is the original name of the attribute.
210+
/// {2} is the type hint.
203211
constexpr const char *optionalAttributeGetterTemplate = R"Py(
204212
@builtins.property
205-
def {0}(self):
213+
def {0}(self) -> _Optional[{2}]:
206214
if "{1}" not in self.operation.attributes:
207215
return None
208216
return self.operation.attributes["{1}"]
@@ -215,16 +223,17 @@ constexpr const char *optionalAttributeGetterTemplate = R"Py(
215223
/// {1} is the original name of the attribute.
216224
constexpr const char *unitAttributeGetterTemplate = R"Py(
217225
@builtins.property
218-
def {0}(self):
226+
def {0}(self) -> bool:
219227
return "{1}" in self.operation.attributes
220228
)Py";
221229

222230
/// Template for an operation attribute setter:
223231
/// {0} is the name of the attribute sanitized for Python;
224232
/// {1} is the original name of the attribute.
233+
/// {2} is the type hint.
225234
constexpr const char *attributeSetterTemplate = R"Py(
226235
@{0}.setter
227-
def {0}(self, value):
236+
def {0}(self, value: {2}):
228237
if value is None:
229238
raise ValueError("'None' not allowed as value for mandatory attributes")
230239
self.operation.attributes["{1}"] = value
@@ -234,9 +243,10 @@ constexpr const char *attributeSetterTemplate = R"Py(
234243
/// removes the attribute:
235244
/// {0} is the name of the attribute sanitized for Python;
236245
/// {1} is the original name of the attribute.
246+
/// {2} is the type hint.
237247
constexpr const char *optionalAttributeSetterTemplate = R"Py(
238248
@{0}.setter
239-
def {0}(self, value):
249+
def {0}(self, value: _Optional[{2}]):
240250
if value is not None:
241251
self.operation.attributes["{1}"] = value
242252
elif "{1}" in self.operation.attributes:
@@ -268,7 +278,7 @@ constexpr const char *attributeDeleterTemplate = R"Py(
268278

269279
constexpr const char *regionAccessorTemplate = R"Py(
270280
@builtins.property
271-
def {0}(self):
281+
def {0}(self) -> {2}:
272282
return self.regions[{1}]
273283
)Py";
274284

@@ -360,15 +370,24 @@ static void emitElementAccessors(
360370
seenVariableLength = true;
361371
if (element.name.empty())
362372
continue;
373+
const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
374+
: "_ods_ir.OpResult";
363375
if (element.isVariableLength()) {
364-
os << formatv(element.isOptional() ? opOneOptionalTemplate
365-
: opOneVariadicTemplate,
366-
sanitizeName(element.name), kind, numElements, i);
376+
if (element.isOptional()) {
377+
os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
378+
numElements, i, type);
379+
} else {
380+
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
381+
: "_ods_ir.OpResultList";
382+
os << formatv(opOneVariadicTemplate, sanitizeName(element.name), kind,
383+
numElements, i, type);
384+
}
367385
} else if (seenVariableLength) {
368386
os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
369-
kind, numElements, i);
387+
kind, numElements, i, type);
370388
} else {
371-
os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
389+
os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i,
390+
type);
372391
}
373392
}
374393
return;
@@ -391,9 +410,17 @@ static void emitElementAccessors(
391410
for (unsigned i = 0; i < numElements; ++i) {
392411
const NamedTypeConstraint &element = getElement(op, i);
393412
if (!element.name.empty()) {
413+
std::string type;
414+
if (element.isVariableLength()) {
415+
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
416+
: "_ods_ir.OpResultList";
417+
} else {
418+
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
419+
: "_ods_ir.OpResult";
420+
}
394421
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
395422
kind, numSimpleLength, numVariadicGroups,
396-
numPrecedingSimple, numPrecedingVariadic);
423+
numPrecedingSimple, numPrecedingVariadic, type);
397424
os << formatv(element.isVariableLength()
398425
? opVariadicEqualVariadicTemplate
399426
: opVariadicEqualSimpleTemplate,
@@ -416,13 +443,23 @@ static void emitElementAccessors(
416443
if (element.name.empty())
417444
continue;
418445
std::string trailing;
419-
if (!element.isVariableLength())
420-
trailing = "[0]";
421-
else if (element.isOptional())
422-
trailing = std::string(
423-
formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
446+
std::string type = std::strcmp(kind, "operand") == 0
447+
? "_ods_ir.OpOperandList"
448+
: "_ods_ir.OpResultList";
449+
if (!element.isVariableLength() || element.isOptional()) {
450+
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
451+
: "_ods_ir.OpResult";
452+
if (!element.isVariableLength()) {
453+
trailing = "[0]";
454+
} else if (element.isOptional()) {
455+
type = "_Optional[" + type + "]";
456+
trailing = std::string(
457+
formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
458+
}
459+
}
460+
424461
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
425-
i, trailing);
462+
i, trailing, type);
426463
}
427464
return;
428465
}
@@ -452,6 +489,72 @@ static void emitResultAccessors(const Operator &op, raw_ostream &os) {
452489
getNumResults(op), getResult);
453490
}
454491

492+
static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
493+
auto storageTypeStr = attr.getStorageType();
494+
if (storageTypeStr == "::mlir::AffineMapAttr")
495+
return "AffineMapAttr";
496+
if (storageTypeStr == "::mlir::ArrayAttr")
497+
return "ArrayAttr";
498+
if (storageTypeStr == "::mlir::BoolAttr")
499+
return "BoolAttr";
500+
if (storageTypeStr == "::mlir::DenseBoolArrayAttr")
501+
return "DenseBoolArrayAttr";
502+
if (storageTypeStr == "::mlir::DenseElementsAttr") {
503+
llvm::StringSet<> superClasses;
504+
for (const Record *sc : attr.getDef().getSuperClasses())
505+
superClasses.insert(sc->getNameInitAsString());
506+
if (superClasses.contains("FloatElementsAttr") ||
507+
superClasses.contains("RankedFloatElementsAttr")) {
508+
return "DenseFPElementsAttr";
509+
}
510+
return "DenseElementsAttr";
511+
}
512+
if (storageTypeStr == "::mlir::DenseF32ArrayAttr")
513+
return "DenseF32ArrayAttr";
514+
if (storageTypeStr == "::mlir::DenseF64ArrayAttr")
515+
return "DenseF64ArrayAttr";
516+
if (storageTypeStr == "::mlir::DenseFPElementsAttr")
517+
return "DenseFPElementsAttr";
518+
if (storageTypeStr == "::mlir::DenseI16ArrayAttr")
519+
return "DenseI16ArrayAttr";
520+
if (storageTypeStr == "::mlir::DenseI32ArrayAttr")
521+
return "DenseI32ArrayAttr";
522+
if (storageTypeStr == "::mlir::DenseI64ArrayAttr")
523+
return "DenseI64ArrayAttr";
524+
if (storageTypeStr == "::mlir::DenseI8ArrayAttr")
525+
return "DenseI8ArrayAttr";
526+
if (storageTypeStr == "::mlir::DenseIntElementsAttr")
527+
return "DenseIntElementsAttr";
528+
if (storageTypeStr == "::mlir::DenseResourceElementsAttr")
529+
return "DenseResourceElementsAttr";
530+
if (storageTypeStr == "::mlir::DictionaryAttr")
531+
return "DictAttr";
532+
if (storageTypeStr == "::mlir::FlatSymbolRefAttr")
533+
return "FlatSymbolRefAttr";
534+
if (storageTypeStr == "::mlir::FloatAttr")
535+
return "FloatAttr";
536+
if (storageTypeStr == "::mlir::IntegerAttr") {
537+
if (attr.getAttrDefName().str() == "I1Attr")
538+
return "BoolAttr";
539+
return "IntegerAttr";
540+
}
541+
if (storageTypeStr == "::mlir::IntegerSetAttr")
542+
return "IntegerSetAttr";
543+
if (storageTypeStr == "::mlir::OpaqueAttr")
544+
return "OpaqueAttr";
545+
if (storageTypeStr == "::mlir::StridedLayoutAttr")
546+
return "StridedLayoutAttr";
547+
if (storageTypeStr == "::mlir::StringAttr")
548+
return "StringAttr";
549+
if (storageTypeStr == "::mlir::SymbolRefAttr")
550+
return "SymbolRefAttr";
551+
if (storageTypeStr == "::mlir::TypeAttr")
552+
return "TypeAttr";
553+
if (storageTypeStr == "::mlir::UnitAttr")
554+
return "UnitAttr";
555+
return "Attribute";
556+
}
557+
455558
/// Emits accessors to Op attributes.
456559
static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
457560
for (const auto &namedAttr : op.getAttributes()) {
@@ -473,15 +576,18 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
473576
continue;
474577
}
475578

579+
std::string type = "_ods_ir." + getPythonAttrName(namedAttr.attr);
476580
if (namedAttr.attr.isOptional()) {
477581
os << formatv(optionalAttributeGetterTemplate, sanitizedName,
478-
namedAttr.name);
582+
namedAttr.name, type);
479583
os << formatv(optionalAttributeSetterTemplate, sanitizedName,
480-
namedAttr.name);
584+
namedAttr.name, type);
481585
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
482586
} else {
483-
os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name);
484-
os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name);
587+
os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name,
588+
type);
589+
os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name,
590+
type);
485591
// Non-optional attributes cannot be deleted.
486592
}
487593
}
@@ -983,8 +1089,9 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
9831089
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
9841090
"expected only the last region to be variadic");
9851091
os << formatv(regionAccessorTemplate, sanitizeName(region.name),
986-
std::to_string(en.index()) +
987-
(region.isVariadic() ? ":" : ""));
1092+
std::to_string(en.index()) + (region.isVariadic() ? ":" : ""),
1093+
region.isVariadic() ? "_ods_ir.RegionSequence"
1094+
: "_ods_ir.Region");
9881095
}
9891096
}
9901097

0 commit comments

Comments
 (0)