-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR] Add conversion support for more ops from ComplexToROCDLLibraryCalls #151166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This patch extends the ComplexToROCDLLibraryCalls pass to support conversion of additional complex number operations to their corresponding ROCDL (ROCm) library calls. The purpose is to enable GPU acceleration for a broader set of complex mathematical operations on AMD GPUs.
- Adds conversion patterns for 9 new complex operations: AngleOp, ConjOp, CosOp, LogOp, PowOp, SinOp, SqrtOp, TanOp, and TanhOp
- Maps each operation to appropriate OCML (Open Compute Math Library) function calls for both f32 and f64 precision
- Includes comprehensive test coverage for all newly supported operations
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| ComplexToROCDLLibraryCalls.cpp | Adds conversion patterns for 9 new complex ops and updates the conversion target to mark them as illegal |
| complex-to-rocdl-library-calls.mlir | Adds test cases for all newly supported complex operations with both f32 and f64 variants |
…Calls This patch adds conversion support for AngleOp, ConjOp, CosOp, LogOp, PowOp, SinOp, SqrtOp, TanOp and TanhOp to the ComplexToROCDLLibraryCalls pass.
|
@llvm/pr-subscribers-mlir Author: Akash Banerjee (TIFitis) ChangesThis patch adds conversion support for AngleOp, ConjOp, CosOp, LogOp, PowOp, SinOp, SqrtOp, TanOp and TanhOp to the ComplexToROCDLLibraryCalls pass. Full diff: https://github.com/llvm/llvm-project/pull/151166.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 6f0fc2965e6fd..35ad99c7791db 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
patterns.getContext(), "__ocml_cabs_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
+ patterns.getContext(), "__ocml_carg_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
+ patterns.getContext(), "__ocml_carg_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
+ patterns.getContext(), "__ocml_conj_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
+ patterns.getContext(), "__ocml_conj_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ccos_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ccos_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
patterns.getContext(), "__ocml_cexp_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
patterns.getContext(), "__ocml_cexp_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
+ patterns.getContext(), "__ocml_clog_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
+ patterns.getContext(), "__ocml_clog_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
+ patterns.getContext(), "__ocml_cpow_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
+ patterns.getContext(), "__ocml_cpow_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csin_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csin_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csqrt_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csqrt_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctan_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctan_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctanh_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctanh_f64");
}
namespace {
@@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+ target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
+ complex::CosOp, complex::ExpOp, complex::LogOp,
+ complex::PowOp, complex::SinOp, complex::SqrtOp,
+ complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index bae7c5986ef9e..ae59f28b46392 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -2,8 +2,26 @@
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32
+// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64
+// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64>
// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_csqrt_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctan_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctan_f64(complex<f64>) -> complex<f64>
+// CHECK-DAG: @__ocml_ctanh_f32(complex<f32>) -> complex<f32>
+// CHECK-DAG: @__ocml_ctanh_f64(complex<f64>) -> complex<f64>
//CHECK-LABEL: @abs_caller
func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
@@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
return %rf, %rd : f32, f64
}
+//CHECK-LABEL: @angle_caller
+func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
+ // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}})
+ %af = complex.angle %f : complex<f32>
+ // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}})
+ %ad = complex.angle %d : complex<f64>
+ // CHECK: return %[[AF]], %[[AD]]
+ return %af, %ad : f32, f64
+}
+
+//CHECK-LABEL: @cos_caller
+func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}})
+ %cf = complex.cos %f : complex<f32>
+ // CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}})
+ %cd = complex.cos %d : complex<f64>
+ // CHECK: return %[[CF]], %[[CD]]
+ return %cf, %cd : complex<f32>, complex<f64>
+}
+
//CHECK-LABEL: @exp_caller
func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
@@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
// CHECK: return %[[EF]], %[[ED]]
return %ef, %ed : complex<f32>, complex<f64>
}
+
+//CHECK-LABEL: @log_caller
+func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}})
+ %lf = complex.log %f : complex<f32>
+ // CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}})
+ %ld = complex.log %d : complex<f64>
+ // CHECK: return %[[LF]], %[[LD]]
+ return %lf, %ld : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @conj_caller
+func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}})
+ %cf2 = complex.conj %f : complex<f32>
+ // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}})
+ %cd2 = complex.conj %d : complex<f64>
+ // CHECK: return %[[CF]], %[[CD]]
+ return %cf2, %cd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @pow_caller
+func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}})
+ %pf = complex.pow %f, %f : complex<f32>
+ // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}})
+ %pd = complex.pow %d, %d : complex<f64>
+ // CHECK: return %[[PF]], %[[PD]]
+ return %pf, %pd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sin_caller
+func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
+ %sf2 = complex.sin %f : complex<f32>
+ // CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}})
+ %sd2 = complex.sin %d : complex<f64>
+ // CHECK: return %[[SF]], %[[SD]]
+ return %sf2, %sd2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @sqrt_caller
+func.func @sqrt_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}})
+ %sf = complex.sqrt %f : complex<f32>
+ // CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}})
+ %sd = complex.sqrt %d : complex<f64>
+ // CHECK: return %[[SF]], %[[SD]]
+ return %sf, %sd : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tan_caller
+func.func @tan_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}})
+ %tf2 = complex.tan %f : complex<f32>
+ // CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}})
+ %td2 = complex.tan %d : complex<f64>
+ // CHECK: return %[[TF]], %[[TD]]
+ return %tf2, %td2 : complex<f32>, complex<f64>
+}
+
+//CHECK-LABEL: @tanh_caller
+func.func @tanh_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
+ // CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}})
+ %tf = complex.tanh %f : complex<f32>
+ // CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}})
+ %td = complex.tanh %d : complex<f64>
+ // CHECK: return %[[TF]], %[[TD]]
+ return %tf, %td : complex<f32>, complex<f64>
+}
|
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved, thank you.
This patch adds conversion support for AngleOp, ConjOp, CosOp, LogOp, PowOp, SinOp, SqrtOp, TanOp and TanhOp to the ComplexToROCDLLibraryCalls pass.