Skip to content

Commit c1f4107

Browse files
committed
Revert "[NFC] Remove invalid conversions in ComplexToROCDLLibraryCalls"
This reverts commit b8104fa.
1 parent ef2b880 commit c1f4107

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
6464
patterns.getContext(), "__ocml_cabs_f32");
6565
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
6666
patterns.getContext(), "__ocml_cabs_f64");
67+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
68+
patterns.getContext(), "__ocml_carg_f32");
69+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
70+
patterns.getContext(), "__ocml_carg_f64");
71+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
72+
patterns.getContext(), "__ocml_conj_f32");
73+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
74+
patterns.getContext(), "__ocml_conj_f64");
6775
patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
6876
patterns.getContext(), "__ocml_ccos_f32");
6977
patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
@@ -76,6 +84,10 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
7684
patterns.getContext(), "__ocml_clog_f32");
7785
patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
7886
patterns.getContext(), "__ocml_clog_f64");
87+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
88+
patterns.getContext(), "__ocml_cpow_f32");
89+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
90+
patterns.getContext(), "__ocml_cpow_f64");
7991
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
8092
patterns.getContext(), "__ocml_csin_f32");
8193
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
@@ -110,8 +122,9 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
110122

111123
ConversionTarget target(getContext());
112124
target.addLegalDialect<func::FuncDialect>();
113-
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
114-
complex::LogOp, complex::SinOp, complex::SqrtOp,
125+
target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
126+
complex::CosOp, complex::ExpOp, complex::LogOp,
127+
complex::PowOp, complex::SinOp, complex::SqrtOp,
115128
complex::TanOp, complex::TanhOp>();
116129
if (failed(applyPartialConversion(op, target, std::move(patterns))))
117130
signalPassFailure();

mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,16 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
3333
return %rf, %rd : f32, f64
3434
}
3535

36+
//CHECK-LABEL: @angle_caller
37+
func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
38+
// CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}})
39+
%af = complex.angle %f : complex<f32>
40+
// CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}})
41+
%ad = complex.angle %d : complex<f64>
42+
// CHECK: return %[[AF]], %[[AD]]
43+
return %af, %ad : f32, f64
44+
}
45+
3646
//CHECK-LABEL: @cos_caller
3747
func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
3848
// CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}})
@@ -63,6 +73,26 @@ func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
6373
return %lf, %ld : complex<f32>, complex<f64>
6474
}
6575

76+
//CHECK-LABEL: @conj_caller
77+
func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
78+
// CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}})
79+
%cf2 = complex.conj %f : complex<f32>
80+
// CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}})
81+
%cd2 = complex.conj %d : complex<f64>
82+
// CHECK: return %[[CF]], %[[CD]]
83+
return %cf2, %cd2 : complex<f32>, complex<f64>
84+
}
85+
86+
//CHECK-LABEL: @pow_caller
87+
func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
88+
// CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}})
89+
%pf = complex.pow %f, %f : complex<f32>
90+
// CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}})
91+
%pd = complex.pow %d, %d : complex<f64>
92+
// CHECK: return %[[PF]], %[[PD]]
93+
return %pf, %pd : complex<f32>, complex<f64>
94+
}
95+
6696
//CHECK-LABEL: @sin_caller
6797
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
6898
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})

0 commit comments

Comments
 (0)