Skip to content

Commit 54677d6

Browse files
authored
[Flang] Add new ConvertComplexPow pass for Flang (#158642)
This PR introduces a new `ConvertComplexPow` pass for Flang that handles complex power operations. The change forces lowering to complex.pow operations when `--math-runtime=precise` is not used, then uses the `ConvertComplexPow` pass to convert these operations back to library calls. - Adds a new `ConvertComplexPow` pass that converts complex.pow ops to appropriate runtime library calls - Updates complex power lowering to use `complex.pow` operations by default instead of direct library calls #158722 Adds a new `complex.powi` op enabling algebraic optimisations.
1 parent 01fca01 commit 54677d6

File tree

19 files changed

+307
-41
lines changed

19 files changed

+307
-41
lines changed

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,17 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
555555
"Prefer expanding without using Fortran runtime calls.">];
556556
}
557557

558+
def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::ModuleOp"> {
559+
let summary = "Convert complex.pow operations to library calls";
560+
let description = [{
561+
Replace `complex.pow` operations with calls to the appropriate
562+
Fortran runtime or libm functions.
563+
}];
564+
let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect",
565+
"mlir::complex::ComplexDialect",
566+
"mlir::arith::ArithDialect"];
567+
}
568+
558569
def OptimizeArrayRepacking
559570
: Pass<"optimize-array-repacking", "mlir::func::FuncOp"> {
560571
let summary = "Optimizes redundant array repacking operations";

flang/include/flang/Tools/CrossToolHelpers.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks {
135135
bool NSWOnLoopVarInc = true; ///< Add nsw flag to loop variable increments.
136136
bool EnableOpenMP = false; ///< Enable OpenMP lowering.
137137
bool EnableOpenMPSimd = false; ///< Enable OpenMP simd-only mode.
138+
bool SkipConvertComplexPow = false; ///< Do not run complex pow conversion.
138139
std::string InstrumentFunctionEntry =
139140
""; ///< Name of the instrument-function that is called on each
140141
///< function-entry

flang/lib/Frontend/FrontendActions.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,8 @@ void CodeGenAction::generateLLVMIR() {
738738
pm.enableVerifier(/*verifyPasses=*/true);
739739

740740
MLIRToLLVMPassPipelineConfig config(level, opts, mathOpts);
741+
llvm::Triple pipelineTriple(invoc.getTargetOpts().triple);
742+
config.SkipConvertComplexPow = pipelineTriple.isAMDGCN();
741743
fir::registerDefaultInlinerPass(config);
742744

743745
if (auto vsr = getVScaleRange(ci)) {

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,18 +1327,18 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
13271327
const MathOperation &mathOp,
13281328
mlir::FunctionType mathLibFuncType,
13291329
llvm::ArrayRef<mlir::Value> args) {
1330-
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
1331-
if (!isAMDGPU)
1330+
if (mathRuntimeVersion == preciseVersion)
13321331
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
1333-
13341332
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
1335-
auto realTy = complexTy.getElementType();
1336-
mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
1337-
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
1338-
mlir::Value complexExp =
1339-
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
1340-
mlir::Value result =
1341-
builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
1333+
mlir::Value exp = args[1];
1334+
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
1335+
auto realTy = complexTy.getElementType();
1336+
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
1337+
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
1338+
exp =
1339+
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
1340+
}
1341+
mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
13421342
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
13431343
return result;
13441344
}
@@ -1668,11 +1668,11 @@ static constexpr MathOperation mathOperations[] = {
16681668
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
16691669
{"pow", "cpowf",
16701670
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
1671-
genComplexMathOp<mlir::complex::PowOp>},
1671+
genComplexPow},
16721672
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
1673-
genComplexMathOp<mlir::complex::PowOp>},
1673+
genComplexPow},
16741674
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
1675-
genLibF128Call},
1675+
genComplexPow},
16761676
{"pow", RTNAME_STRING(FPow4i),
16771677
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
16781678
genMathOp<mlir::math::FPowIOp>},
@@ -1698,15 +1698,15 @@ static constexpr MathOperation mathOperations[] = {
16981698
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
16991699
genComplexPow},
17001700
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
1701-
genLibF128Call},
1701+
genComplexPow},
17021702
{"pow", RTNAME_STRING(cpowk),
17031703
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
17041704
genComplexPow},
17051705
{"pow", RTNAME_STRING(zpowk),
17061706
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
17071707
genComplexPow},
17081708
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
1709-
genLibF128Call},
1709+
genComplexPow},
17101710
{"pow-unsigned", RTNAME_STRING(UPow1),
17111711
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
17121712
{"pow-unsigned", RTNAME_STRING(UPow2),

flang/lib/Optimizer/Passes/Pipelines.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
226226

227227
pm.addPass(mlir::createCanonicalizerPass(config));
228228
pm.addPass(fir::createSimplifyRegionLite());
229+
if (!pc.SkipConvertComplexPow)
230+
pm.addPass(fir::createConvertComplexPow());
229231
pm.addPass(mlir::createCSEPass());
230232

231233
if (pc.OptLevel.isOptimizingForSpeed())

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_flang_library(FIRTransforms
3535
GenRuntimeCallsForTest.cpp
3636
SimplifyFIROperations.cpp
3737
OptimizeArrayRepacking.cpp
38+
ConvertComplexPow.cpp
3839

3940
DEPENDS
4041
CUFAttrs
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Common/static-multimap-view.h"
10+
#include "flang/Optimizer/Builder/FIRBuilder.h"
11+
#include "flang/Optimizer/Dialect/FIRDialect.h"
12+
#include "flang/Optimizer/Transforms/Passes.h"
13+
#include "flang/Runtime/entry-names.h"
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Complex/IR/Complex.h"
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Pass/Pass.h"
18+
19+
namespace fir {
20+
#define GEN_PASS_DEF_CONVERTCOMPLEXPOW
21+
#include "flang/Optimizer/Transforms/Passes.h.inc"
22+
} // namespace fir
23+
24+
using namespace mlir;
25+
26+
namespace {
27+
class ConvertComplexPowPass
28+
: public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> {
29+
public:
30+
void getDependentDialects(DialectRegistry &registry) const override {
31+
registry.insert<fir::FIROpsDialect, complex::ComplexDialect,
32+
arith::ArithDialect, func::FuncDialect>();
33+
}
34+
void runOnOperation() override;
35+
};
36+
} // namespace
37+
38+
// Helper to declare or get a math library function.
39+
static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
40+
StringRef name, FunctionType type) {
41+
if (auto func = builder.getNamedFunction(name))
42+
return func;
43+
auto func = builder.createFunction(loc, name, type);
44+
func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name));
45+
func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
46+
builder.getUnitAttr());
47+
return func;
48+
}
49+
50+
static bool isZero(Value v) {
51+
if (auto cst = v.getDefiningOp<arith::ConstantOp>())
52+
if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
53+
return attr.getValue().isZero();
54+
return false;
55+
}
56+
57+
void ConvertComplexPowPass::runOnOperation() {
58+
ModuleOp mod = getOperation();
59+
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
60+
61+
mod.walk([&](complex::PowOp op) {
62+
builder.setInsertionPoint(op);
63+
Location loc = op.getLoc();
64+
auto complexTy = cast<ComplexType>(op.getType());
65+
auto elemTy = complexTy.getElementType();
66+
67+
Value base = op.getLhs();
68+
Value rhs = op.getRhs();
69+
70+
Value intExp;
71+
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
72+
if (isZero(create.getImaginary())) {
73+
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
74+
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
75+
intExp = conv.getValue();
76+
}
77+
}
78+
}
79+
80+
func::FuncOp callee;
81+
SmallVector<Value> args;
82+
if (intExp) {
83+
unsigned realBits = cast<FloatType>(elemTy).getWidth();
84+
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
85+
auto funcTy = builder.getFunctionType(
86+
{complexTy, builder.getIntegerType(intBits)}, {complexTy});
87+
if (realBits == 32 && intBits == 32)
88+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
89+
else if (realBits == 32 && intBits == 64)
90+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
91+
else if (realBits == 64 && intBits == 32)
92+
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
93+
else if (realBits == 64 && intBits == 64)
94+
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
95+
else if (realBits == 128 && intBits == 32)
96+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
97+
else if (realBits == 128 && intBits == 64)
98+
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
99+
else
100+
return;
101+
args = {base, intExp};
102+
} else {
103+
unsigned realBits = cast<FloatType>(elemTy).getWidth();
104+
auto funcTy =
105+
builder.getFunctionType({complexTy, complexTy}, {complexTy});
106+
if (realBits == 32)
107+
callee = getOrDeclare(builder, loc, "cpowf", funcTy);
108+
else if (realBits == 64)
109+
callee = getOrDeclare(builder, loc, "cpow", funcTy);
110+
else if (realBits == 128)
111+
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
112+
else
113+
return;
114+
args = {base, rhs};
115+
}
116+
117+
auto call = fir::CallOp::create(builder, loc, callee, args);
118+
if (auto fmf = op.getFastmathAttr())
119+
call.setFastmathAttr(fmf);
120+
op.replaceAllUsesWith(call.getResult(0));
121+
op.erase();
122+
});
123+
}

flang/test/Driver/bbc-mlir-pass-pipeline.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
! CHECK-NEXT: SCFToControlFlow
7070
! CHECK-NEXT: Canonicalizer
7171
! CHECK-NEXT: SimplifyRegionLite
72+
! CHECK-NEXT: ConvertComplexPow
7273
! CHECK-NEXT: CSE
7374
! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
7475
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

flang/test/Driver/mlir-debug-pass-pipeline.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
! ALL-NEXT: SCFToControlFlow
9797
! ALL-NEXT: Canonicalizer
9898
! ALL-NEXT: SimplifyRegionLite
99+
! ALL-NEXT: ConvertComplexPow
99100
! ALL-NEXT: CSE
100101
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
101102
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

flang/test/Driver/mlir-pass-pipeline.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
! ALL-NEXT: SCFToControlFlow
128128
! ALL-NEXT: Canonicalizer
129129
! ALL-NEXT: SimplifyRegionLite
130+
! ALL-NEXT: ConvertComplexPow
130131
! ALL-NEXT: CSE
131132
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
132133
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd

0 commit comments

Comments
 (0)