diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 8fb1227ee80ff..742262a9c4969 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -10,6 +10,7 @@ # DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py. from .._linalg_ops_gen import * from .._linalg_enum_gen import * +from .._linalg_enum_gen import _iteratortypeenum # These are the ground truth functions defined as: # ``` @@ -58,6 +59,7 @@ from ...ir import * from .._ods_common import get_op_result_or_value as _get_op_result_or_value +from ...extras.meta import region_op def transpose( @@ -102,3 +104,46 @@ def broadcast( ) fill_builtin_region(op.operation) return op + + +@register_attribute_builder("IteratorTypeArrayAttr") +def _IteratorTypeArrayAttr(x, context): + return ArrayAttr.get([_iteratortypeenum(v, context) for v in x]) + + +# The underscore is needed here so that there's no collision with opdsl generation. +class GenericOp_(GenericOp): + def __init__( + self, + inputs, + outputs, + indexing_maps, + iterator_types, + *, + doc=None, + library_call=None, + loc=None, + ip=None, + ): + result_types = [] + if isinstance(outputs[0].type, RankedTensorType): + result_types = [o.type for o in outputs] + + super().__init__( + result_types, + inputs, + outputs, + indexing_maps, + iterator_types, + doc=doc, + library_call=library_call, + loc=loc, + ip=ip, + ) + element_types = [i.type.element_type for i in inputs] + [ + o.type.element_type for o in outputs + ] + self.regions[0].blocks.append(*element_types) + + +generic = region_op(GenericOp_, terminator=YieldOp) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 72045a07b2da8..ac7186c24bed8 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -1,6 +1,6 @@ # RUN: %PYTHON %s | FileCheck %s -from mlir.dialects import arith, builtin, func, linalg, tensor +from mlir.dialects import arith, func, linalg, tensor, memref from mlir.dialects.linalg.opdsl.lang import * from mlir.ir import * @@ -84,6 +84,7 @@ def named_form(lhs, rhs): print(module) + # CHECK-LABEL: TEST: testIdentityRegionOps @run def testIdentityRegionOps(): @@ -161,3 +162,97 @@ def broadcast_op(op1, op2, op3): op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0]) print(module) + + +# CHECK-LABEL: TEST: testGenericOp +@run +def testGenericOp(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + memref_t = MemRefType.get([10, 10], f32) + with InsertionPoint(module.body): + id_map_1 = AffineMap.get_identity(2) + # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32> + # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32> + x = tensor.empty((16, 16), f32) + y = tensor.empty((16, 16), f32) + + # CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) { + # CHECK: ^bb0(%in: f32, %out: f32): + # CHECK: linalg.yield %in : f32 + # CHECK: } -> tensor<16x16xf32> + @linalg.generic( + [x], + [y], + [id_map_1, id_map_1], + [linalg.IteratorType.parallel, linalg.IteratorType.parallel], + ) + def f(a, b): + assert isinstance(a, Value) + assert isinstance(a.type, F32Type) + assert isinstance(b, Value) + assert isinstance(b.type, F32Type) + return a + + assert isinstance(f, Value) + assert isinstance(f.type, RankedTensorType) + + # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32> + z = tensor.empty((16, 16, 16), f32) + + minor_id = AffineMap.get_minor_identity(3, 2) + id_map_2 = AffineMap.get_identity(3) + + # CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) { + # CHECK: ^bb0(%in: f32, %out: f32, %out_1: f32): + # CHECK: linalg.yield %in, %out : f32, f32 + # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>) + @linalg.generic( + [x], + [z, z], + [minor_id, id_map_2, id_map_2], + [ + linalg.IteratorType.parallel, + linalg.IteratorType.parallel, + linalg.IteratorType.parallel, + ], + ) + def g(a, b, c): + assert isinstance(a, Value) + assert isinstance(a.type, F32Type) + assert isinstance(b, Value) + assert isinstance(b.type, F32Type) + assert isinstance(c, Value) + assert isinstance(c.type, F32Type) + return a, b + + assert isinstance(g, OpResultList) + assert len(g) == 2 + assert isinstance(g[0].type, RankedTensorType) + assert isinstance(g[1].type, RankedTensorType) + + # CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<10x10xf32> + # CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<10x10xf32> + xx = memref.alloc(memref_t, [], []) + yy = memref.alloc(memref_t, [], []) + + # CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_5]] : memref<10x10xf32>) outs(%[[VAL_6]] : memref<10x10xf32>) { + # CHECK: ^bb0(%in: f32, %out: f32): + # CHECK: linalg.yield %in : f32 + # CHECK: } + @linalg.generic( + [xx], + [yy], + [id_map_1, id_map_1], + [linalg.IteratorType.parallel, linalg.IteratorType.parallel], + ) + def f(a, b): + assert isinstance(a, Value) + assert isinstance(a.type, F32Type) + assert isinstance(b, Value) + assert isinstance(b.type, F32Type) + return a + + module.operation.verify() + print(module)