- 
                Notifications
    You must be signed in to change notification settings 
- Fork 14.9k
[MLIR][Complex] Add complex ops support in OPDSL. #162665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| @llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Hugo Trachino (nujaa) ChangesThis patch allows Opdsl to generate the complex version of existing OpDSL ops : 
 Adds support for 1 new op : 
 I needed to refactor  This patch includes nit-picking renaming of FileCheck names for better consistency. Full diff: https://github.com/llvm/llvm-project/pull/162665.diff 3 Files Affected: 
 diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 4f81a3874650d..3f3ec7b59eb3d 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -299,6 +299,7 @@ class UnaryFn:
     square = UnaryFnType("square")
     tanh = UnaryFnType("tanh")
     erf = UnaryFnType("erf")
+    conj = UnaryFnType("conj")
 
 
 class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 254458a978828..10f1083b11758 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -468,16 +468,22 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
     def _unary_exp(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.ExpOp(x).result
+        if _is_complex_type(x.type):
+            return complex.ExpOp(x).result
         raise NotImplementedError("Unsupported 'exp' operand: {x}")
 
     def _unary_log(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.LogOp(x).result
+        if _is_complex_type(x.type):
+            return complex.LogOp(x).result
         raise NotImplementedError("Unsupported 'log' operand: {x}")
 
     def _unary_abs(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.AbsFOp(x).result
+        if _is_complex_type(x.type):
+            return complex.AbsOp(x).result
         raise NotImplementedError("Unsupported 'abs' operand: {x}")
 
     def _unary_ceil(self, x: Value) -> Value:
@@ -497,6 +503,11 @@ def _unary_negf(self, x: Value) -> Value:
             return complex.NegOp(x).result
         raise NotImplementedError("Unsupported 'negf' operand: {x}")
 
+    def _unary_conj(self, x: Value) -> Value:
+        if _is_complex_type(x.type):
+            return complex.ConjOp(x).result
+        raise NotImplementedError("Unsupported 'conj' operand: {x}")
+
     def _binary_add(self, lhs: Value, rhs: Value) -> Value:
         if _is_floating_point_type(lhs.type):
             return arith.AddFOp(lhs, rhs).result
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index f8e034fb0e48b..2afae8b055ed0 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -30,7 +30,7 @@ def test_index(O=TensorDef(I32, S.M, S.N, output=True)):
 
 
 @linalg_structured_op
-def elemwise_unary_poly(
+def elemwise_unary_poly_cast(
     I=TensorDef(T),
     O=TensorDef(U, output=True),
     fun=UnaryFnAttrDef(default=UnaryFn.exp),
@@ -38,6 +38,13 @@ def elemwise_unary_poly(
 ):
     O[None] = fun(cast(U, I[None]))
 
+@linalg_structured_op
+def elemwise_unary_poly(
+    I=TensorDef(T),
+    O=TensorDef(U, output=True),
+    fun=UnaryFnAttrDef(default=UnaryFn.exp),
+):
+    O[None] = fun(I[None])
 
 @linalg_structured_op(op_name="custom_op_name")
 def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
@@ -84,6 +91,17 @@ def test_i32_index(init_result):
         def test_f32_elemwise_exp(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
 
+        # CHECK-LABEL: @test_c32_elemwise_exp
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[EXP:.+]] = complex.exp %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_exp(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
+
         # CHECK-LABEL: @test_f32_elemwise_log
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
         # CHECK-NEXT:   %[[LOG:.+]] = math.log %[[IN]] : f32
@@ -95,10 +113,21 @@ def test_f32_elemwise_exp(input, init_result):
         def test_f32_elemwise_log(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
 
+        # CHECK-LABEL: @test_c32_elemwise_log
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[LOG:.+]] = complex.log %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[LOG]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_log(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+
         # CHECK-LABEL: @test_f32_elemwise_abs
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.absf %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[ABS:.+]] = math.absf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[ABS]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -106,10 +135,21 @@ def test_f32_elemwise_log(input, init_result):
         def test_f32_elemwise_abs(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
 
+        # CHECK-LABEL: @test_c32_elemwise_abs
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[ABS:.+]] = complex.abs %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[ABS]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_c32_elemwise_abs(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+
         # CHECK-LABEL: @test_f32_elemwise_ceil
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.ceil %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[CEIL:.+]] = math.ceil %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[CEIL]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -119,8 +159,8 @@ def test_f32_elemwise_ceil(input, init_result):
 
         # CHECK-LABEL: @test_f32_elemwise_floor
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.floor %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[FLOOR:.+]] = math.floor %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[FLOOR]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -130,8 +170,8 @@ def test_f32_elemwise_floor(input, init_result):
 
         # CHECK-LABEL: @test_f32_elemwise_neg
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = arith.negf %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[NEG:.+]] = arith.negf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[NEG]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -141,8 +181,8 @@ def test_f32_elemwise_neg(input, init_result):
 
         # CHECK-LABEL: @test_c32_elemwise_neg
         # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
-        # CHECK-NEXT:   %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+        # CHECK-NEXT:   %[[NEG:.+]] = complex.neg %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[NEG]] : complex<f32>
         # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
@@ -150,6 +190,17 @@ def test_f32_elemwise_neg(input, init_result):
         def test_c32_elemwise_neg(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
 
+        # CHECK-LABEL: @test_c32_elemwise_conj
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[CONJ:.+]] = complex.conj %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[CONJ]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_conj(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.conj, cast=None)
+
         # Just check that we don't assert out on name mismatch.
         # CHECK-LABEL: @test_non_default_op_name
         @func.FuncOp.from_py_func(
 | 
| ✅ With the latest revision this PR passed the Python code formatter. | 
| We're trying to deprecate OpDSL, so adding more ops here is probably ill advised. Please add support for complex types in  | 
| 
 Its our bad little bit that we havent deprecated OpDSL yet. Probably this is not the right place to discuss here, but yes please add this support if possible as linalg.elementwise | 
This patch allows Opdsl to generate the complex version of existing OpDSL ops :
Adds support for 1 new op :
I needed to refactor
elemwise_unary_poly->elemwise_unary_poly_castsince Complex AbsOp has inconsistent Input and output type (complex vs float). Additionally, turns out the cast inelemwise_unary_polywas not necessary for the tested use cases. Let me know if you prefer to seeelemwise_unary_poly_castcompletely gone or if it is maybe used downstream.This patch includes nit-picking renaming of FileCheck names for better consistency.