Skip to content

Conversation

@makslevental
Copy link
Contributor

No description provided.

@makslevental makslevental marked this pull request as ready for review January 27, 2025 01:02
@llvmbot llvmbot added mlir:linalg mlir:python MLIR Python bindings mlir labels Jan 27, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 27, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Maksim Levental (makslevental)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/124496.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/linalg/init.py (+44)
  • (modified) mlir/test/python/dialects/linalg/ops.py (+60)
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..946094e2e9f691 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -10,6 +10,7 @@
 #   DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
 from .._linalg_ops_gen import *
 from .._linalg_enum_gen import *
+from .._linalg_enum_gen import _iteratortypeenum
 
 # These are the ground truth functions defined as:
 # ```
@@ -58,6 +59,7 @@
 
 from ...ir import *
 from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from ...extras.meta import region_op
 
 
 def transpose(
@@ -102,3 +104,45 @@ def broadcast(
     )
     fill_builtin_region(op.operation)
     return op
+
+
+@register_attribute_builder("IteratorTypeArrayAttr")
+def _IteratorTypeArrayAttr(x, context):
+    return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
+
+
+class GenericOp(GenericOp):
+    def __init__(
+        self,
+        inputs,
+        outputs,
+        indexing_maps,
+        iterator_types,
+        *,
+        doc=None,
+        library_call=None,
+        loc=None,
+        ip=None,
+    ):
+        result_types = []
+        if isinstance(outputs[0].type, RankedTensorType):
+            result_types = [o.type for o in outputs]
+
+        super().__init__(
+            result_types,
+            inputs,
+            outputs,
+            indexing_maps,
+            iterator_types,
+            doc=doc,
+            library_call=library_call,
+            loc=loc,
+            ip=ip,
+        )
+        element_types = [i.type.element_type for i in inputs] + [
+            o.type.element_type for o in outputs
+        ]
+        self.regions[0].blocks.append(*element_types)
+
+
+generic = region_op(GenericOp, terminator=YieldOp)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 72045a07b2da80..b7e0f2884bb249 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
 
     print(module)
 
+
 # CHECK-LABEL: TEST: testIdentityRegionOps
 @run
 def testIdentityRegionOps():
@@ -161,3 +162,62 @@ def broadcast_op(op1, op2, op3):
                 op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
 
     print(module)
+
+
+# CHECK-LABEL: TEST: testGenericOp
+@run
+def testGenericOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            id_map = AffineMap.get_identity(2)
+            # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
+            # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
+            x = tensor.empty((16, 16), f32)
+            y = tensor.empty((16, 16), f32)
+
+            # CHECK: %[[VAL_3:*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
+            # CHECK: ^bb0(%in: f32, %out: f32):
+            # CHECK:   linalg.yield %in : f32
+            # CHECK: } -> tensor<16x16xf32>
+            @linalg.generic(
+                [x],
+                [y],
+                [id_map, id_map],
+                [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+            )
+            def f(x, y):
+                return x
+
+            assert isinstance(f, Value)
+
+            # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
+            z = tensor.empty((16, 16, 16), f32)
+
+            minor_id = AffineMap.get_minor_identity(3, 2)
+            id_map = AffineMap.get_identity(3)
+
+            # CHECK: %%[[VAL_4:.*]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
+            # CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32):
+            # CHECK:   linalg.yield %in, %out : f32, f32
+            # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
+            @linalg.generic(
+                [x],
+                [z, z],
+                [minor_id, id_map, id_map],
+                [
+                    linalg.IteratorType.parallel,
+                    linalg.IteratorType.parallel,
+                    linalg.IteratorType.parallel,
+                ],
+            )
+            def g(x, z1, z2):
+                return x, z1
+
+            assert isinstance(g, OpResultList)
+            assert len(g) == 2
+            assert isinstance(g[0].type, RankedTensorType)
+            assert isinstance(g[1].type, RankedTensorType)
+
+    print(module)

@llvmbot
Copy link
Member

llvmbot commented Jan 27, 2025

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/124496.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/linalg/init.py (+44)
  • (modified) mlir/test/python/dialects/linalg/ops.py (+60)
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..946094e2e9f691 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -10,6 +10,7 @@
 #   DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
 from .._linalg_ops_gen import *
 from .._linalg_enum_gen import *
+from .._linalg_enum_gen import _iteratortypeenum
 
 # These are the ground truth functions defined as:
 # ```
@@ -58,6 +59,7 @@
 
 from ...ir import *
 from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from ...extras.meta import region_op
 
 
 def transpose(
@@ -102,3 +104,45 @@ def broadcast(
     )
     fill_builtin_region(op.operation)
     return op
+
+
+@register_attribute_builder("IteratorTypeArrayAttr")
+def _IteratorTypeArrayAttr(x, context):
+    return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
+
+
+class GenericOp(GenericOp):
+    def __init__(
+        self,
+        inputs,
+        outputs,
+        indexing_maps,
+        iterator_types,
+        *,
+        doc=None,
+        library_call=None,
+        loc=None,
+        ip=None,
+    ):
+        result_types = []
+        if isinstance(outputs[0].type, RankedTensorType):
+            result_types = [o.type for o in outputs]
+
+        super().__init__(
+            result_types,
+            inputs,
+            outputs,
+            indexing_maps,
+            iterator_types,
+            doc=doc,
+            library_call=library_call,
+            loc=loc,
+            ip=ip,
+        )
+        element_types = [i.type.element_type for i in inputs] + [
+            o.type.element_type for o in outputs
+        ]
+        self.regions[0].blocks.append(*element_types)
+
+
+generic = region_op(GenericOp, terminator=YieldOp)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 72045a07b2da80..b7e0f2884bb249 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
 
     print(module)
 
+
 # CHECK-LABEL: TEST: testIdentityRegionOps
 @run
 def testIdentityRegionOps():
@@ -161,3 +162,62 @@ def broadcast_op(op1, op2, op3):
                 op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
 
     print(module)
+
+
+# CHECK-LABEL: TEST: testGenericOp
+@run
+def testGenericOp():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            id_map = AffineMap.get_identity(2)
+            # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
+            # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
+            x = tensor.empty((16, 16), f32)
+            y = tensor.empty((16, 16), f32)
+
+            # CHECK: %[[VAL_3:*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
+            # CHECK: ^bb0(%in: f32, %out: f32):
+            # CHECK:   linalg.yield %in : f32
+            # CHECK: } -> tensor<16x16xf32>
+            @linalg.generic(
+                [x],
+                [y],
+                [id_map, id_map],
+                [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+            )
+            def f(x, y):
+                return x
+
+            assert isinstance(f, Value)
+
+            # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
+            z = tensor.empty((16, 16, 16), f32)
+
+            minor_id = AffineMap.get_minor_identity(3, 2)
+            id_map = AffineMap.get_identity(3)
+
+            # CHECK: %%[[VAL_4:.*]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
+            # CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32):
+            # CHECK:   linalg.yield %in, %out : f32, f32
+            # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
+            @linalg.generic(
+                [x],
+                [z, z],
+                [minor_id, id_map, id_map],
+                [
+                    linalg.IteratorType.parallel,
+                    linalg.IteratorType.parallel,
+                    linalg.IteratorType.parallel,
+                ],
+            )
+            def g(x, z1, z2):
+                return x, z1
+
+            assert isinstance(g, OpResultList)
+            assert len(g) == 2
+            assert isinstance(g[0].type, RankedTensorType)
+            assert isinstance(g[1].type, RankedTensorType)
+
+    print(module)

@makslevental makslevental force-pushed the users/makslevental/linalg-generic-op branch 3 times, most recently from c61e7c9 to a4f728f Compare January 27, 2025 02:32
@makslevental makslevental requested a review from jpienaar January 27, 2025 02:33
@makslevental makslevental force-pushed the users/makslevental/linalg-generic-op branch from a4f728f to 3053439 Compare January 27, 2025 02:37
Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but i'll let someone else approve since I requested this feature...

@makslevental makslevental merged commit 1bc5fe6 into llvm:main Jan 28, 2025
8 checks passed
@makslevental makslevental deleted the users/makslevental/linalg-generic-op branch January 28, 2025 17:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants