Skip to content

Commit 6976910

Browse files
committed
Add complex.powi op.
1 parent 2632942 commit 6976910

File tree

13 files changed

+188
-76
lines changed

13 files changed

+188
-76
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,14 +1331,20 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
13311331
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
13321332
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
13331333
mlir::Value exp = args[1];
1334-
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
1335-
auto realTy = complexTy.getElementType();
1336-
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
1337-
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
1338-
exp =
1339-
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
1334+
mlir::Value result;
1335+
if (mlir::isa<mlir::IntegerType>(exp.getType()) ||
1336+
mlir::isa<mlir::IndexType>(exp.getType())) {
1337+
result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp);
1338+
} else {
1339+
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
1340+
auto realTy = complexTy.getElementType();
1341+
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
1342+
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
1343+
exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp,
1344+
zero);
1345+
}
1346+
result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
13401347
}
1341-
mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
13421348
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
13431349
return result;
13441350
}

flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp

Lines changed: 44 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -58,63 +58,57 @@ void ConvertComplexPowPass::runOnOperation() {
5858
ModuleOp mod = getOperation();
5959
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
6060

61-
mod.walk([&](complex::PowOp op) {
61+
mod.walk([&](complex::PowiOp op) {
6262
builder.setInsertionPoint(op);
6363
Location loc = op.getLoc();
6464
auto complexTy = cast<ComplexType>(op.getType());
6565
auto elemTy = complexTy.getElementType();
66-
6766
Value base = op.getLhs();
68-
Value rhs = op.getRhs();
69-
70-
Value intExp;
71-
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
72-
if (isZero(create.getImaginary())) {
73-
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
74-
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
75-
intExp = conv.getValue();
76-
}
77-
}
78-
}
79-
67+
Value intExp = op.getRhs();
8068
func::FuncOp callee;
81-
SmallVector<Value> args;
82-
if (intExp) {
83-
unsigned realBits = cast<FloatType>(elemTy).getWidth();
84-
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
85-
auto funcTy = builder.getFunctionType(
86-
{complexTy, builder.getIntegerType(intBits)}, {complexTy});
87-
if (realBits == 32 && intBits == 32)
88-
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
89-
else if (realBits == 32 && intBits == 64)
90-
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
91-
else if (realBits == 64 && intBits == 32)
92-
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
93-
else if (realBits == 64 && intBits == 64)
94-
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
95-
else if (realBits == 128 && intBits == 32)
96-
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
97-
else if (realBits == 128 && intBits == 64)
98-
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
99-
else
100-
return;
101-
args = {base, intExp};
102-
} else {
103-
unsigned realBits = cast<FloatType>(elemTy).getWidth();
104-
auto funcTy =
105-
builder.getFunctionType({complexTy, complexTy}, {complexTy});
106-
if (realBits == 32)
107-
callee = getOrDeclare(builder, loc, "cpowf", funcTy);
108-
else if (realBits == 64)
109-
callee = getOrDeclare(builder, loc, "cpow", funcTy);
110-
else if (realBits == 128)
111-
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
112-
else
113-
return;
114-
args = {base, rhs};
115-
}
69+
unsigned realBits = cast<FloatType>(elemTy).getWidth();
70+
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
71+
auto funcTy = builder.getFunctionType(
72+
{complexTy, builder.getIntegerType(intBits)}, {complexTy});
73+
if (realBits == 32 && intBits == 32)
74+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
75+
else if (realBits == 32 && intBits == 64)
76+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
77+
else if (realBits == 64 && intBits == 32)
78+
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
79+
else if (realBits == 64 && intBits == 64)
80+
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
81+
else if (realBits == 128 && intBits == 32)
82+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
83+
else if (realBits == 128 && intBits == 64)
84+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
85+
else
86+
return;
87+
auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
88+
if (auto fmf = op.getFastmathAttr())
89+
call.setFastmathAttr(fmf);
90+
op.replaceAllUsesWith(call.getResult(0));
91+
op.erase();
92+
});
11693

117-
auto call = fir::CallOp::create(builder, loc, callee, args);
94+
mod.walk([&](complex::PowOp op) {
95+
builder.setInsertionPoint(op);
96+
Location loc = op.getLoc();
97+
auto complexTy = cast<ComplexType>(op.getType());
98+
auto elemTy = complexTy.getElementType();
99+
unsigned realBits = cast<FloatType>(elemTy).getWidth();
100+
func::FuncOp callee;
101+
auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy});
102+
if (realBits == 32)
103+
callee = getOrDeclare(builder, loc, "cpowf", funcTy);
104+
else if (realBits == 64)
105+
callee = getOrDeclare(builder, loc, "cpow", funcTy);
106+
else if (realBits == 128)
107+
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
108+
else
109+
return;
110+
auto call =
111+
fir::CallOp::create(builder, loc, callee, {op.getLhs(), op.getRhs()});
118112
if (auto fmf = op.getFastmathAttr())
119113
call.setFastmathAttr(fmf);
120114
op.replaceAllUsesWith(call.getResult(0));

flang/test/Lower/HLFIR/binary-ops.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
193193
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
194194
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
195195
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
196-
! CHECK: %[[VAL_8:.*]] = complex.pow
196+
! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32
197197

198198
subroutine extremum(c, n, l)
199199
integer(8), intent(in) :: l

flang/test/Lower/Intrinsics/pow_complex16i.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
55

66
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
7-
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
7+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
88
complex(16) :: a
99
integer(4) :: b
1010
b = a ** b

flang/test/Lower/Intrinsics/pow_complex16k.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
55

66
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
7-
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
7+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
88
complex(16) :: a
99
integer(8) :: b
1010
b = a ** b

flang/test/Lower/amdgcn-complex.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
2525
complex :: a, b, c
2626
a = b**c
2727
end subroutine pow_test
28+
29+
! CHECK-LABEL: func @_QPpowi_test(
30+
! CHECK: complex.powi
31+
! CHECK-NOT: fir.call @_FortranAcpowi
32+
subroutine powi_test(a, b, c)
33+
complex :: a, b
34+
integer :: i
35+
b = a ** i
36+
end subroutine powi_test

flang/test/Lower/power-operator.f90

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
9696
complex :: x, z
9797
integer :: y
9898
z = x ** y
99-
! CHECK: complex.pow
99+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
100100
! PRECISE: fir.call @_FortranAcpowi
101101
end subroutine
102102

@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
105105
complex :: x, z
106106
integer(8) :: y
107107
z = x ** y
108-
! CHECK: complex.pow
108+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
109109
! PRECISE: fir.call @_FortranAcpowk
110110
end subroutine
111111

@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
114114
complex(8) :: x, z
115115
integer :: y
116116
z = x ** y
117-
! CHECK: complex.pow
117+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
118118
! PRECISE: fir.call @_FortranAzpowi
119119
end subroutine
120120

@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
123123
complex(8) :: x, z
124124
integer(8) :: y
125125
z = x ** y
126-
! CHECK: complex.pow
126+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
127127
! PRECISE: fir.call @_FortranAzpowk
128128
end subroutine
129129

@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
142142
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
143143
! PRECISE: fir.call @cpow
144144
end subroutine
145-

mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,32 @@ def PowOp : ComplexArithmeticOp<"pow"> {
443443
}];
444444
}
445445

446+
//===----------------------------------------------------------------------===//
447+
// PowiOp
448+
//===----------------------------------------------------------------------===//
449+
450+
def PowiOp : Complex_Op<"powi",
451+
[Pure, Elementwise, SameOperandsAndResultShape,
452+
AllTypesMatch<["lhs", "result"]>]> {
453+
let summary = "complex number raised to integer power";
454+
let description = [{
455+
The `powi` operation takes a complex number and an integer exponent.
456+
457+
Example:
458+
459+
```mlir
460+
%a = complex.powi %b, %c : complex<f32>, i32
461+
```
462+
}];
463+
464+
let arguments = (ins Complex<AnyFloat>:$lhs,
465+
AnySignlessInteger:$rhs);
466+
let results = (outs Complex<AnyFloat>:$result);
467+
468+
let assemblyFormat =
469+
"$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)";
470+
}
471+
446472
//===----------------------------------------------------------------------===//
447473
// ReOp
448474
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
10+
#include "mlir/Dialect/Arith/IR/Arith.h"
1011
#include "mlir/Dialect/Complex/IR/Complex.h"
1112
#include "mlir/Dialect/Func/IR/FuncOps.h"
1213
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/IR/TypeUtilities.h"
1315
#include "mlir/Transforms/DialectConversion.h"
1416

1517
namespace mlir {
@@ -74,10 +76,40 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
7476
return success();
7577
}
7678
};
79+
80+
// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0))
81+
struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
82+
using OpRewritePattern<complex::PowiOp>::OpRewritePattern;
83+
84+
LogicalResult matchAndRewrite(complex::PowiOp op,
85+
PatternRewriter &rewriter) const final {
86+
auto complexType = cast<ComplexType>(getElementTypeOrSelf(op.getType()));
87+
Type elementType = complexType.getElementType();
88+
89+
Type exponentType = op.getRhs().getType();
90+
Type exponentFloatType = elementType;
91+
if (auto shapedType = dyn_cast<ShapedType>(exponentType))
92+
exponentFloatType = shapedType.cloneWith(std::nullopt, elementType);
93+
94+
Location loc = op.getLoc();
95+
Value exponentReal =
96+
rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
97+
Value zeroImag = rewriter.create<arith::ConstantOp>(
98+
loc, rewriter.getZeroAttr(exponentFloatType));
99+
Value exponent = rewriter.create<complex::CreateOp>(
100+
loc, op.getLhs().getType(), exponentReal, zeroImag);
101+
102+
rewriter
103+
.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
104+
exponent);
105+
return success();
106+
}
107+
};
77108
} // namespace
78109

79110
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
80111
RewritePatternSet &patterns) {
112+
patterns.add<PowiOpToROCDLLibraryCalls>(patterns.getContext());
81113
patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
82114
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
83115
patterns.getContext(), "__ocml_cabs_f32");
@@ -128,11 +160,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
128160
populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
129161

130162
ConversionTarget target(getContext());
131-
target.addLegalDialect<func::FuncDialect>();
132-
target.addLegalOp<complex::MulOp>();
163+
target.addLegalDialect<arith::ArithDialect, func::FuncDialect>();
164+
target.addLegalOp<complex::CreateOp, complex::MulOp>();
133165
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
134-
complex::LogOp, complex::PowOp, complex::SinOp,
135-
complex::SqrtOp, complex::TanOp, complex::TanhOp>();
166+
complex::LogOp, complex::PowOp, complex::PowiOp,
167+
complex::SinOp, complex::SqrtOp, complex::TanOp,
168+
complex::TanhOp>();
136169
if (failed(applyPartialConversion(op, target, std::move(patterns))))
137170
signalPassFailure();
138171
}

mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/Complex/IR/Complex.h"
1617
#include "mlir/Dialect/Math/IR/Math.h"
1718
#include "mlir/Dialect/Math/Transforms/Passes.h"
1819
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -175,12 +176,20 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
175176

176177
Value one;
177178
Type opType = getElementTypeOrSelf(op.getType());
178-
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
179+
if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>) {
179180
one = arith::ConstantOp::create(rewriter, loc,
180181
rewriter.getFloatAttr(opType, 1.0));
181-
else
182+
} else if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>) {
183+
auto complexTy = cast<ComplexType>(opType);
184+
Type elementType = complexTy.getElementType();
185+
auto realPart = rewriter.getFloatAttr(elementType, 1.0);
186+
auto imagPart = rewriter.getFloatAttr(elementType, 0.0);
187+
one = rewriter.create<complex::ConstantOp>(
188+
loc, complexTy, rewriter.getArrayAttr({realPart, imagPart}));
189+
} else {
182190
one = arith::ConstantOp::create(rewriter, loc,
183191
rewriter.getIntegerAttr(opType, 1));
192+
}
184193

185194
// Replace `[fi]powi(x, 0)` with `1`.
186195
if (exponentValue == 0) {
@@ -224,9 +233,10 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
224233

225234
void mlir::populateMathAlgebraicSimplificationPatterns(
226235
RewritePatternSet &patterns) {
227-
patterns
228-
.add<PowFStrengthReduction,
229-
PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
230-
PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
231-
patterns.getContext());
236+
patterns.add<
237+
PowFStrengthReduction,
238+
PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
239+
PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>,
240+
PowIStrengthReduction<complex::PowiOp, complex::DivOp, complex::MulOp>>(
241+
patterns.getContext(), /*exponentThreshold=*/8);
232242
}

0 commit comments

Comments
 (0)