Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,17 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
"Prefer expanding without using Fortran runtime calls.">];
}

def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::ModuleOp"> {
let summary = "Convert complex.pow operations to library calls";
let description = [{
Replace `complex.pow` operations with calls to the appropriate
Fortran runtime or libm functions.
}];
let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect",
"mlir::complex::ComplexDialect",
"mlir::arith::ArithDialect"];
}

def OptimizeArrayRepacking
: Pass<"optimize-array-repacking", "mlir::func::FuncOp"> {
let summary = "Optimizes redundant array repacking operations";
Expand Down
1 change: 1 addition & 0 deletions flang/include/flang/Tools/CrossToolHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks {
bool NSWOnLoopVarInc = true; ///< Add nsw flag to loop variable increments.
bool EnableOpenMP = false; ///< Enable OpenMP lowering.
bool EnableOpenMPSimd = false; ///< Enable OpenMP simd-only mode.
bool SkipConvertComplexPow = false; ///< Do not run complex pow conversion.
std::string InstrumentFunctionEntry =
""; ///< Name of the instrument-function that is called on each
///< function-entry
Expand Down
2 changes: 2 additions & 0 deletions flang/lib/Frontend/FrontendActions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,8 @@ void CodeGenAction::generateLLVMIR() {
pm.enableVerifier(/*verifyPasses=*/true);

MLIRToLLVMPassPipelineConfig config(level, opts, mathOpts);
llvm::Triple pipelineTriple(invoc.getTargetOpts().triple);
config.SkipConvertComplexPow = pipelineTriple.isAMDGCN();
fir::registerDefaultInlinerPass(config);

if (auto vsr = getVScaleRange(ci)) {
Expand Down
30 changes: 15 additions & 15 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1327,18 +1327,18 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
const MathOperation &mathOp,
mlir::FunctionType mathLibFuncType,
llvm::ArrayRef<mlir::Value> args) {
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
if (!isAMDGPU)
if (mathRuntimeVersion == preciseVersion)
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);

auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
auto realTy = complexTy.getElementType();
mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
mlir::Value complexExp =
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
mlir::Value result =
builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
mlir::Value exp = args[1];
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
auto realTy = complexTy.getElementType();
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
exp =
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
}
mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
}
Expand Down Expand Up @@ -1668,11 +1668,11 @@ static constexpr MathOperation mathOperations[] = {
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
{"pow", "cpowf",
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
genComplexMathOp<mlir::complex::PowOp>},
genComplexPow},
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
genComplexMathOp<mlir::complex::PowOp>},
genComplexPow},
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
genLibF128Call},
genComplexPow},
{"pow", RTNAME_STRING(FPow4i),
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
genMathOp<mlir::math::FPowIOp>},
Expand All @@ -1698,15 +1698,15 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
genLibF128Call},
genComplexPow},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you have complex.powi, I think we can just use genComplexMathOp<mlir::complex::powi> or genMathOp<mlir::complex::powi> here.

We can probably get rid of genComplexPow and use genMathOp instead.

Copy link
Member Author

@TIFitis TIFitis Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

genComplexMathOp would lower to libCall if (!forceMlirComplex && !canUseApprox && !isAMDGPU)
which means we would restrict lowering to complex.pow for some cases where we are currently forcing it. Is that okay?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if (!forceMlirComplex && !canUseApprox && !isAMDGPU) check is yet another workaround that has to be removed eventually (not in this PR).

I think genMathOp should work here just fine or am I missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I glossed over genMathOp, I've added this change in #158722 to removed genComplexPow.

{"pow", RTNAME_STRING(cpowk),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
genComplexPow},
{"pow", RTNAME_STRING(zpowk),
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
genLibF128Call},
genComplexPow},
{"pow-unsigned", RTNAME_STRING(UPow1),
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
{"pow-unsigned", RTNAME_STRING(UPow2),
Expand Down
2 changes: 2 additions & 0 deletions flang/lib/Optimizer/Passes/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,

pm.addPass(mlir::createCanonicalizerPass(config));
pm.addPass(fir::createSimplifyRegionLite());
if (!pc.SkipConvertComplexPow)
pm.addPass(fir::createConvertComplexPow());
pm.addPass(mlir::createCSEPass());

if (pc.OptLevel.isOptimizingForSpeed())
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_flang_library(FIRTransforms
GenRuntimeCallsForTest.cpp
SimplifyFIROperations.cpp
OptimizeArrayRepacking.cpp
ConvertComplexPow.cpp

DEPENDS
CUFAttrs
Expand Down
123 changes: 123 additions & 0 deletions flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Common/static-multimap-view.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "flang/Runtime/entry-names.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"

namespace fir {
#define GEN_PASS_DEF_CONVERTCOMPLEXPOW
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

using namespace mlir;

namespace {
class ConvertComplexPowPass
: public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<fir::FIROpsDialect, complex::ComplexDialect,
arith::ArithDialect, func::FuncDialect>();
}
void runOnOperation() override;
};
} // namespace

// Helper to declare or get a math library function.
static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
StringRef name, FunctionType type) {
if (auto func = builder.getNamedFunction(name))
return func;
auto func = builder.createFunction(loc, name, type);
func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name));
func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
builder.getUnitAttr());
return func;
}

static bool isZero(Value v) {
if (auto cst = v.getDefiningOp<arith::ConstantOp>())
if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
return attr.getValue().isZero();
return false;
}

void ConvertComplexPowPass::runOnOperation() {
ModuleOp mod = getOperation();
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));

mod.walk([&](complex::PowOp op) {
builder.setInsertionPoint(op);
Location loc = op.getLoc();
auto complexTy = cast<ComplexType>(op.getType());
auto elemTy = complexTy.getElementType();

Value base = op.getLhs();
Value rhs = op.getRhs();

Value intExp;
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
if (isZero(create.getImaginary())) {
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
intExp = conv.getValue();
}
}
}

func::FuncOp callee;
SmallVector<Value> args;
if (intExp) {
unsigned realBits = cast<FloatType>(elemTy).getWidth();
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
auto funcTy = builder.getFunctionType(
{complexTy, builder.getIntegerType(intBits)}, {complexTy});
if (realBits == 32 && intBits == 32)
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
else if (realBits == 32 && intBits == 64)
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
else if (realBits == 64 && intBits == 32)
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
else if (realBits == 64 && intBits == 64)
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
else if (realBits == 128 && intBits == 32)
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
else if (realBits == 128 && intBits == 64)
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
else
return;
Comment on lines +99 to +100
Copy link
Preview

Copilot AI Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Early returns without error handling or logging make debugging difficult. Consider adding a diagnostic message or comment explaining why these combinations are unsupported.

Copilot uses AI. Check for mistakes.

args = {base, intExp};
} else {
unsigned realBits = cast<FloatType>(elemTy).getWidth();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am really worried about dropping the imaginary part for these cases. Imagine, somewhere in Flang we start generating complex.pow with a true complex exponent. This pass will just silently drop it and produce incorrect code. Please add a LIT test for this case.

I think we need to keep complex.pow if we cannot prove that the imaginary part is zero.

Ideally, we should have powi and powf operations in the complex dialect, so that we do not have to rely on the particular fir.convert/complex.create pattern generated by the lowering. Moreover, SSA values may become block arguments making it harder to recognize the specific pattern even more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should have powi and powf operations in the complex dialect, so that we do not have to rely on the particular fir.convert/complex.create pattern generated by the lowering. Moreover, SSA values may become block arguments making it harder to recognize the specific pattern even more.

I have added powi in #158722. I'll address the rest of this comment along with other comments tomorrow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding powi! I think adding powf is not required. I just misread the code.

auto funcTy =
builder.getFunctionType({complexTy, complexTy}, {complexTy});
if (realBits == 32)
callee = getOrDeclare(builder, loc, "cpowf", funcTy);
else if (realBits == 64)
callee = getOrDeclare(builder, loc, "cpow", funcTy);
else if (realBits == 128)
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
else
return;
Comment on lines +112 to +113
Copy link
Preview

Copilot AI Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Early return without error handling or logging makes debugging difficult. Consider adding a diagnostic message or comment explaining why this bit width is unsupported.

Suggested change
else
return;
else {
emitWarning(loc, "Unsupported complex.pow bit width: realBits=" +
std::to_string(realBits));
return;
}

Copilot uses AI. Check for mistakes.

args = {base, rhs};
}

auto call = fir::CallOp::create(builder, loc, callee, args);
if (auto fmf = op.getFastmathAttr())
call.setFastmathAttr(fmf);
op.replaceAllUsesWith(call.getResult(0));
op.erase();
});
}
1 change: 1 addition & 0 deletions flang/test/Driver/bbc-mlir-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
! CHECK-NEXT: SCFToControlFlow
! CHECK-NEXT: Canonicalizer
! CHECK-NEXT: SimplifyRegionLite
! CHECK-NEXT: ConvertComplexPow
! CHECK-NEXT: CSE
! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
Expand Down
1 change: 1 addition & 0 deletions flang/test/Driver/mlir-debug-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
Expand Down
1 change: 1 addition & 0 deletions flang/test/Driver/mlir-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
Expand Down
1 change: 1 addition & 0 deletions flang/test/Fir/basic-program.fir
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ func.func @_QQmain() {
// PASSES-NEXT: SCFToControlFlow
// PASSES-NEXT: Canonicalizer
// PASSES-NEXT: SimplifyRegionLite
// PASSES-NEXT: ConvertComplexPow
// PASSES-NEXT: CSE
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/HLFIR/binary-ops.f90
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ subroutine complex_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<complex<f32>>, !fir.dscope) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_8:.*]] = fir.call @cpowf(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, complex<f32>) -> complex<f32>
! CHECK: %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>


subroutine real_to_int_power(x, y, z)
Expand All @@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
! CHECK: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
! CHECK: %[[VAL_8:.*]] = complex.pow

subroutine extremum(c, n, l)
integer(8), intent(in) :: l
Expand Down
5 changes: 3 additions & 2 deletions flang/test/Lower/Intrinsics/pow_complex16.f90
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s

! CHECK: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
! PRECISE: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a, b
b = a ** b
end
5 changes: 3 additions & 2 deletions flang/test/Lower/Intrinsics/pow_complex16i.f90
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s

! CHECK: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(4) :: b
b = a ** b
Expand Down
5 changes: 3 additions & 2 deletions flang/test/Lower/Intrinsics/pow_complex16k.f90
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s

! CHECK: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(8) :: b
b = a ** b
Expand Down
Loading