Skip to content

Conversation

@TIFitis
Copy link
Member

@TIFitis TIFitis commented Jul 29, 2025

This patch adds conversion support for AngleOp, ConjOp, CosOp, LogOp, PowOp, SinOp, SqrtOp, TanOp and TanhOp to the ComplexToROCDLLibraryCalls pass.

@TIFitis TIFitis requested review from Copilot, jsjodin and krzysz00 July 29, 2025 15:01
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This patch extends the ComplexToROCDLLibraryCalls pass to support conversion of additional complex number operations to their corresponding ROCDL (ROCm) library calls. The purpose is to enable GPU acceleration for a broader set of complex mathematical operations on AMD GPUs.

  • Adds conversion patterns for 9 new complex operations: AngleOp, ConjOp, CosOp, LogOp, PowOp, SinOp, SqrtOp, TanOp, and TanhOp
  • Maps each operation to appropriate OCML (Open Compute Math Library) function calls for both f32 and f64 precision
  • Includes comprehensive test coverage for all newly supported operations

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
ComplexToROCDLLibraryCalls.cpp Adds conversion patterns for 9 new complex ops and updates the conversion target to mark them as illegal
complex-to-rocdl-library-calls.mlir Adds test cases for all newly supported complex operations with both f32 and f64 variants

@llvmbot llvmbot added the mlir label Jul 29, 2025
…Calls

This patch adds conversion support for AngleOp, ConjOp, CosOp, LogOp, PowOp, SinOp, SqrtOp, TanOp and TanhOp to the ComplexToROCDLLibraryCalls pass.
@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2025

@llvm/pr-subscribers-mlir

Author: Akash Banerjee (TIFitis)

Changes

This patch adds conversion support for AngleOp, ConjOp, CosOp, LogOp, PowOp, SinOp, SqrtOp, TanOp and TanhOp to the ComplexToROCDLLibraryCalls pass.


Full diff: https://github.com/llvm/llvm-project/pull/151166.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp (+40-1)
  • (modified) mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir (+108)
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 6f0fc2965e6fd..35ad99c7791db 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
       patterns.getContext(), "__ocml_cabs_f32");
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
       patterns.getContext(), "__ocml_cabs_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
+      patterns.getContext(), "__ocml_carg_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
+      patterns.getContext(), "__ocml_carg_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
+      patterns.getContext(), "__ocml_conj_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
+      patterns.getContext(), "__ocml_conj_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
+      patterns.getContext(), "__ocml_ccos_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
+      patterns.getContext(), "__ocml_ccos_f64");
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
       patterns.getContext(), "__ocml_cexp_f32");
   patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
       patterns.getContext(), "__ocml_cexp_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
+      patterns.getContext(), "__ocml_clog_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
+      patterns.getContext(), "__ocml_clog_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
+      patterns.getContext(), "__ocml_cpow_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
+      patterns.getContext(), "__ocml_cpow_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
+      patterns.getContext(), "__ocml_csin_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
+      patterns.getContext(), "__ocml_csin_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
+      patterns.getContext(), "__ocml_csqrt_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
+      patterns.getContext(), "__ocml_csqrt_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
+      patterns.getContext(), "__ocml_ctan_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
+      patterns.getContext(), "__ocml_ctan_f64");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
+      patterns.getContext(), "__ocml_ctanh_f32");
+  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
+      patterns.getContext(), "__ocml_ctanh_f64");
 }
 
 namespace {
@@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
 
   ConversionTarget target(getContext());
   target.addLegalDialect<func::FuncDialect>();
-  target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+  target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
+                      complex::CosOp, complex::ExpOp, complex::LogOp,
+                      complex::PowOp, complex::SinOp, complex::SqrtOp,
+                      complex::TanOp, complex::TanhOp>();
   if (failed(applyPartialConversion(op, target, std::move(patterns))))
     signalPassFailure();
 }
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index bae7c5986ef9e..ae59f28b46392 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -2,8 +2,26 @@
 
 // CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
 // CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64>
 // CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
 // CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csqrt_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctan_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctan_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctanh_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctanh_f64(complex<f64>) -> complex<f64>
 
 //CHECK-LABEL: @abs_caller
 func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
@@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
   return %rf, %rd : f32, f64
 }
 
+//CHECK-LABEL: @angle_caller
+func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+  // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}})
+  %af = complex.angle %f : complex<f32>
+  // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}})
+  %ad = complex.angle %d : complex<f64>
+  // CHECK: return %[[AF]], %[[AD]]
+  return %af, %ad : f32, f64
+}
+
+//CHECK-LABEL: @cos_caller
+func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}})
+  %cf = complex.cos %f : complex<f32>
+  // CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}})
+  %cd = complex.cos %d : complex<f64>
+  // CHECK: return %[[CF]], %[[CD]]
+  return %cf, %cd : complex<f32>, complex<f64>
+}
+
 //CHECK-LABEL: @exp_caller
 func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
   // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
@@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
   // CHECK: return %[[EF]], %[[ED]]
   return %ef, %ed : complex<f32>, complex<f64>
 }
+
+//CHECK-LABEL: @log_caller
+func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}})
+  %lf = complex.log %f : complex<f32>
+  // CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}})
+  %ld = complex.log %d : complex<f64>
+  // CHECK: return %[[LF]], %[[LD]]
+  return %lf, %ld : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @conj_caller
+func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}})
+  %cf2 = complex.conj %f : complex<f32>
+  // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}})
+  %cd2 = complex.conj %d : complex<f64>
+  // CHECK: return %[[CF]], %[[CD]]
+  return %cf2, %cd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @pow_caller
+func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}})
+  %pf = complex.pow %f, %f : complex<f32>
+  // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}})
+  %pd = complex.pow %d, %d : complex<f64>
+  // CHECK: return %[[PF]], %[[PD]]
+  return %pf, %pd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sin_caller
+func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
+  %sf2 = complex.sin %f : complex<f32>
+  // CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}})
+  %sd2 = complex.sin %d : complex<f64>
+  // CHECK: return %[[SF]], %[[SD]]
+  return %sf2, %sd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sqrt_caller
+func.func @sqrt_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}})
+  %sf = complex.sqrt %f : complex<f32>
+  // CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}})
+  %sd = complex.sqrt %d : complex<f64>
+  // CHECK: return %[[SF]], %[[SD]]
+  return %sf, %sd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tan_caller
+func.func @tan_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}})
+  %tf2 = complex.tan %f : complex<f32>
+  // CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}})
+  %td2 = complex.tan %d : complex<f64>
+  // CHECK: return %[[TF]], %[[TD]]
+  return %tf2, %td2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tanh_caller
+func.func @tanh_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+  // CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}})
+  %tf = complex.tanh %f : complex<f32>
+  // CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}})
+  %td = complex.tanh %d : complex<f64>
+  // CHECK: return %[[TF]], %[[TD]]
+  return %tf, %td : complex<f32>, complex<f64>
+}

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved, thank you.

@TIFitis TIFitis merged commit 0a4c652 into llvm:main Jul 29, 2025
9 checks passed
@TIFitis TIFitis deleted the complexrocdl branch August 21, 2025 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants