Skip to content

Conversation

nujaa
Copy link
Contributor

@nujaa nujaa commented Oct 9, 2025

This patch allows Opdsl to generate the complex version of existing OpDSL ops :

  • ExpOp
  • LogOp
  • AbsOp

Adds support for 1 new op :

  • ConjOp

I needed to refactor elemwise_unary_poly -> elemwise_unary_poly_cast since Complex AbsOp has inconsistent Input and output type (complex vs float). Additionally, turns out the cast in elemwise_unary_poly was not necessary for the tested use cases. Let me know if you prefer to see elemwise_unary_poly_cast completely gone or if it is maybe used downstream.

This patch includes nit-picking renaming of FileCheck names for better consistency.

@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Hugo Trachino (nujaa)

Changes

This patch allows Opdsl to generate the complex version of existing OpDSL ops :

  • ExpOp
  • LogOp
  • AbsOp

Adds support for 1 new op :

  • ConjOp

I needed to refactor elemwise_unary_poly -> elemwise_unary_poly_cast since Complex AbsOp has inconsistent Input and output type (complex vs float). Additionally, turns out the cast in elemwise_unary_poly was not necessary for the tested use cases. Let me know if you prefer to see elemwise_unary_poly_cast completely gone or if it is maybe used downstream.

This patch includes nit-picking renaming of FileCheck names for better consistency.


Full diff: https://github.com/llvm/llvm-project/pull/162665.diff

3 Files Affected:

  • (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py (+1)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py (+11)
  • (modified) mlir/test/python/dialects/linalg/opdsl/emit_misc.py (+62-11)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 4f81a3874650d..3f3ec7b59eb3d 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -299,6 +299,7 @@ class UnaryFn:
     square = UnaryFnType("square")
     tanh = UnaryFnType("tanh")
     erf = UnaryFnType("erf")
+    conj = UnaryFnType("conj")
 
 
 class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 254458a978828..10f1083b11758 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -468,16 +468,22 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
     def _unary_exp(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.ExpOp(x).result
+        if _is_complex_type(x.type):
+            return complex.ExpOp(x).result
         raise NotImplementedError("Unsupported 'exp' operand: {x}")
 
     def _unary_log(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.LogOp(x).result
+        if _is_complex_type(x.type):
+            return complex.LogOp(x).result
         raise NotImplementedError("Unsupported 'log' operand: {x}")
 
     def _unary_abs(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.AbsFOp(x).result
+        if _is_complex_type(x.type):
+            return complex.AbsOp(x).result
         raise NotImplementedError("Unsupported 'abs' operand: {x}")
 
     def _unary_ceil(self, x: Value) -> Value:
@@ -497,6 +503,11 @@ def _unary_negf(self, x: Value) -> Value:
             return complex.NegOp(x).result
         raise NotImplementedError("Unsupported 'negf' operand: {x}")
 
+    def _unary_conj(self, x: Value) -> Value:
+        if _is_complex_type(x.type):
+            return complex.ConjOp(x).result
+        raise NotImplementedError("Unsupported 'conj' operand: {x}")
+
     def _binary_add(self, lhs: Value, rhs: Value) -> Value:
         if _is_floating_point_type(lhs.type):
             return arith.AddFOp(lhs, rhs).result
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index f8e034fb0e48b..2afae8b055ed0 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -30,7 +30,7 @@ def test_index(O=TensorDef(I32, S.M, S.N, output=True)):
 
 
 @linalg_structured_op
-def elemwise_unary_poly(
+def elemwise_unary_poly_cast(
     I=TensorDef(T),
     O=TensorDef(U, output=True),
     fun=UnaryFnAttrDef(default=UnaryFn.exp),
@@ -38,6 +38,13 @@ def elemwise_unary_poly(
 ):
     O[None] = fun(cast(U, I[None]))
 
+@linalg_structured_op
+def elemwise_unary_poly(
+    I=TensorDef(T),
+    O=TensorDef(U, output=True),
+    fun=UnaryFnAttrDef(default=UnaryFn.exp),
+):
+    O[None] = fun(I[None])
 
 @linalg_structured_op(op_name="custom_op_name")
 def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
@@ -84,6 +91,17 @@ def test_i32_index(init_result):
         def test_f32_elemwise_exp(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
 
+        # CHECK-LABEL: @test_c32_elemwise_exp
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[EXP:.+]] = complex.exp %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_exp(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
+
         # CHECK-LABEL: @test_f32_elemwise_log
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
         # CHECK-NEXT:   %[[LOG:.+]] = math.log %[[IN]] : f32
@@ -95,10 +113,21 @@ def test_f32_elemwise_exp(input, init_result):
         def test_f32_elemwise_log(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
 
+        # CHECK-LABEL: @test_c32_elemwise_log
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[LOG:.+]] = complex.log %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[LOG]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_log(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+
         # CHECK-LABEL: @test_f32_elemwise_abs
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.absf %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[ABS:.+]] = math.absf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[ABS]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -106,10 +135,21 @@ def test_f32_elemwise_log(input, init_result):
         def test_f32_elemwise_abs(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
 
+        # CHECK-LABEL: @test_c32_elemwise_abs
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[ABS:.+]] = complex.abs %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[ABS]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_c32_elemwise_abs(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+
         # CHECK-LABEL: @test_f32_elemwise_ceil
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.ceil %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[CEIL:.+]] = math.ceil %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[CEIL]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -119,8 +159,8 @@ def test_f32_elemwise_ceil(input, init_result):
 
         # CHECK-LABEL: @test_f32_elemwise_floor
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.floor %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[FLOOR:.+]] = math.floor %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[FLOOR]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -130,8 +170,8 @@ def test_f32_elemwise_floor(input, init_result):
 
         # CHECK-LABEL: @test_f32_elemwise_neg
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = arith.negf %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[NEG:.+]] = arith.negf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[NEG]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -141,8 +181,8 @@ def test_f32_elemwise_neg(input, init_result):
 
         # CHECK-LABEL: @test_c32_elemwise_neg
         # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
-        # CHECK-NEXT:   %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+        # CHECK-NEXT:   %[[NEG:.+]] = complex.neg %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[NEG]] : complex<f32>
         # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
@@ -150,6 +190,17 @@ def test_f32_elemwise_neg(input, init_result):
         def test_c32_elemwise_neg(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
 
+        # CHECK-LABEL: @test_c32_elemwise_conj
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[CONJ:.+]] = complex.conj %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[CONJ]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_conj(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.conj, cast=None)
+
         # Just check that we don't assert out on name mismatch.
         # CHECK-LABEL: @test_non_default_op_name
         @func.FuncOp.from_py_func(

@github-actions
Copy link

github-actions bot commented Oct 9, 2025

✅ With the latest revision this PR passed the Python code formatter.

@rengolin
Copy link
Member

rengolin commented Oct 9, 2025

We're trying to deprecate OpDSL, so adding more ops here is probably ill advised.

Please add support for complex types in linalg.elementwise instead.

FYI: @ftynse @javedabsar1 @nicolasvasilache

@javedabsar1
Copy link
Contributor

We're trying to deprecate OpDSL, so adding more ops here is probably ill advised.

Please add support for complex types in linalg.elementwise instead.

FYI: @ftynse @javedabsar1 @nicolasvasilache

Its our bad little bit that we havent deprecated OpDSL yet. Probably this is not the right place to discuss here, but yes please add this support if possible as linalg.elementwise

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants