-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][python] implement GenericOp bindings #124496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][python] implement GenericOp bindings #124496
Conversation
|
@llvm/pr-subscribers-mlir-linalg Author: Maksim Levental (makslevental) ChangesFull diff: https://github.com/llvm/llvm-project/pull/124496.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..946094e2e9f691 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -10,6 +10,7 @@
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
from .._linalg_ops_gen import *
from .._linalg_enum_gen import *
+from .._linalg_enum_gen import _iteratortypeenum
# These are the ground truth functions defined as:
# ```
@@ -58,6 +59,7 @@
from ...ir import *
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from ...extras.meta import region_op
def transpose(
@@ -102,3 +104,45 @@ def broadcast(
)
fill_builtin_region(op.operation)
return op
+
+
+@register_attribute_builder("IteratorTypeArrayAttr")
+def _IteratorTypeArrayAttr(x, context):
+ return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
+
+
+class GenericOp(GenericOp):
+ def __init__(
+ self,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ *,
+ doc=None,
+ library_call=None,
+ loc=None,
+ ip=None,
+ ):
+ result_types = []
+ if isinstance(outputs[0].type, RankedTensorType):
+ result_types = [o.type for o in outputs]
+
+ super().__init__(
+ result_types,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ doc=doc,
+ library_call=library_call,
+ loc=loc,
+ ip=ip,
+ )
+ element_types = [i.type.element_type for i in inputs] + [
+ o.type.element_type for o in outputs
+ ]
+ self.regions[0].blocks.append(*element_types)
+
+
+generic = region_op(GenericOp, terminator=YieldOp)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 72045a07b2da80..b7e0f2884bb249 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
print(module)
+
# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
@@ -161,3 +162,62 @@ def broadcast_op(op1, op2, op3):
op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
print(module)
+
+
+# CHECK-LABEL: TEST: testGenericOp
+@run
+def testGenericOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ id_map = AffineMap.get_identity(2)
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
+ # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
+ x = tensor.empty((16, 16), f32)
+ y = tensor.empty((16, 16), f32)
+
+ # CHECK: %[[VAL_3:*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32):
+ # CHECK: linalg.yield %in : f32
+ # CHECK: } -> tensor<16x16xf32>
+ @linalg.generic(
+ [x],
+ [y],
+ [id_map, id_map],
+ [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+ )
+ def f(x, y):
+ return x
+
+ assert isinstance(f, Value)
+
+ # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
+ z = tensor.empty((16, 16, 16), f32)
+
+ minor_id = AffineMap.get_minor_identity(3, 2)
+ id_map = AffineMap.get_identity(3)
+
+ # CHECK: %%[[VAL_4:.*]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32):
+ # CHECK: linalg.yield %in, %out : f32, f32
+ # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
+ @linalg.generic(
+ [x],
+ [z, z],
+ [minor_id, id_map, id_map],
+ [
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ ],
+ )
+ def g(x, z1, z2):
+ return x, z1
+
+ assert isinstance(g, OpResultList)
+ assert len(g) == 2
+ assert isinstance(g[0].type, RankedTensorType)
+ assert isinstance(g[1].type, RankedTensorType)
+
+ print(module)
|
|
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesFull diff: https://github.com/llvm/llvm-project/pull/124496.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 8fb1227ee80ff5..946094e2e9f691 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -10,6 +10,7 @@
# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py.
from .._linalg_ops_gen import *
from .._linalg_enum_gen import *
+from .._linalg_enum_gen import _iteratortypeenum
# These are the ground truth functions defined as:
# ```
@@ -58,6 +59,7 @@
from ...ir import *
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+from ...extras.meta import region_op
def transpose(
@@ -102,3 +104,45 @@ def broadcast(
)
fill_builtin_region(op.operation)
return op
+
+
+@register_attribute_builder("IteratorTypeArrayAttr")
+def _IteratorTypeArrayAttr(x, context):
+ return ArrayAttr.get([_iteratortypeenum(v, context) for v in x])
+
+
+class GenericOp(GenericOp):
+ def __init__(
+ self,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ *,
+ doc=None,
+ library_call=None,
+ loc=None,
+ ip=None,
+ ):
+ result_types = []
+ if isinstance(outputs[0].type, RankedTensorType):
+ result_types = [o.type for o in outputs]
+
+ super().__init__(
+ result_types,
+ inputs,
+ outputs,
+ indexing_maps,
+ iterator_types,
+ doc=doc,
+ library_call=library_call,
+ loc=loc,
+ ip=ip,
+ )
+ element_types = [i.type.element_type for i in inputs] + [
+ o.type.element_type for o in outputs
+ ]
+ self.regions[0].blocks.append(*element_types)
+
+
+generic = region_op(GenericOp, terminator=YieldOp)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 72045a07b2da80..b7e0f2884bb249 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
print(module)
+
# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
@@ -161,3 +162,62 @@ def broadcast_op(op1, op2, op3):
op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
print(module)
+
+
+# CHECK-LABEL: TEST: testGenericOp
+@run
+def testGenericOp():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ id_map = AffineMap.get_identity(2)
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
+ # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
+ x = tensor.empty((16, 16), f32)
+ y = tensor.empty((16, 16), f32)
+
+ # CHECK: %[[VAL_3:*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32):
+ # CHECK: linalg.yield %in : f32
+ # CHECK: } -> tensor<16x16xf32>
+ @linalg.generic(
+ [x],
+ [y],
+ [id_map, id_map],
+ [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
+ )
+ def f(x, y):
+ return x
+
+ assert isinstance(f, Value)
+
+ # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
+ z = tensor.empty((16, 16, 16), f32)
+
+ minor_id = AffineMap.get_minor_identity(3, 2)
+ id_map = AffineMap.get_identity(3)
+
+ # CHECK: %%[[VAL_4:.*]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
+ # CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32):
+ # CHECK: linalg.yield %in, %out : f32, f32
+ # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
+ @linalg.generic(
+ [x],
+ [z, z],
+ [minor_id, id_map, id_map],
+ [
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ linalg.IteratorType.parallel,
+ ],
+ )
+ def g(x, z1, z2):
+ return x, z1
+
+ assert isinstance(g, OpResultList)
+ assert len(g) == 2
+ assert isinstance(g[0].type, RankedTensorType)
+ assert isinstance(g[1].type, RankedTensorType)
+
+ print(module)
|
c61e7c9 to
a4f728f
Compare
a4f728f to
3053439
Compare
Groverkss
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but i'll let someone else approve since I requested this feature...
No description provided.