@@ -45,7 +45,7 @@ _ods_ir = _ods_cext.ir
4545_ods_cext.globals.register_traceback_file_exclusion(__file__)
4646
4747import builtins
48- from typing import Sequence as _Sequence, Union as _Union
48+ from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional
4949
5050)Py" ;
5151
@@ -95,9 +95,10 @@ constexpr const char *opClassRegionSpecTemplate = R"Py(
9595// / {0} is the name of the accessor;
9696// / {1} is either 'operand' or 'result';
9797// / {2} is the position in the element list.
98+ // / {3} is the type hint.
9899constexpr const char *opSingleTemplate = R"Py(
99100 @builtins.property
100- def {0}(self):
101+ def {0}(self) -> {3} :
101102 return self.operation.{1}s[{2}]
102103)Py" ;
103104
@@ -106,11 +107,12 @@ constexpr const char *opSingleTemplate = R"Py(
106107// / {1} is either 'operand' or 'result';
107108// / {2} is the total number of element groups;
108109// / {3} is the position of the current group in the group list.
110+ // / {4} is the type hint.
109111// / This works for both a single variadic group (non-negative length) and an
110112// / single optional element (zero length if the element is absent).
111113constexpr const char *opSingleAfterVariableTemplate = R"Py(
112114 @builtins.property
113- def {0}(self):
115+ def {0}(self) -> {4} :
114116 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
115117 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
116118)Py" ;
@@ -120,12 +122,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py(
120122// / {1} is either 'operand' or 'result';
121123// / {2} is the total number of element groups;
122124// / {3} is the position of the current group in the group list.
125+ // / {4} is the type hint.
123126// / This works if we have only one variable-length group (and it's the optional
124127// / operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
125128// / smaller than the total number of groups.
126129constexpr const char *opOneOptionalTemplate = R"Py(
127130 @builtins.property
128- def {0}(self):
131+ def {0}(self) -> _Optional[{4}] :
129132 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
130133)Py" ;
131134
@@ -134,9 +137,10 @@ constexpr const char *opOneOptionalTemplate = R"Py(
134137// / {1} is either 'operand' or 'result';
135138// / {2} is the total number of element groups;
136139// / {3} is the position of the current group in the group list.
140+ // / {4} is the type hint.
137141constexpr const char *opOneVariadicTemplate = R"Py(
138142 @builtins.property
139- def {0}(self):
143+ def {0}(self) -> {4} :
140144 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
141145 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
142146)Py" ;
@@ -148,9 +152,10 @@ constexpr const char *opOneVariadicTemplate = R"Py(
148152// / {3} is the total number of variadic groups;
149153// / {4} is the number of non-variadic groups preceding the current group;
150154// / {5} is the number of variadic groups preceding the current group.
155+ // / {6} is the type hint.
151156constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
152157 @builtins.property
153- def {0}(self):
158+ def {0}(self) -> {6} :
154159 start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py" ;
155160
156161// / Second part of the template for equally-sized case, accessing a single
@@ -173,9 +178,10 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
173178// / {2} is the position of the group in the group list;
174179// / {3} is a return suffix (expected [0] for single-element, empty for
175180// / variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
181+ // / {4} is the type hint.
176182constexpr const char *opVariadicSegmentTemplate = R"Py(
177183 @builtins.property
178- def {0}(self):
184+ def {0}(self) -> {4} :
179185 {1}_range = _ods_segmented_accessor(
180186 self.operation.{1}s,
181187 self.operation.attributes["{1}SegmentSizes"], {2})
@@ -191,18 +197,20 @@ constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
191197// / Template for an operation attribute getter:
192198// / {0} is the name of the attribute sanitized for Python;
193199// / {1} is the original name of the attribute.
200+ // / {2} is the type hint.
194201constexpr const char *attributeGetterTemplate = R"Py(
195202 @builtins.property
196- def {0}(self):
203+ def {0}(self) -> {2} :
197204 return self.operation.attributes["{1}"]
198205)Py" ;
199206
200207// / Template for an optional operation attribute getter:
201208// / {0} is the name of the attribute sanitized for Python;
202209// / {1} is the original name of the attribute.
210+ // / {2} is the type hint.
203211constexpr const char *optionalAttributeGetterTemplate = R"Py(
204212 @builtins.property
205- def {0}(self):
213+ def {0}(self) -> _Optional[{2}] :
206214 if "{1}" not in self.operation.attributes:
207215 return None
208216 return self.operation.attributes["{1}"]
@@ -215,16 +223,17 @@ constexpr const char *optionalAttributeGetterTemplate = R"Py(
215223// / {1} is the original name of the attribute.
216224constexpr const char *unitAttributeGetterTemplate = R"Py(
217225 @builtins.property
218- def {0}(self):
226+ def {0}(self) -> bool :
219227 return "{1}" in self.operation.attributes
220228)Py" ;
221229
222230// / Template for an operation attribute setter:
223231// / {0} is the name of the attribute sanitized for Python;
224232// / {1} is the original name of the attribute.
233+ // / {2} is the type hint.
225234constexpr const char *attributeSetterTemplate = R"Py(
226235 @{0}.setter
227- def {0}(self, value):
236+ def {0}(self, value: {2} ):
228237 if value is None:
229238 raise ValueError("'None' not allowed as value for mandatory attributes")
230239 self.operation.attributes["{1}"] = value
@@ -234,9 +243,10 @@ constexpr const char *attributeSetterTemplate = R"Py(
234243// / removes the attribute:
235244// / {0} is the name of the attribute sanitized for Python;
236245// / {1} is the original name of the attribute.
246+ // / {2} is the type hint.
237247constexpr const char *optionalAttributeSetterTemplate = R"Py(
238248 @{0}.setter
239- def {0}(self, value):
249+ def {0}(self, value: _Optional[{2}] ):
240250 if value is not None:
241251 self.operation.attributes["{1}"] = value
242252 elif "{1}" in self.operation.attributes:
@@ -268,7 +278,7 @@ constexpr const char *attributeDeleterTemplate = R"Py(
268278
269279constexpr const char *regionAccessorTemplate = R"Py(
270280 @builtins.property
271- def {0}(self):
281+ def {0}(self) -> {2} :
272282 return self.regions[{1}]
273283)Py" ;
274284
@@ -360,15 +370,24 @@ static void emitElementAccessors(
360370 seenVariableLength = true ;
361371 if (element.name .empty ())
362372 continue ;
373+ const char *type = std::strcmp (kind, " operand" ) == 0 ? " _ods_ir.Value"
374+ : " _ods_ir.OpResult" ;
363375 if (element.isVariableLength ()) {
364- os << formatv (element.isOptional () ? opOneOptionalTemplate
365- : opOneVariadicTemplate,
366- sanitizeName (element.name ), kind, numElements, i);
376+ if (element.isOptional ()) {
377+ os << formatv (opOneOptionalTemplate, sanitizeName (element.name ), kind,
378+ numElements, i, type);
379+ } else {
380+ type = std::strcmp (kind, " operand" ) == 0 ? " _ods_ir.OpOperandList"
381+ : " _ods_ir.OpResultList" ;
382+ os << formatv (opOneVariadicTemplate, sanitizeName (element.name ), kind,
383+ numElements, i, type);
384+ }
367385 } else if (seenVariableLength) {
368386 os << formatv (opSingleAfterVariableTemplate, sanitizeName (element.name ),
369- kind, numElements, i);
387+ kind, numElements, i, type );
370388 } else {
371- os << formatv (opSingleTemplate, sanitizeName (element.name ), kind, i);
389+ os << formatv (opSingleTemplate, sanitizeName (element.name ), kind, i,
390+ type);
372391 }
373392 }
374393 return ;
@@ -391,9 +410,17 @@ static void emitElementAccessors(
391410 for (unsigned i = 0 ; i < numElements; ++i) {
392411 const NamedTypeConstraint &element = getElement (op, i);
393412 if (!element.name .empty ()) {
413+ std::string type;
414+ if (element.isVariableLength ()) {
415+ type = std::strcmp (kind, " operand" ) == 0 ? " _ods_ir.OpOperandList"
416+ : " _ods_ir.OpResultList" ;
417+ } else {
418+ type = std::strcmp (kind, " operand" ) == 0 ? " _ods_ir.Value"
419+ : " _ods_ir.OpResult" ;
420+ }
394421 os << formatv (opVariadicEqualPrefixTemplate, sanitizeName (element.name ),
395422 kind, numSimpleLength, numVariadicGroups,
396- numPrecedingSimple, numPrecedingVariadic);
423+ numPrecedingSimple, numPrecedingVariadic, type );
397424 os << formatv (element.isVariableLength ()
398425 ? opVariadicEqualVariadicTemplate
399426 : opVariadicEqualSimpleTemplate,
@@ -416,13 +443,23 @@ static void emitElementAccessors(
416443 if (element.name .empty ())
417444 continue ;
418445 std::string trailing;
419- if (!element.isVariableLength ())
420- trailing = " [0]" ;
421- else if (element.isOptional ())
422- trailing = std::string (
423- formatv (opVariadicSegmentOptionalTrailingTemplate, kind));
446+ std::string type = std::strcmp (kind, " operand" ) == 0
447+ ? " _ods_ir.OpOperandList"
448+ : " _ods_ir.OpResultList" ;
449+ if (!element.isVariableLength () || element.isOptional ()) {
450+ type = std::strcmp (kind, " operand" ) == 0 ? " _ods_ir.Value"
451+ : " _ods_ir.OpResult" ;
452+ if (!element.isVariableLength ()) {
453+ trailing = " [0]" ;
454+ } else if (element.isOptional ()) {
455+ type = " _Optional[" + type + " ]" ;
456+ trailing = std::string (
457+ formatv (opVariadicSegmentOptionalTrailingTemplate, kind));
458+ }
459+ }
460+
424461 os << formatv (opVariadicSegmentTemplate, sanitizeName (element.name ), kind,
425- i, trailing);
462+ i, trailing, type );
426463 }
427464 return ;
428465 }
@@ -452,6 +489,72 @@ static void emitResultAccessors(const Operator &op, raw_ostream &os) {
452489 getNumResults (op), getResult);
453490}
454491
492+ static std::string getPythonAttrName (mlir::tblgen::Attribute attr) {
493+ auto storageTypeStr = attr.getStorageType ();
494+ if (storageTypeStr == " ::mlir::AffineMapAttr" )
495+ return " AffineMapAttr" ;
496+ if (storageTypeStr == " ::mlir::ArrayAttr" )
497+ return " ArrayAttr" ;
498+ if (storageTypeStr == " ::mlir::BoolAttr" )
499+ return " BoolAttr" ;
500+ if (storageTypeStr == " ::mlir::DenseBoolArrayAttr" )
501+ return " DenseBoolArrayAttr" ;
502+ if (storageTypeStr == " ::mlir::DenseElementsAttr" ) {
503+ llvm::StringSet<> superClasses;
504+ for (const Record *sc : attr.getDef ().getSuperClasses ())
505+ superClasses.insert (sc->getNameInitAsString ());
506+ if (superClasses.contains (" FloatElementsAttr" ) ||
507+ superClasses.contains (" RankedFloatElementsAttr" )) {
508+ return " DenseFPElementsAttr" ;
509+ }
510+ return " DenseElementsAttr" ;
511+ }
512+ if (storageTypeStr == " ::mlir::DenseF32ArrayAttr" )
513+ return " DenseF32ArrayAttr" ;
514+ if (storageTypeStr == " ::mlir::DenseF64ArrayAttr" )
515+ return " DenseF64ArrayAttr" ;
516+ if (storageTypeStr == " ::mlir::DenseFPElementsAttr" )
517+ return " DenseFPElementsAttr" ;
518+ if (storageTypeStr == " ::mlir::DenseI16ArrayAttr" )
519+ return " DenseI16ArrayAttr" ;
520+ if (storageTypeStr == " ::mlir::DenseI32ArrayAttr" )
521+ return " DenseI32ArrayAttr" ;
522+ if (storageTypeStr == " ::mlir::DenseI64ArrayAttr" )
523+ return " DenseI64ArrayAttr" ;
524+ if (storageTypeStr == " ::mlir::DenseI8ArrayAttr" )
525+ return " DenseI8ArrayAttr" ;
526+ if (storageTypeStr == " ::mlir::DenseIntElementsAttr" )
527+ return " DenseIntElementsAttr" ;
528+ if (storageTypeStr == " ::mlir::DenseResourceElementsAttr" )
529+ return " DenseResourceElementsAttr" ;
530+ if (storageTypeStr == " ::mlir::DictionaryAttr" )
531+ return " DictAttr" ;
532+ if (storageTypeStr == " ::mlir::FlatSymbolRefAttr" )
533+ return " FlatSymbolRefAttr" ;
534+ if (storageTypeStr == " ::mlir::FloatAttr" )
535+ return " FloatAttr" ;
536+ if (storageTypeStr == " ::mlir::IntegerAttr" ) {
537+ if (attr.getAttrDefName ().str () == " I1Attr" )
538+ return " BoolAttr" ;
539+ return " IntegerAttr" ;
540+ }
541+ if (storageTypeStr == " ::mlir::IntegerSetAttr" )
542+ return " IntegerSetAttr" ;
543+ if (storageTypeStr == " ::mlir::OpaqueAttr" )
544+ return " OpaqueAttr" ;
545+ if (storageTypeStr == " ::mlir::StridedLayoutAttr" )
546+ return " StridedLayoutAttr" ;
547+ if (storageTypeStr == " ::mlir::StringAttr" )
548+ return " StringAttr" ;
549+ if (storageTypeStr == " ::mlir::SymbolRefAttr" )
550+ return " SymbolRefAttr" ;
551+ if (storageTypeStr == " ::mlir::TypeAttr" )
552+ return " TypeAttr" ;
553+ if (storageTypeStr == " ::mlir::UnitAttr" )
554+ return " UnitAttr" ;
555+ return " Attribute" ;
556+ }
557+
455558// / Emits accessors to Op attributes.
456559static void emitAttributeAccessors (const Operator &op, raw_ostream &os) {
457560 for (const auto &namedAttr : op.getAttributes ()) {
@@ -473,15 +576,18 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
473576 continue ;
474577 }
475578
579+ std::string type = " _ods_ir." + getPythonAttrName (namedAttr.attr );
476580 if (namedAttr.attr .isOptional ()) {
477581 os << formatv (optionalAttributeGetterTemplate, sanitizedName,
478- namedAttr.name );
582+ namedAttr.name , type );
479583 os << formatv (optionalAttributeSetterTemplate, sanitizedName,
480- namedAttr.name );
584+ namedAttr.name , type );
481585 os << formatv (attributeDeleterTemplate, sanitizedName, namedAttr.name );
482586 } else {
483- os << formatv (attributeGetterTemplate, sanitizedName, namedAttr.name );
484- os << formatv (attributeSetterTemplate, sanitizedName, namedAttr.name );
587+ os << formatv (attributeGetterTemplate, sanitizedName, namedAttr.name ,
588+ type);
589+ os << formatv (attributeSetterTemplate, sanitizedName, namedAttr.name ,
590+ type);
485591 // Non-optional attributes cannot be deleted.
486592 }
487593 }
@@ -983,8 +1089,9 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
9831089 assert ((!region.isVariadic () || en.index () == op.getNumRegions () - 1 ) &&
9841090 " expected only the last region to be variadic" );
9851091 os << formatv (regionAccessorTemplate, sanitizeName (region.name ),
986- std::to_string (en.index ()) +
987- (region.isVariadic () ? " :" : " " ));
1092+ std::to_string (en.index ()) + (region.isVariadic () ? " :" : " " ),
1093+ region.isVariadic () ? " _ods_ir.RegionSequence"
1094+ : " _ods_ir.Region" );
9881095 }
9891096}
9901097
0 commit comments