Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 346f919

Browse files
authored
[MLIR][Linalg][Python] Improve bindings for linalg.elementwise (#139462)
Adds wrappers for ElementWiseOp, in particular to ensure appropriate default indexing maps are derived.
1 parent 342aa55 commit 346f919

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,67 @@ def contract(
216216
)
217217

218218

219+
# Extend and shadow the TableGen-derived version to make sure correct default
220+
# indexing_maps are derived (as there is no mechanism for doing so given the
221+
# Python API bypasses the C++-builders).
222+
class ElementwiseOp_(ElementwiseOp):
223+
def __init__(
224+
self,
225+
result_tensors,
226+
inputs,
227+
outputs,
228+
kind,
229+
*,
230+
indexing_maps=None,
231+
loc=None,
232+
ip=None,
233+
):
234+
if indexing_maps is None:
235+
inputs = [_get_op_result_or_value(in_) for in_ in inputs]
236+
for in0, in1 in zip(inputs[:-1], inputs[1:]):
237+
assert in0.type == in1.type
238+
output = _get_op_result_or_value(outputs[0])
239+
assert inputs[0].type == output.type
240+
num_args = len(inputs) + 1
241+
indexing_maps = [AffineMap.get_identity(output.type.rank)] * num_args
242+
243+
super().__init__(
244+
result_tensors=result_tensors,
245+
inputs=inputs,
246+
outputs=outputs,
247+
kind=kind,
248+
indexing_maps=indexing_maps,
249+
loc=loc,
250+
ip=ip,
251+
)
252+
253+
254+
ElementwiseOp = ElementwiseOp_
255+
256+
257+
def elementwise(
258+
*ins: Union[Operation, OpView, Value],
259+
outs: Sequence[Union[Operation, OpView, Value]],
260+
kind: Union[ElementwiseKind, Attribute],
261+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
262+
):
263+
ins = [_get_op_result_or_value(input) for input in ins]
264+
if len(outs) != 1:
265+
raise ValueError(f"{outs=} must have length 1.")
266+
init = _get_op_result_or_value(outs[0])
267+
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
268+
269+
op = ElementwiseOp(
270+
result_tensors=result_types,
271+
inputs=ins,
272+
outputs=[init],
273+
kind=kind,
274+
indexing_maps=indexing_maps,
275+
)
276+
fill_builtin_region(op.operation)
277+
return _get_op_result_or_op_results(op)
278+
279+
219280
def pack(
220281
source,
221282
dest,

0 commit comments

Comments
 (0)