Skip to content

Commit dce85c3

Browse files
committed
Move cpow constant optimisation to Fortran lowering.
1 parent 9f01454 commit dce85c3

File tree

5 files changed

+29
-67
lines changed

5 files changed

+29
-67
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,14 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
12801280
const MathOperation &mathOp,
12811281
mlir::FunctionType mathLibFuncType,
12821282
llvm::ArrayRef<mlir::Value> args) {
1283+
if (auto expInt = fir::getIntIfConstant(args[1]))
1284+
if (*expInt >= 2 && *expInt <= 8) {
1285+
mlir::Value result = args[0];
1286+
for (int i = 1; i < *expInt; ++i)
1287+
result = builder.create<mlir::complex::MulOp>(loc, result, args[0]);
1288+
return builder.createConvert(loc, mathLibFuncType.getResult(0), result);
1289+
}
1290+
12831291
bool canUseApprox = mlir::arith::bitEnumContainsAny(
12841292
builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
12851293
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();

flang/test/Lower/amdgcn-complex.f90

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,19 @@ subroutine cexpf_test(a, b)
1818
b = exp(a)
1919
end subroutine
2020

21-
! CHECK-LABEL: func @_QPpow_test(
22-
! CHECK: complex.pow
21+
! CHECK-LABEL: func @_QPpow_test1(
22+
! CHECK: complex.mul
23+
! CHECK-NOT: complex.pow
2324
! CHECK-NOT: fir.call @_FortranAcpowi
24-
subroutine pow_test(a, b)
25+
subroutine pow_test1(a, b)
2526
complex :: a, b
2627
a = b**2
27-
end subroutine pow_test
28+
end subroutine pow_test1
29+
30+
! CHECK-LABEL: func @_QPpow_test2(
31+
! CHECK: complex.pow
32+
! CHECK-NOT: fir.call @_FortranAcpowi
33+
subroutine pow_test2(a, b, c)
34+
complex :: a, b, c
35+
a = b**c
36+
end subroutine pow_test2

flang/test/Lower/power-operator.f90

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,11 @@ subroutine pow_c8_c8(x, y, z)
143143
! PRECISE: call @cpow
144144
end subroutine
145145

146+
! CHECK-LABEL: pow_const
147+
subroutine pow_const(a, b)
148+
complex :: a, b
149+
! CHECK-NOT: complex.pow
150+
! CHECK-NOT: @_FortranAcpowi
151+
! CHECK-COUNT-3: complex.mul
152+
a = b**4
153+
end subroutine

mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
10-
#include "mlir/Dialect/Arith/IR/Arith.h"
1110
#include "mlir/Dialect/Complex/IR/Complex.h"
1211
#include "mlir/Dialect/Func/IR/FuncOps.h"
1312
#include "mlir/IR/PatternMatch.h"
@@ -59,52 +58,12 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
5958
};
6059

6160
// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
62-
// Rewrite complex.pow(z, i) -> z * z ... * z for 2 >= i <=8
6361
struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
6462
using OpRewritePattern<complex::PowOp>::OpRewritePattern;
6563

6664
LogicalResult matchAndRewrite(complex::PowOp op,
6765
PatternRewriter &rewriter) const final {
6866
auto loc = op.getLoc();
69-
70-
auto peelConst = [&](Value val) -> std::optional<TypedAttr> {
71-
while (val) {
72-
Operation *defOp = val.getDefiningOp();
73-
if (!defOp)
74-
return std::nullopt;
75-
76-
if (auto constVal = dyn_cast<arith::ConstantOp>(defOp))
77-
return dyn_cast<TypedAttr>(constVal.getValue());
78-
79-
if (defOp->getName().getStringRef() == "fir.convert" &&
80-
defOp->getNumOperands() == 1) {
81-
val = defOp->getOperand(0);
82-
continue;
83-
}
84-
return std::nullopt;
85-
}
86-
return std::nullopt;
87-
};
88-
89-
if (auto createOp = op.getRhs().getDefiningOp<complex::CreateOp>()) {
90-
auto image = peelConst(createOp.getImaginary());
91-
auto real = peelConst(createOp.getReal());
92-
if (image && real) {
93-
auto imagFloat = dyn_cast<FloatAttr>(*image);
94-
if (imagFloat && imagFloat.getValue().isZero()) {
95-
auto realInt = dyn_cast<IntegerAttr>(*real);
96-
if (realInt && realInt.getInt() >= 2 && realInt.getInt() <= 8) {
97-
Value base = op.getLhs();
98-
Value result = base;
99-
for (int i = 1; i < realInt.getInt(); ++i)
100-
result = rewriter.create<complex::MulOp>(loc, result, base);
101-
rewriter.replaceOp(op, result);
102-
return success();
103-
}
104-
}
105-
}
106-
}
107-
10867
Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
10968
Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
11069
Value exp = rewriter.create<complex::ExpOp>(loc, mul);

mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,28 +68,6 @@ func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
6868
return %r : complex<f32>
6969
}
7070

71-
// CHECK-LABEL: @pow_int_caller
72-
func.func @pow_int_caller(%f : complex<f32>, %d : complex<f64>)
73-
->(complex<f32>, complex<f64>) {
74-
// CHECK-NOT: call @__ocml
75-
// CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
76-
%c2_i32 = arith.constant 2 : i32
77-
%c2r = "fir.convert"(%c2_i32) : (i32) -> f32
78-
%c2i = arith.constant 0.0 : f32
79-
%c2 = complex.create %c2r, %c2i : complex<f32>
80-
%p2 = complex.pow %f, %c2 : complex<f32>
81-
// CHECK-NOT: call @__ocml
82-
// CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f64>
83-
// CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex<f64>
84-
%c3_i32 = arith.constant 3 : i32
85-
%c3r = "fir.convert"(%c3_i32) : (i32) -> f64
86-
%c3i = arith.constant 0.0 : f64
87-
%c3 = complex.create %c3r, %c3i : complex<f64>
88-
%p3 = complex.pow %d, %c3 : complex<f64>
89-
// CHECK: return %[[M2]], %[[M3B]]
90-
return %p2, %p3 : complex<f32>, complex<f64>
91-
}
92-
9371
//CHECK-LABEL: @sin_caller
9472
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
9573
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})

0 commit comments

Comments
 (0)