@@ -216,6 +216,67 @@ def contract(
216216 )
217217
218218
219+ # Extend and shadow the TableGen-derived version to make sure correct default
220+ # indexing_maps are derived (as there is no mechanism for doing so given the
221+ # Python API bypasses the C++-builders).
222+ class ElementwiseOp_ (ElementwiseOp ):
223+ def __init__ (
224+ self ,
225+ result_tensors ,
226+ inputs ,
227+ outputs ,
228+ kind ,
229+ * ,
230+ indexing_maps = None ,
231+ loc = None ,
232+ ip = None ,
233+ ):
234+ if indexing_maps is None :
235+ inputs = [_get_op_result_or_value (in_ ) for in_ in inputs ]
236+ for in0 , in1 in zip (inputs [:- 1 ], inputs [1 :]):
237+ assert in0 .type == in1 .type
238+ output = _get_op_result_or_value (outputs [0 ])
239+ assert inputs [0 ].type == output .type
240+ num_args = len (inputs ) + 1
241+ indexing_maps = [AffineMap .get_identity (output .type .rank )] * num_args
242+
243+ super ().__init__ (
244+ result_tensors = result_tensors ,
245+ inputs = inputs ,
246+ outputs = outputs ,
247+ kind = kind ,
248+ indexing_maps = indexing_maps ,
249+ loc = loc ,
250+ ip = ip ,
251+ )
252+
253+
254+ ElementwiseOp = ElementwiseOp_
255+
256+
257+ def elementwise (
258+ * ins : Union [Operation , OpView , Value ],
259+ outs : Sequence [Union [Operation , OpView , Value ]],
260+ kind : Union [ElementwiseKind , Attribute ],
261+ indexing_maps : Optional [Sequence [AffineMapAttr ]] = None ,
262+ ):
263+ ins = [_get_op_result_or_value (input ) for input in ins ]
264+ if len (outs ) != 1 :
265+ raise ValueError (f"{ outs = } must have length 1." )
266+ init = _get_op_result_or_value (outs [0 ])
267+ result_types = [init .type ] if isinstance (init .type , RankedTensorType ) else []
268+
269+ op = ElementwiseOp (
270+ result_tensors = result_types ,
271+ inputs = ins ,
272+ outputs = [init ],
273+ kind = kind ,
274+ indexing_maps = indexing_maps ,
275+ )
276+ fill_builtin_region (op .operation )
277+ return _get_op_result_or_op_results (op )
278+
279+
219280def pack (
220281 source ,
221282 dest ,
0 commit comments