Skip to content

Commit 175e3be

Browse files
[MLIR][Python] Add region_op wrappers for linalg (#167616)
Makes linalg.reduce and linalg.map region_ops so they can be constructed from functions and be called as decorators.
1 parent 389a23c commit 175e3be

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,7 @@ def unpack(
352352
ip=ip,
353353
)
354354
)
355+
356+
357+
reduce = region_op(ReduceOp, terminator=YieldOp)
358+
map = region_op(MapOp, terminator=YieldOp)

mlir/test/python/dialects/linalg/ops.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

3-
from mlir.dialects import arith, func, linalg, tensor, memref
3+
from mlir.dialects import arith, func, linalg, tensor, memref, builtin
44
from mlir.dialects.linalg.opdsl.lang import *
5+
from mlir.extras import types as T
56
from mlir.ir import *
67

78

@@ -857,3 +858,76 @@ def elementwise_op(
857858
)
858859

859860
print(module)
861+
862+
863+
@run
864+
def testReduceOp():
865+
with Context(), Location.unknown():
866+
f32 = T.f32()
867+
tensor_type = T.tensor(10, f32)
868+
869+
@builtin.module
870+
def module():
871+
@func.func(tensor_type)
872+
def reduce_op(input):
873+
c1 = arith.constant(f32, 1.0)
874+
single_result = ir.RankedTensorType.get((), f32)
875+
dims = ir.DenseI64ArrayAttr.get([0])
876+
init = tensor.splat(single_result, c1, [])
877+
878+
@linalg.reduce(
879+
result=[single_result],
880+
inputs=[input],
881+
inits=[init],
882+
dimensions=dims,
883+
)
884+
def reduced(element: f32, acc: f32):
885+
return arith.mulf(acc, element)
886+
887+
return tensor.extract(reduced, [])
888+
889+
print(module)
890+
891+
892+
# CHECK-LABEL: func.func @reduce_op(
893+
# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> f32 {
894+
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.000000e+00 : f32
895+
# CHECK: %[[SPLAT_0:.*]] = tensor.splat %[[CONSTANT_0]] : tensor<f32>
896+
# CHECK: %[[REDUCE_0:.*]] = linalg.reduce { arith.mulf } ins(%[[ARG0]] : tensor<10xf32>) outs(%[[SPLAT_0]] : tensor<f32>) dimensions = [0]
897+
# CHECK: %[[EXTRACT_0:.*]] = tensor.extract %[[REDUCE_0]][] : tensor<f32>
898+
# CHECK: return %[[EXTRACT_0]] : f32
899+
# CHECK: }
900+
901+
902+
@run
903+
def testMapOp():
904+
with Context(), Location.unknown():
905+
f32 = T.f32()
906+
tensor_type = T.tensor(10, f32)
907+
908+
@builtin.module
909+
def module():
910+
@func.func(tensor_type)
911+
def map_op(input):
912+
empty = tensor.empty(tensor_type.shape, f32)
913+
914+
@linalg.map(
915+
result=[tensor_type],
916+
inputs=[input, input],
917+
init=empty,
918+
)
919+
def add(element: f32, acc: f32, init: f32):
920+
return arith.addf(element, acc)
921+
922+
return add
923+
924+
module.verify()
925+
print(module)
926+
927+
928+
# CHECK-LABEL: func.func @map_op(
929+
# CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32>) -> tensor<10xf32> {
930+
# CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<10xf32>
931+
# CHECK: %[[MAP_0:.*]] = linalg.map { arith.addf } ins(%[[ARG0]], %[[ARG0]] : tensor<10xf32>, tensor<10xf32>) outs(%[[EMPTY_0]] : tensor<10xf32>)
932+
# CHECK: return %[[MAP_0]] : tensor<10xf32>
933+
# CHECK: }

0 commit comments

Comments
 (0)