Skip to content

Commit ba739c1

Browse files
authored
[MLIR][Linalg][Python] Improve bindings for linalg.elementwise (llvm#139462)
Adds wrappers for ElementWiseOp, in particular to ensure appropriate default indexing maps are derived.
1 parent 688bccb commit ba739c1

File tree

2 files changed

+247
-0
lines changed

2 files changed

+247
-0
lines changed

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,67 @@ def contract(
216216
)
217217

218218

219+
# Extend and shadow the TableGen-derived version to make sure correct default
220+
# indexing_maps are derived (as there is no mechanism for doing so given the
221+
# Python API bypasses the C++-builders).
222+
class ElementwiseOp_(ElementwiseOp):
223+
def __init__(
224+
self,
225+
result_tensors,
226+
inputs,
227+
outputs,
228+
kind,
229+
*,
230+
indexing_maps=None,
231+
loc=None,
232+
ip=None,
233+
):
234+
if indexing_maps is None:
235+
inputs = [_get_op_result_or_value(in_) for in_ in inputs]
236+
for in0, in1 in zip(inputs[:-1], inputs[1:]):
237+
assert in0.type == in1.type
238+
output = _get_op_result_or_value(outputs[0])
239+
assert inputs[0].type == output.type
240+
num_args = len(inputs) + 1
241+
indexing_maps = [AffineMap.get_identity(output.type.rank)] * num_args
242+
243+
super().__init__(
244+
result_tensors=result_tensors,
245+
inputs=inputs,
246+
outputs=outputs,
247+
kind=kind,
248+
indexing_maps=indexing_maps,
249+
loc=loc,
250+
ip=ip,
251+
)
252+
253+
254+
ElementwiseOp = ElementwiseOp_
255+
256+
257+
def elementwise(
258+
*ins: Union[Operation, OpView, Value],
259+
outs: Sequence[Union[Operation, OpView, Value]],
260+
kind: Union[ElementwiseKind, Attribute],
261+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
262+
):
263+
ins = [_get_op_result_or_value(input) for input in ins]
264+
if len(outs) != 1:
265+
raise ValueError(f"{outs=} must have length 1.")
266+
init = _get_op_result_or_value(outs[0])
267+
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
268+
269+
op = ElementwiseOp(
270+
result_tensors=result_types,
271+
inputs=ins,
272+
outputs=[init],
273+
kind=kind,
274+
indexing_maps=indexing_maps,
275+
)
276+
fill_builtin_region(op.operation)
277+
return _get_op_result_or_op_results(op)
278+
279+
219280
def pack(
220281
source,
221282
dest,

mlir/test/python/dialects/linalg/ops.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,3 +606,189 @@ def tensor_pack(src, dst):
606606
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
607607
# CHECK: }
608608
print(module)
609+
610+
611+
# CHECK-LABEL: TEST: testElementwiseOp
612+
@run
613+
def testElementwiseOp():
614+
with Context(), Location.unknown():
615+
module = Module.create()
616+
f32 = F32Type.get()
617+
with InsertionPoint(module.body):
618+
rect_shape = (8, 16)
619+
vert_line_shape = (8,)
620+
hor_line_shape = (16,)
621+
transposed_rect_shape = (16, 8)
622+
623+
# CHECK-DAG: #[[$IdentMap2D:.*]] = affine_map<(d0, d1) -> (d0, d1)>
624+
# CHECK-DAG: #[[$TransMap2D:.*]] = affine_map<(d0, d1) -> (d1, d0)>
625+
# CHECK-DAG: #[[$VertLineBCastMap:.*]] = affine_map<(d0, d1) -> (d0)>
626+
# CHECK-DAG: #[[$HorLineBCastMap:.*]] = affine_map<(d0, d1) -> (d1)>
627+
628+
ident_map_2d = AffineMap.get_identity(2)
629+
transposed_map_2d = AffineMap.get_permutation((1, 0))
630+
vert_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(0)])
631+
hor_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(1)])
632+
633+
# CHECK: func.func @elementwise_op(
634+
@func.FuncOp.from_py_func(
635+
# CHECK-SAME: %[[Rect:.*]]: tensor<8x16xf32>,
636+
RankedTensorType.get(rect_shape, f32),
637+
# CHECK-SAME: %[[RectMem:.*]]: memref<8x16xf32>,
638+
MemRefType.get(rect_shape, f32),
639+
# CHECK-SAME: %[[VertLine:.*]]: tensor<8xf32>,
640+
RankedTensorType.get(vert_line_shape, f32),
641+
# CHECK-SAME: %[[VertLineMem:.*]]: memref<8xf32>,
642+
MemRefType.get(vert_line_shape, f32),
643+
# CHECK-SAME: %[[HorLine:.*]]: tensor<16xf32>,
644+
RankedTensorType.get(hor_line_shape, f32),
645+
# CHECK-SAME: %[[HorLineMem:.*]]: memref<16xf32>,
646+
MemRefType.get(hor_line_shape, f32),
647+
# CHECK-SAME: %[[TransRect:.*]]: tensor<16x8xf32>,
648+
RankedTensorType.get(transposed_rect_shape, f32),
649+
# CHECK-SAME: %[[TransRectMem:.*]]: memref<16x8xf32>)
650+
MemRefType.get(transposed_rect_shape, f32),
651+
)
652+
def elementwise_op(
653+
rect,
654+
rect_mem,
655+
vert_line,
656+
vert_line_mem,
657+
hor_line,
658+
hor_line_mem,
659+
trans_rect,
660+
trans_rect_mem,
661+
):
662+
# CHECK: %[[OutRect:.*]] = tensor.empty() : tensor<8x16xf32>
663+
out_rect = tensor.EmptyOp(rect_shape, f32)
664+
# CHECK: %[[OutRectMem:.*]] = memref.alloca() : memref<8x16xf32>
665+
out_rect_mem = memref.alloca(MemRefType.get(rect_shape, f32), [], [])
666+
667+
if _inferred_affine_maps := True:
668+
# CHECK: linalg.elementwise
669+
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
670+
# CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
671+
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
672+
op1 = linalg.ElementwiseOp(
673+
result_tensors=(out_rect.result.type,),
674+
inputs=(rect,),
675+
outputs=(out_rect,),
676+
kind=linalg.ElementwiseKind.exp,
677+
)
678+
linalg.fill_builtin_region(op1.operation)
679+
680+
# CHECK: linalg.elementwise
681+
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
682+
# CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
683+
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
684+
linalg.elementwise(
685+
rect,
686+
outs=(out_rect,),
687+
kind=linalg.ElementwiseKind.exp,
688+
)
689+
690+
# CHECK: linalg.elementwise
691+
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
692+
# CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
693+
# CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
694+
linalg.elementwise(
695+
rect_mem,
696+
outs=(out_rect_mem,),
697+
kind=linalg.ElementwiseKind.exp,
698+
)
699+
700+
if _explicit_ident_affine_maps := True:
701+
# Same as above but with default identity indexing_maps explicitly provided.
702+
# CHECK: linalg.elementwise
703+
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
704+
# CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
705+
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
706+
op3 = linalg.ElementwiseOp(
707+
result_tensors=(out_rect.result.type,),
708+
inputs=(rect,),
709+
outputs=(out_rect,),
710+
kind=linalg.ElementwiseKind.exp,
711+
indexing_maps=[ident_map_2d, ident_map_2d],
712+
)
713+
linalg.fill_builtin_region(op3.operation)
714+
715+
# CHECK: linalg.elementwise
716+
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
717+
# CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
718+
# CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
719+
linalg.elementwise(
720+
rect_mem,
721+
outs=(out_rect_mem,),
722+
kind=linalg.ElementwiseKind.exp,
723+
indexing_maps=[ident_map_2d, ident_map_2d],
724+
)
725+
726+
if _ops_with_non_ident_input_maps := True:
727+
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<exp>
728+
# CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$IdentMap2D]]]
729+
# CHECK-SAME: ins(%[[VertLine]] : tensor<8xf32>)
730+
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
731+
op4 = linalg.ElementwiseOp(
732+
result_tensors=(out_rect.result.type,),
733+
inputs=(vert_line,),
734+
outputs=(out_rect,),
735+
kind=linalg.ElementwiseKind.exp,
736+
indexing_maps=[vert_line_bcast_map, ident_map_2d],
737+
)
738+
linalg.fill_builtin_region(op4.operation)
739+
740+
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
741+
# CHECK-SAME: indexing_maps = [#[[$IdentMap2D]], #[[$VertLineBCastMap]], #[[$IdentMap2D]]]
742+
# CHECK-SAME: ins(%[[Rect]], %[[VertLine]] : tensor<8x16xf32>, tensor<8xf32>)
743+
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
744+
op4 = linalg.ElementwiseOp(
745+
result_tensors=(out_rect.result.type,),
746+
inputs=(rect, vert_line),
747+
outputs=(out_rect,),
748+
kind=linalg.ElementwiseKind.add,
749+
indexing_maps=[ident_map_2d, vert_line_bcast_map, ident_map_2d],
750+
)
751+
linalg.fill_builtin_region(op4.operation)
752+
753+
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
754+
# CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$IdentMap2D]]]
755+
# CHECK-SAME: ins(%[[VertLine]], %[[HorLine]] : tensor<8xf32>, tensor<16xf32>)
756+
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
757+
linalg.elementwise(
758+
vert_line,
759+
hor_line,
760+
outs=(out_rect,),
761+
kind=linalg.ElementwiseKind.div,
762+
indexing_maps=[
763+
vert_line_bcast_map,
764+
hor_line_bcast_map,
765+
ident_map_2d,
766+
],
767+
)
768+
769+
if _ops_with_non_ident_and_transposed_input_maps := True:
770+
# CHECK: %[[VertLineBoolsMem:.*]] = memref.alloca() : memref<8xi1>
771+
vert_line_bools_mem = memref.alloca(
772+
MemRefType.get(vert_line_shape, IntegerType.get_signless(1)),
773+
[],
774+
[],
775+
)
776+
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<select>
777+
# CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$TransMap2D]], #[[$IdentMap2D]]]
778+
# CHECK-SAME: ins(%[[VertLineBoolsMem]], %[[HorLineMem]], %[[TransRectMem]] : memref<8xi1>, memref<16xf32>, memref<16x8xf32>)
779+
# CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
780+
linalg.elementwise(
781+
vert_line_bools_mem,
782+
hor_line_mem,
783+
trans_rect_mem,
784+
outs=(out_rect_mem,),
785+
kind=linalg.ElementwiseKind.select,
786+
indexing_maps=[
787+
vert_line_bcast_map,
788+
hor_line_bcast_map,
789+
transposed_map_2d,
790+
ident_map_2d,
791+
],
792+
)
793+
794+
print(module)

0 commit comments

Comments
 (0)