Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit a2056b6

Browse files
authored
[mlir][python] implement GenericOp bindings (#124496)
1 parent 4fa9291 commit a2056b6

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
1111
from .._linalg_ops_gen import *
1212
from .._linalg_enum_gen import *
13+
from .._linalg_enum_gen import _iteratortypeenum
1314

1415
# These are the ground truth functions defined as:
1516
# ```
@@ -58,6 +59,7 @@
5859

5960
from ...ir import *
6061
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
62+
from ...extras.meta import region_op
6163

6264

6365
def transpose(
@@ -102,3 +104,46 @@ def broadcast(
102104
)
103105
fill_builtin_region(op.operation)
104106
return op
107+
108+
109+
@register_attribute_builder("IteratorTypeArrayAttr")
110+
def _IteratorTypeArrayAttr(x, context):
111+
return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
112+
113+
114+
# The underscore is needed here so that there's no collision with opdsl generation.
115+
class GenericOp_(GenericOp):
116+
def __init__(
117+
self,
118+
inputs,
119+
outputs,
120+
indexing_maps,
121+
iterator_types,
122+
*,
123+
doc=None,
124+
library_call=None,
125+
loc=None,
126+
ip=None,
127+
):
128+
result_types = []
129+
if isinstance(outputs[0].type, RankedTensorType):
130+
result_types = [o.type for o in outputs]
131+
132+
super().__init__(
133+
result_types,
134+
inputs,
135+
outputs,
136+
indexing_maps,
137+
iterator_types,
138+
doc=doc,
139+
library_call=library_call,
140+
loc=loc,
141+
ip=ip,
142+
)
143+
element_types = [i.type.element_type for i in inputs] + [
144+
o.type.element_type for o in outputs
145+
]
146+
self.regions[0].blocks.append(*element_types)
147+
148+
149+
generic = region_op(GenericOp_, terminator=YieldOp)

0 commit comments

Comments
 (0)