diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 6f0fc2965e6fd..35ad99c7791db 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_cabs_f32"); patterns.add>( patterns.getContext(), "__ocml_cabs_f64"); + patterns.add>( + patterns.getContext(), "__ocml_carg_f32"); + patterns.add>( + patterns.getContext(), "__ocml_carg_f64"); + patterns.add>( + patterns.getContext(), "__ocml_conj_f32"); + patterns.add>( + patterns.getContext(), "__ocml_conj_f64"); + patterns.add>( + patterns.getContext(), "__ocml_ccos_f32"); + patterns.add>( + patterns.getContext(), "__ocml_ccos_f64"); patterns.add>( patterns.getContext(), "__ocml_cexp_f32"); patterns.add>( patterns.getContext(), "__ocml_cexp_f64"); + patterns.add>( + patterns.getContext(), "__ocml_clog_f32"); + patterns.add>( + patterns.getContext(), "__ocml_clog_f64"); + patterns.add>( + patterns.getContext(), "__ocml_cpow_f32"); + patterns.add>( + patterns.getContext(), "__ocml_cpow_f64"); + patterns.add>( + patterns.getContext(), "__ocml_csin_f32"); + patterns.add>( + patterns.getContext(), "__ocml_csin_f64"); + patterns.add>( + patterns.getContext(), "__ocml_csqrt_f32"); + patterns.add>( + patterns.getContext(), "__ocml_csqrt_f64"); + patterns.add>( + patterns.getContext(), "__ocml_ctan_f32"); + patterns.add>( + patterns.getContext(), "__ocml_ctan_f64"); + patterns.add>( + patterns.getContext(), "__ocml_ctanh_f32"); + patterns.add>( + patterns.getContext(), "__ocml_ctanh_f64"); } namespace { @@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index bae7c5986ef9e..ae59f28b46392 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -2,8 +2,26 @@ // CHECK-DAG: @__ocml_cabs_f32(complex) -> f32 // CHECK-DAG: @__ocml_cabs_f64(complex) -> f64 +// CHECK-DAG: @__ocml_carg_f32(complex) -> f32 +// CHECK-DAG: @__ocml_carg_f64(complex) -> f64 +// CHECK-DAG: @__ocml_ccos_f32(complex) -> complex +// CHECK-DAG: @__ocml_ccos_f64(complex) -> complex // CHECK-DAG: @__ocml_cexp_f32(complex) -> complex // CHECK-DAG: @__ocml_cexp_f64(complex) -> complex +// CHECK-DAG: @__ocml_clog_f32(complex) -> complex +// CHECK-DAG: @__ocml_clog_f64(complex) -> complex +// CHECK-DAG: @__ocml_conj_f32(complex) -> complex +// CHECK-DAG: @__ocml_conj_f64(complex) -> complex +// CHECK-DAG: @__ocml_cpow_f32(complex, complex) -> complex +// CHECK-DAG: @__ocml_cpow_f64(complex, complex) -> complex +// CHECK-DAG: @__ocml_csin_f32(complex) -> complex +// CHECK-DAG: @__ocml_csin_f64(complex) -> complex +// CHECK-DAG: @__ocml_csqrt_f32(complex) -> complex +// CHECK-DAG: @__ocml_csqrt_f64(complex) -> complex +// CHECK-DAG: @__ocml_ctan_f32(complex) -> complex +// CHECK-DAG: @__ocml_ctan_f64(complex) -> complex +// CHECK-DAG: @__ocml_ctanh_f32(complex) -> complex +// CHECK-DAG: @__ocml_ctanh_f64(complex) -> complex //CHECK-LABEL: @abs_caller func.func @abs_caller(%f: complex, %d: complex) -> (f32, f64) { @@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex, %d: complex) -> (f32, f64) { return %rf, %rd : f32, f64 } +//CHECK-LABEL: @angle_caller +func.func @angle_caller(%f: complex, %d: complex) -> (f32, f64) { + // CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}}) + %af = complex.angle %f : complex + // CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}}) + %ad = complex.angle %d : complex + // CHECK: return %[[AF]], %[[AD]] + return %af, %ad : f32, f64 +} + +//CHECK-LABEL: @cos_caller +func.func @cos_caller(%f: complex, %d: complex) -> (complex, complex) { + // CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}}) + %cf = complex.cos %f : complex + // CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}}) + %cd = complex.cos %d : complex + // CHECK: return %[[CF]], %[[CD]] + return %cf, %cd : complex, complex +} + //CHECK-LABEL: @exp_caller func.func @exp_caller(%f: complex, %d: complex) -> (complex, complex) { // CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}}) @@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex, %d: complex) -> (complex, comp // CHECK: return %[[EF]], %[[ED]] return %ef, %ed : complex, complex } + +//CHECK-LABEL: @log_caller +func.func @log_caller(%f: complex, %d: complex) -> (complex, complex) { + // CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}}) + %lf = complex.log %f : complex + // CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}}) + %ld = complex.log %d : complex + // CHECK: return %[[LF]], %[[LD]] + return %lf, %ld : complex, complex +} + +//CHECK-LABEL: @conj_caller +func.func @conj_caller(%f: complex, %d: complex) -> (complex, complex) { + // CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}}) + %cf2 = complex.conj %f : complex + // CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}}) + %cd2 = complex.conj %d : complex + // CHECK: return %[[CF]], %[[CD]] + return %cf2, %cd2 : complex, complex +} + +//CHECK-LABEL: @pow_caller +func.func @pow_caller(%f: complex, %d: complex) -> (complex, complex) { + // CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}}) + %pf = complex.pow %f, %f : complex + // CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}}) + %pd = complex.pow %d, %d : complex + // CHECK: return %[[PF]], %[[PD]] + return %pf, %pd : complex, complex +} + +//CHECK-LABEL: @sin_caller +func.func @sin_caller(%f: complex, %d: complex) -> (complex, complex) { + // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}}) + %sf2 = complex.sin %f : complex + // CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}}) + %sd2 = complex.sin %d : complex + // CHECK: return %[[SF]], %[[SD]] + return %sf2, %sd2 : complex, complex +} + +//CHECK-LABEL: @sqrt_caller +func.func @sqrt_caller(%f: complex, %d: complex) -> (complex, complex) { + // CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}}) + %sf = complex.sqrt %f : complex + // CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}}) + %sd = complex.sqrt %d : complex + // CHECK: return %[[SF]], %[[SD]] + return %sf, %sd : complex, complex +} + +//CHECK-LABEL: @tan_caller +func.func @tan_caller(%f: complex, %d: complex) -> (complex, complex) { + // CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}}) + %tf2 = complex.tan %f : complex + // CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}}) + %td2 = complex.tan %d : complex + // CHECK: return %[[TF]], %[[TD]] + return %tf2, %td2 : complex, complex +} + +//CHECK-LABEL: @tanh_caller +func.func @tanh_caller(%f: complex, %d: complex) -> (complex, complex) { + // CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}}) + %tf = complex.tanh %f : complex + // CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}}) + %td = complex.tanh %d : complex + // CHECK: return %[[TF]], %[[TD]] + return %tf, %td : complex, complex +}