Skip to content

Commit bcf4b5a

Browse files
committed
Force lowering to complex.pow ops.
1 parent 53a4e4a commit bcf4b5a

File tree

15 files changed

+293
-38
lines changed

15 files changed

+293
-38
lines changed

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,17 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
551551
"Prefer expanding without using Fortran runtime calls.">];
552552
}
553553

554+
def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::func::FuncOp"> {
555+
let summary = "Convert complex.pow operations to library calls";
556+
let description = [{
557+
Replace `complex.pow` operations with calls to the appropriate
558+
Fortran runtime or libm functions.
559+
}];
560+
let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect",
561+
"mlir::complex::ComplexDialect",
562+
"mlir::arith::ArithDialect"];
563+
}
564+
554565
def OptimizeArrayRepacking
555566
: Pass<"optimize-array-repacking", "mlir::func::FuncOp"> {
556567
let summary = "Optimizes redundant array repacking operations";

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,18 +1327,18 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
13271327
const MathOperation &mathOp,
13281328
mlir::FunctionType mathLibFuncType,
13291329
llvm::ArrayRef<mlir::Value> args) {
1330-
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
1331-
if (!isAMDGPU)
1330+
if (mathRuntimeVersion == preciseVersion)
13321331
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
1333-
13341332
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
1335-
auto realTy = complexTy.getElementType();
1336-
mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
1337-
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
1338-
mlir::Value complexExp =
1339-
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
1340-
mlir::Value result =
1341-
builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
1333+
mlir::Value exp = args[1];
1334+
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
1335+
auto realTy = complexTy.getElementType();
1336+
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
1337+
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
1338+
exp =
1339+
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
1340+
}
1341+
mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
13421342
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
13431343
return result;
13441344
}
@@ -1668,11 +1668,11 @@ static constexpr MathOperation mathOperations[] = {
16681668
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
16691669
{"pow", "cpowf",
16701670
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
1671-
genComplexMathOp<mlir::complex::PowOp>},
1671+
genComplexPow},
16721672
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
1673-
genComplexMathOp<mlir::complex::PowOp>},
1673+
genComplexPow},
16741674
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
1675-
genLibF128Call},
1675+
genComplexPow},
16761676
{"pow", RTNAME_STRING(FPow4i),
16771677
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
16781678
genMathOp<mlir::math::FPowIOp>},
@@ -1698,15 +1698,15 @@ static constexpr MathOperation mathOperations[] = {
16981698
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
16991699
genComplexPow},
17001700
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
1701-
genLibF128Call},
1701+
genComplexPow},
17021702
{"pow", RTNAME_STRING(cpowk),
17031703
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
17041704
genComplexPow},
17051705
{"pow", RTNAME_STRING(zpowk),
17061706
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
17071707
genComplexPow},
17081708
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
1709-
genLibF128Call},
1709+
genComplexPow},
17101710
{"pow-unsigned", RTNAME_STRING(UPow1),
17111711
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
17121712
{"pow-unsigned", RTNAME_STRING(UPow2),

flang/lib/Optimizer/Passes/Pipelines.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
225225

226226
pm.addPass(mlir::createCanonicalizerPass(config));
227227
pm.addPass(fir::createSimplifyRegionLite());
228+
pm.addPass(fir::createConvertComplexPow());
228229
pm.addPass(mlir::createCSEPass());
229230

230231
if (pc.OptLevel.isOptimizingForSpeed())

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_flang_library(FIRTransforms
3535
GenRuntimeCallsForTest.cpp
3636
SimplifyFIROperations.cpp
3737
OptimizeArrayRepacking.cpp
38+
ConvertComplexPow.cpp
3839

3940
DEPENDS
4041
CUFAttrs
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Common/static-multimap-view.h"
10+
#include "flang/Optimizer/Builder/FIRBuilder.h"
11+
#include "flang/Optimizer/Dialect/FIRDialect.h"
12+
#include "flang/Optimizer/Transforms/Passes.h"
13+
#include "flang/Runtime/entry-names.h"
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Complex/IR/Complex.h"
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Pass/Pass.h"
18+
19+
namespace fir {
20+
#define GEN_PASS_DEF_CONVERTCOMPLEXPOW
21+
#include "flang/Optimizer/Transforms/Passes.h.inc"
22+
} // namespace fir
23+
24+
using namespace mlir;
25+
26+
namespace {
27+
class ConvertComplexPowPass
28+
: public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> {
29+
public:
30+
void getDependentDialects(DialectRegistry &registry) const override {
31+
registry.insert<fir::FIROpsDialect, complex::ComplexDialect,
32+
arith::ArithDialect, func::FuncDialect>();
33+
}
34+
void runOnOperation() override;
35+
};
36+
} // namespace
37+
38+
// Helper to declare or get a math library function.
39+
static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
40+
StringRef name, FunctionType type) {
41+
if (auto func = builder.getNamedFunction(name))
42+
return func;
43+
auto func = builder.createFunction(loc, name, type);
44+
func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name));
45+
func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
46+
builder.getUnitAttr());
47+
return func;
48+
}
49+
50+
static bool isZero(Value v) {
51+
if (auto cst = v.getDefiningOp<arith::ConstantOp>())
52+
if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
53+
return attr.getValue().isZero();
54+
return false;
55+
}
56+
57+
void ConvertComplexPowPass::runOnOperation() {
58+
auto func = getOperation();
59+
auto mod = func->getParentOfType<ModuleOp>();
60+
if (fir::getTargetTriple(mod).isAMDGCN())
61+
return;
62+
63+
fir::FirOpBuilder builder(func, fir::getKindMapping(mod));
64+
65+
func.walk([&](complex::PowOp op) {
66+
builder.setInsertionPoint(op);
67+
Location loc = op.getLoc();
68+
auto complexTy = cast<ComplexType>(op.getType());
69+
auto elemTy = complexTy.getElementType();
70+
71+
Value base = op.getLhs();
72+
Value rhs = op.getRhs();
73+
74+
Value intExp;
75+
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
76+
if (isZero(create.getImaginary())) {
77+
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
78+
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
79+
intExp = conv.getValue();
80+
}
81+
}
82+
}
83+
84+
func::FuncOp callee;
85+
SmallVector<Value> args;
86+
if (intExp) {
87+
unsigned realBits = cast<FloatType>(elemTy).getWidth();
88+
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
89+
auto funcTy = builder.getFunctionType(
90+
{complexTy, builder.getIntegerType(intBits)}, {complexTy});
91+
if (realBits == 32 && intBits == 32)
92+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
93+
else if (realBits == 32 && intBits == 64)
94+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
95+
else if (realBits == 64 && intBits == 32)
96+
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
97+
else if (realBits == 64 && intBits == 64)
98+
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
99+
else if (realBits == 128 && intBits == 32)
100+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
101+
else if (realBits == 128 && intBits == 64)
102+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
103+
else
104+
return;
105+
args = {base, intExp};
106+
} else {
107+
unsigned realBits = cast<FloatType>(elemTy).getWidth();
108+
auto funcTy =
109+
builder.getFunctionType({complexTy, complexTy}, {complexTy});
110+
if (realBits == 32)
111+
callee = getOrDeclare(builder, loc, "cpowf", funcTy);
112+
else if (realBits == 64)
113+
callee = getOrDeclare(builder, loc, "cpow", funcTy);
114+
else if (realBits == 128)
115+
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
116+
else
117+
return;
118+
args = {base, rhs};
119+
}
120+
121+
auto call = fir::CallOp::create(builder, loc, callee, args);
122+
op.replaceAllUsesWith(call.getResult(0));
123+
op.erase();
124+
});
125+
}

flang/test/Driver/bbc-mlir-pass-pipeline.f90

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
! CHECK-NEXT: SCFToControlFlow
7070
! CHECK-NEXT: Canonicalizer
7171
! CHECK-NEXT: SimplifyRegionLite
72+
! CHECK-NEXT: 'func.func' Pipeline
73+
! CHECK-NEXT: ConvertComplexPow
7274
! CHECK-NEXT: CSE
7375
! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
7476
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

flang/test/Driver/mlir-debug-pass-pipeline.f90

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@
9696
! ALL-NEXT: SCFToControlFlow
9797
! ALL-NEXT: Canonicalizer
9898
! ALL-NEXT: SimplifyRegionLite
99+
! ALL-NEXT: 'func.func' Pipeline
100+
! ALL-NEXT: ConvertComplexPow
99101
! ALL-NEXT: CSE
100102
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
101103
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

flang/test/Driver/mlir-pass-pipeline.f90

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@
127127
! ALL-NEXT: SCFToControlFlow
128128
! ALL-NEXT: Canonicalizer
129129
! ALL-NEXT: SimplifyRegionLite
130+
! ALL-NEXT: 'func.func' Pipeline
131+
! ALL-NEXT: ConvertComplexPow
130132
! ALL-NEXT: CSE
131133
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
132134
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

flang/test/Fir/basic-program.fir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ func.func @_QQmain() {
125125
// PASSES-NEXT: SCFToControlFlow
126126
// PASSES-NEXT: Canonicalizer
127127
// PASSES-NEXT: SimplifyRegionLite
128+
// PASSES-NEXT: 'func.func' Pipeline
129+
// PASSES-NEXT: ConvertComplexPow
128130
// PASSES-NEXT: CSE
129131
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
130132
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

flang/test/Lower/HLFIR/binary-ops.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ subroutine complex_power(x, y, z)
168168
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<complex<f32>>, !fir.dscope) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
169169
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
170170
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<complex<f32>>
171-
! CHECK: %[[VAL_8:.*]] = fir.call @cpowf(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, complex<f32>) -> complex<f32>
171+
! CHECK: %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>
172172

173173

174174
subroutine real_to_int_power(x, y, z)
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
193193
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
194194
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
195195
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
196-
! CHECK: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
196+
! CHECK: %[[VAL_8:.*]] = complex.pow
197197

198198
subroutine extremum(c, n, l)
199199
integer(8), intent(in) :: l

0 commit comments

Comments
 (0)