Skip to content

Commit 0a4c652

Browse files
authored
[MLIR] Add conversion support for more ops from ComplexToROCDLLibraryCalls (llvm#151166)
1 parent 2abd58c commit 0a4c652

File tree

2 files changed

+148
-1
lines changed

2 files changed

+148
-1
lines changed

mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
6464
patterns.getContext(), "__ocml_cabs_f32");
6565
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
6666
patterns.getContext(), "__ocml_cabs_f64");
67+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
68+
patterns.getContext(), "__ocml_carg_f32");
69+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
70+
patterns.getContext(), "__ocml_carg_f64");
71+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
72+
patterns.getContext(), "__ocml_conj_f32");
73+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
74+
patterns.getContext(), "__ocml_conj_f64");
75+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
76+
patterns.getContext(), "__ocml_ccos_f32");
77+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
78+
patterns.getContext(), "__ocml_ccos_f64");
6779
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
6880
patterns.getContext(), "__ocml_cexp_f32");
6981
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
7082
patterns.getContext(), "__ocml_cexp_f64");
83+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
84+
patterns.getContext(), "__ocml_clog_f32");
85+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
86+
patterns.getContext(), "__ocml_clog_f64");
87+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
88+
patterns.getContext(), "__ocml_cpow_f32");
89+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
90+
patterns.getContext(), "__ocml_cpow_f64");
91+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
92+
patterns.getContext(), "__ocml_csin_f32");
93+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
94+
patterns.getContext(), "__ocml_csin_f64");
95+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
96+
patterns.getContext(), "__ocml_csqrt_f32");
97+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
98+
patterns.getContext(), "__ocml_csqrt_f64");
99+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
100+
patterns.getContext(), "__ocml_ctan_f32");
101+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
102+
patterns.getContext(), "__ocml_ctan_f64");
103+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
104+
patterns.getContext(), "__ocml_ctanh_f32");
105+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
106+
patterns.getContext(), "__ocml_ctanh_f64");
71107
}
72108

73109
namespace {
@@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
86122

87123
ConversionTarget target(getContext());
88124
target.addLegalDialect<func::FuncDialect>();
89-
target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
125+
target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
126+
complex::CosOp, complex::ExpOp, complex::LogOp,
127+
complex::PowOp, complex::SinOp, complex::SqrtOp,
128+
complex::TanOp, complex::TanhOp>();
90129
if (failed(applyPartialConversion(op, target, std::move(patterns))))
91130
signalPassFailure();
92131
}

mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,26 @@
22

33
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
44
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
5+
// CHECK-DAG: @__ocml_carg_f32(complex<f32>) -> f32
6+
// CHECK-DAG: @__ocml_carg_f64(complex<f64>) -> f64
7+
// CHECK-DAG: @__ocml_ccos_f32(complex<f32>) -> complex<f32>
8+
// CHECK-DAG: @__ocml_ccos_f64(complex<f64>) -> complex<f64>
59
// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
610
// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
11+
// CHECK-DAG: @__ocml_clog_f32(complex<f32>) -> complex<f32>
12+
// CHECK-DAG: @__ocml_clog_f64(complex<f64>) -> complex<f64>
13+
// CHECK-DAG: @__ocml_conj_f32(complex<f32>) -> complex<f32>
14+
// CHECK-DAG: @__ocml_conj_f64(complex<f64>) -> complex<f64>
15+
// CHECK-DAG: @__ocml_cpow_f32(complex<f32>, complex<f32>) -> complex<f32>
16+
// CHECK-DAG: @__ocml_cpow_f64(complex<f64>, complex<f64>) -> complex<f64>
17+
// CHECK-DAG: @__ocml_csin_f32(complex<f32>) -> complex<f32>
18+
// CHECK-DAG: @__ocml_csin_f64(complex<f64>) -> complex<f64>
19+
// CHECK-DAG: @__ocml_csqrt_f32(complex<f32>) -> complex<f32>
20+
// CHECK-DAG: @__ocml_csqrt_f64(complex<f64>) -> complex<f64>
21+
// CHECK-DAG: @__ocml_ctan_f32(complex<f32>) -> complex<f32>
22+
// CHECK-DAG: @__ocml_ctan_f64(complex<f64>) -> complex<f64>
23+
// CHECK-DAG: @__ocml_ctanh_f32(complex<f32>) -> complex<f32>
24+
// CHECK-DAG: @__ocml_ctanh_f64(complex<f64>) -> complex<f64>
725

826
//CHECK-LABEL: @abs_caller
927
func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
@@ -15,6 +33,26 @@ func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
1533
return %rf, %rd : f32, f64
1634
}
1735

36+
//CHECK-LABEL: @angle_caller
37+
func.func @angle_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
38+
// CHECK: %[[AF:.*]] = call @__ocml_carg_f32(%{{.*}})
39+
%af = complex.angle %f : complex<f32>
40+
// CHECK: %[[AD:.*]] = call @__ocml_carg_f64(%{{.*}})
41+
%ad = complex.angle %d : complex<f64>
42+
// CHECK: return %[[AF]], %[[AD]]
43+
return %af, %ad : f32, f64
44+
}
45+
46+
//CHECK-LABEL: @cos_caller
47+
func.func @cos_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
48+
// CHECK: %[[CF:.*]] = call @__ocml_ccos_f32(%{{.*}})
49+
%cf = complex.cos %f : complex<f32>
50+
// CHECK: %[[CD:.*]] = call @__ocml_ccos_f64(%{{.*}})
51+
%cd = complex.cos %d : complex<f64>
52+
// CHECK: return %[[CF]], %[[CD]]
53+
return %cf, %cd : complex<f32>, complex<f64>
54+
}
55+
1856
//CHECK-LABEL: @exp_caller
1957
func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
2058
// CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
@@ -24,3 +62,73 @@ func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
2462
// CHECK: return %[[EF]], %[[ED]]
2563
return %ef, %ed : complex<f32>, complex<f64>
2664
}
65+
66+
//CHECK-LABEL: @log_caller
67+
func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
68+
// CHECK: %[[LF:.*]] = call @__ocml_clog_f32(%{{.*}})
69+
%lf = complex.log %f : complex<f32>
70+
// CHECK: %[[LD:.*]] = call @__ocml_clog_f64(%{{.*}})
71+
%ld = complex.log %d : complex<f64>
72+
// CHECK: return %[[LF]], %[[LD]]
73+
return %lf, %ld : complex<f32>, complex<f64>
74+
}
75+
76+
//CHECK-LABEL: @conj_caller
77+
func.func @conj_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
78+
// CHECK: %[[CF:.*]] = call @__ocml_conj_f32(%{{.*}})
79+
%cf2 = complex.conj %f : complex<f32>
80+
// CHECK: %[[CD:.*]] = call @__ocml_conj_f64(%{{.*}})
81+
%cd2 = complex.conj %d : complex<f64>
82+
// CHECK: return %[[CF]], %[[CD]]
83+
return %cf2, %cd2 : complex<f32>, complex<f64>
84+
}
85+
86+
//CHECK-LABEL: @pow_caller
87+
func.func @pow_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
88+
// CHECK: %[[PF:.*]] = call @__ocml_cpow_f32(%{{.*}}, %{{.*}})
89+
%pf = complex.pow %f, %f : complex<f32>
90+
// CHECK: %[[PD:.*]] = call @__ocml_cpow_f64(%{{.*}}, %{{.*}})
91+
%pd = complex.pow %d, %d : complex<f64>
92+
// CHECK: return %[[PF]], %[[PD]]
93+
return %pf, %pd : complex<f32>, complex<f64>
94+
}
95+
96+
//CHECK-LABEL: @sin_caller
97+
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
98+
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
99+
%sf2 = complex.sin %f : complex<f32>
100+
// CHECK: %[[SD:.*]] = call @__ocml_csin_f64(%{{.*}})
101+
%sd2 = complex.sin %d : complex<f64>
102+
// CHECK: return %[[SF]], %[[SD]]
103+
return %sf2, %sd2 : complex<f32>, complex<f64>
104+
}
105+
106+
//CHECK-LABEL: @sqrt_caller
107+
func.func @sqrt_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
108+
// CHECK: %[[SF:.*]] = call @__ocml_csqrt_f32(%{{.*}})
109+
%sf = complex.sqrt %f : complex<f32>
110+
// CHECK: %[[SD:.*]] = call @__ocml_csqrt_f64(%{{.*}})
111+
%sd = complex.sqrt %d : complex<f64>
112+
// CHECK: return %[[SF]], %[[SD]]
113+
return %sf, %sd : complex<f32>, complex<f64>
114+
}
115+
116+
//CHECK-LABEL: @tan_caller
117+
func.func @tan_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
118+
// CHECK: %[[TF:.*]] = call @__ocml_ctan_f32(%{{.*}})
119+
%tf2 = complex.tan %f : complex<f32>
120+
// CHECK: %[[TD:.*]] = call @__ocml_ctan_f64(%{{.*}})
121+
%td2 = complex.tan %d : complex<f64>
122+
// CHECK: return %[[TF]], %[[TD]]
123+
return %tf2, %td2 : complex<f32>, complex<f64>
124+
}
125+
126+
//CHECK-LABEL: @tanh_caller
127+
func.func @tanh_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
128+
// CHECK: %[[TF:.*]] = call @__ocml_ctanh_f32(%{{.*}})
129+
%tf = complex.tanh %f : complex<f32>
130+
// CHECK: %[[TD:.*]] = call @__ocml_ctanh_f64(%{{.*}})
131+
%td = complex.tanh %d : complex<f64>
132+
// CHECK: return %[[TF]], %[[TD]]
133+
return %tf, %td : complex<f32>, complex<f64>
134+
}

0 commit comments

Comments
 (0)