Skip to content

Commit 8c9a156

Browse files
committed
Add complex.powi op.
1 parent c8715a1 commit 8c9a156

File tree

14 files changed

+198
-106
lines changed

14 files changed

+198
-106
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,14 +1331,20 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
13311331
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
13321332
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
13331333
mlir::Value exp = args[1];
1334-
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
1335-
auto realTy = complexTy.getElementType();
1336-
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
1337-
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
1338-
exp =
1339-
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
1334+
mlir::Value result;
1335+
if (mlir::isa<mlir::IntegerType>(exp.getType()) ||
1336+
mlir::isa<mlir::IndexType>(exp.getType())) {
1337+
result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp);
1338+
} else {
1339+
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
1340+
auto realTy = complexTy.getElementType();
1341+
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
1342+
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
1343+
exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp,
1344+
zero);
1345+
}
1346+
result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
13401347
}
1341-
mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
13421348
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
13431349
return result;
13441350
}

flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp

Lines changed: 42 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -61,63 +61,55 @@ void ConvertComplexPowPass::runOnOperation() {
6161

6262
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
6363

64-
mod.walk([&](complex::PowOp op) {
64+
mod.walk([&](complex::PowiOp op) {
6565
builder.setInsertionPoint(op);
6666
Location loc = op.getLoc();
6767
auto complexTy = cast<ComplexType>(op.getType());
6868
auto elemTy = complexTy.getElementType();
69-
7069
Value base = op.getLhs();
71-
Value rhs = op.getRhs();
72-
73-
Value intExp;
74-
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
75-
if (isZero(create.getImaginary())) {
76-
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
77-
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
78-
intExp = conv.getValue();
79-
}
80-
}
81-
}
82-
70+
Value intExp = op.getRhs();
8371
func::FuncOp callee;
84-
SmallVector<Value> args;
85-
if (intExp) {
86-
unsigned realBits = cast<FloatType>(elemTy).getWidth();
87-
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
88-
auto funcTy = builder.getFunctionType(
89-
{complexTy, builder.getIntegerType(intBits)}, {complexTy});
90-
if (realBits == 32 && intBits == 32)
91-
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
92-
else if (realBits == 32 && intBits == 64)
93-
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
94-
else if (realBits == 64 && intBits == 32)
95-
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
96-
else if (realBits == 64 && intBits == 64)
97-
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
98-
else if (realBits == 128 && intBits == 32)
99-
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
100-
else if (realBits == 128 && intBits == 64)
101-
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
102-
else
103-
return;
104-
args = {base, intExp};
105-
} else {
106-
unsigned realBits = cast<FloatType>(elemTy).getWidth();
107-
auto funcTy =
108-
builder.getFunctionType({complexTy, complexTy}, {complexTy});
109-
if (realBits == 32)
110-
callee = getOrDeclare(builder, loc, "cpowf", funcTy);
111-
else if (realBits == 64)
112-
callee = getOrDeclare(builder, loc, "cpow", funcTy);
113-
else if (realBits == 128)
114-
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
115-
else
116-
return;
117-
args = {base, rhs};
118-
}
72+
unsigned realBits = cast<FloatType>(elemTy).getWidth();
73+
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
74+
auto funcTy = builder.getFunctionType(
75+
{complexTy, builder.getIntegerType(intBits)}, {complexTy});
76+
if (realBits == 32 && intBits == 32)
77+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
78+
else if (realBits == 32 && intBits == 64)
79+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
80+
else if (realBits == 64 && intBits == 32)
81+
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
82+
else if (realBits == 64 && intBits == 64)
83+
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
84+
else if (realBits == 128 && intBits == 32)
85+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
86+
else if (realBits == 128 && intBits == 64)
87+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
88+
else
89+
return;
90+
auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
91+
op.replaceAllUsesWith(call.getResult(0));
92+
op.erase();
93+
});
11994

120-
auto call = fir::CallOp::create(builder, loc, callee, args);
95+
mod.walk([&](complex::PowOp op) {
96+
builder.setInsertionPoint(op);
97+
Location loc = op.getLoc();
98+
auto complexTy = cast<ComplexType>(op.getType());
99+
auto elemTy = complexTy.getElementType();
100+
unsigned realBits = cast<FloatType>(elemTy).getWidth();
101+
func::FuncOp callee;
102+
auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy});
103+
if (realBits == 32)
104+
callee = getOrDeclare(builder, loc, "cpowf", funcTy);
105+
else if (realBits == 64)
106+
callee = getOrDeclare(builder, loc, "cpow", funcTy);
107+
else if (realBits == 128)
108+
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
109+
else
110+
return;
111+
auto call =
112+
fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()});
121113
op.replaceAllUsesWith(call.getResult(0));
122114
op.erase();
123115
});

flang/test/Lower/HLFIR/binary-ops.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
193193
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
194194
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
195195
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
196-
! CHECK: %[[VAL_8:.*]] = complex.pow
196+
! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32
197197

198198
subroutine extremum(c, n, l)
199199
integer(8), intent(in) :: l

flang/test/Lower/Intrinsics/pow_complex16i.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
55

66
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
7-
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
7+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
88
complex(16) :: a
99
integer(4) :: b
1010
b = a ** b

flang/test/Lower/Intrinsics/pow_complex16k.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
55

66
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
7-
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
7+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
88
complex(16) :: a
99
integer(8) :: b
1010
b = a ** b

flang/test/Lower/amdgcn-complex.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
2525
complex :: a, b, c
2626
a = b**c
2727
end subroutine pow_test
28+
29+
! CHECK-LABEL: func @_QPpowi_test(
30+
! CHECK: complex.powi
31+
! CHECK-NOT: fir.call @_FortranAcpowi
32+
subroutine powi_test(a, b, c)
33+
complex :: a, b
34+
integer :: i
35+
b = a ** i
36+
end subroutine powi_test

flang/test/Lower/power-operator.f90

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
9696
complex :: x, z
9797
integer :: y
9898
z = x ** y
99-
! CHECK: complex.pow
99+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
100100
! PRECISE: fir.call @_FortranAcpowi
101101
end subroutine
102102

@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
105105
complex :: x, z
106106
integer(8) :: y
107107
z = x ** y
108-
! CHECK: complex.pow
108+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
109109
! PRECISE: fir.call @_FortranAcpowk
110110
end subroutine
111111

@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
114114
complex(8) :: x, z
115115
integer :: y
116116
z = x ** y
117-
! CHECK: complex.pow
117+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
118118
! PRECISE: fir.call @_FortranAzpowi
119119
end subroutine
120120

@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
123123
complex(8) :: x, z
124124
integer(8) :: y
125125
z = x ** y
126-
! CHECK: complex.pow
126+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
127127
! PRECISE: fir.call @_FortranAzpowk
128128
end subroutine
129129

@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
142142
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
143143
! PRECISE: fir.call @cpow
144144
end subroutine
145-

flang/test/Transforms/convert-complex-pow.fir

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,12 @@
22

33
module {
44
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
5-
%c0 = arith.constant 0.000000e+00 : f32
6-
%c1 = fir.convert %arg1 : (i32) -> f32
7-
%c2 = complex.create %c1, %c0 : complex<f32>
8-
%0 = complex.pow %arg0, %c2 : complex<f32>
5+
%0 = complex.powi %arg0, %arg1 : complex<f32>, i32
96
return %0 : complex<f32>
107
}
118

129
func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
13-
%c0 = arith.constant 0.000000e+00 : f32
14-
%c1 = fir.convert %arg1 : (i64) -> f32
15-
%c2 = complex.create %c1, %c0 : complex<f32>
16-
%0 = complex.pow %arg0, %c2 : complex<f32>
10+
%0 = complex.powi %arg0, %arg1 : complex<f32>, i64
1711
return %0 : complex<f32>
1812
}
1913

@@ -23,18 +17,12 @@ module {
2317
}
2418

2519
func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
26-
%c0 = arith.constant 0.000000e+00 : f64
27-
%c1 = fir.convert %arg1 : (i32) -> f64
28-
%c2 = complex.create %c1, %c0 : complex<f64>
29-
%0 = complex.pow %arg0, %c2 : complex<f64>
20+
%0 = complex.powi %arg0, %arg1 : complex<f64>, i32
3021
return %0 : complex<f64>
3122
}
3223

3324
func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
34-
%c0 = arith.constant 0.000000e+00 : f64
35-
%c1 = fir.convert %arg1 : (i64) -> f64
36-
%c2 = complex.create %c1, %c0 : complex<f64>
37-
%0 = complex.pow %arg0, %c2 : complex<f64>
25+
%0 = complex.powi %arg0, %arg1 : complex<f64>, i64
3826
return %0 : complex<f64>
3927
}
4028

@@ -44,18 +32,12 @@ module {
4432
}
4533

4634
func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
47-
%c0 = arith.constant 0.000000e+00 : f128
48-
%c1 = fir.convert %arg1 : (i32) -> f128
49-
%c2 = complex.create %c1, %c0 : complex<f128>
50-
%0 = complex.pow %arg0, %c2 : complex<f128>
35+
%0 = complex.powi %arg0, %arg1 : complex<f128>, i32
5136
return %0 : complex<f128>
5237
}
5338

5439
func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
55-
%c0 = arith.constant 0.000000e+00 : f128
56-
%c1 = fir.convert %arg1 : (i64) -> f128
57-
%c2 = complex.create %c1, %c0 : complex<f128>
58-
%0 = complex.pow %arg0, %c2 : complex<f128>
40+
%0 = complex.powi %arg0, %arg1 : complex<f128>, i64
5941
return %0 : complex<f128>
6042
}
6143

@@ -67,35 +49,35 @@ module {
6749

6850
// CHECK-LABEL: func.func @pow_c4_i4(
6951
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
70-
// CHECK-NOT: complex.pow
52+
// CHECK-NOT: complex.powi
7153

7254
// CHECK-LABEL: func.func @pow_c4_i8(
7355
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
74-
// CHECK-NOT: complex.pow
56+
// CHECK-NOT: complex.powi
7557

7658
// CHECK-LABEL: func.func @pow_c4_c4(
7759
// CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex<f32>, complex<f32>) -> complex<f32>
7860
// CHECK-NOT: complex.pow
7961

8062
// CHECK-LABEL: func.func @pow_c8_i4(
8163
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
82-
// CHECK-NOT: complex.pow
64+
// CHECK-NOT: complex.powi
8365

8466
// CHECK-LABEL: func.func @pow_c8_i8(
8567
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
86-
// CHECK-NOT: complex.pow
68+
// CHECK-NOT: complex.powi
8769

8870
// CHECK-LABEL: func.func @pow_c8_c8(
8971
// CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) : (complex<f64>, complex<f64>) -> complex<f64>
9072
// CHECK-NOT: complex.pow
9173

9274
// CHECK-LABEL: func.func @pow_c16_i4(
9375
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
94-
// CHECK-NOT: complex.pow
76+
// CHECK-NOT: complex.powi
9577

9678
// CHECK-LABEL: func.func @pow_c16_i8(
9779
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
98-
// CHECK-NOT: complex.pow
80+
// CHECK-NOT: complex.powi
9981

10082
// CHECK-LABEL: func.func @pow_c16_c16(
10183
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) : (complex<f128>, complex<f128>) -> complex<f128>

mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,32 @@ def PowOp : ComplexArithmeticOp<"pow"> {
443443
}];
444444
}
445445

446+
//===----------------------------------------------------------------------===//
447+
// PowiOp
448+
//===----------------------------------------------------------------------===//
449+
450+
def PowiOp : Complex_Op<"powi",
451+
[Pure, Elementwise, SameOperandsAndResultShape,
452+
AllTypesMatch<["lhs", "result"]>]> {
453+
let summary = "complex number raised to integer power";
454+
let description = [{
455+
The `powi` operation takes a complex number and an integer exponent.
456+
457+
Example:
458+
459+
```mlir
460+
%a = complex.powi %b, %c : complex<f32>, i32
461+
```
462+
}];
463+
464+
let arguments = (ins Complex<AnyFloat>:$lhs,
465+
AnySignlessInteger:$rhs);
466+
let results = (outs Complex<AnyFloat>:$result);
467+
468+
let assemblyFormat =
469+
"$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)";
470+
}
471+
446472
//===----------------------------------------------------------------------===//
447473
// ReOp
448474
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)