From bcf4b5ada40dbb0d764eacb42047a31b16b6c89d Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Sat, 13 Sep 2025 01:39:27 +0100 Subject: [PATCH 1/4] Force lowering to complex.pow ops. --- .../flang/Optimizer/Transforms/Passes.td | 11 ++ flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 30 ++--- flang/lib/Optimizer/Passes/Pipelines.cpp | 1 + flang/lib/Optimizer/Transforms/CMakeLists.txt | 1 + .../Transforms/ConvertComplexPow.cpp | 125 ++++++++++++++++++ flang/test/Driver/bbc-mlir-pass-pipeline.f90 | 2 + .../test/Driver/mlir-debug-pass-pipeline.f90 | 2 + flang/test/Driver/mlir-pass-pipeline.f90 | 2 + flang/test/Fir/basic-program.fir | 2 + flang/test/Lower/HLFIR/binary-ops.f90 | 4 +- flang/test/Lower/Intrinsics/pow_complex16.f90 | 5 +- .../test/Lower/Intrinsics/pow_complex16i.f90 | 5 +- .../test/Lower/Intrinsics/pow_complex16k.f90 | 5 +- flang/test/Lower/power-operator.f90 | 34 ++--- flang/test/Transforms/convert-complex-pow.fir | 102 ++++++++++++++ 15 files changed, 293 insertions(+), 38 deletions(-) create mode 100644 flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp create mode 100644 flang/test/Transforms/convert-complex-pow.fir diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td index e3001454cdf19..0ed4bb66aff0d 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -551,6 +551,17 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> { "Prefer expanding without using Fortran runtime calls.">]; } +def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::func::FuncOp"> { + let summary = "Convert complex.pow operations to library calls"; + let description = [{ + Replace `complex.pow` operations with calls to the appropriate + Fortran runtime or libm functions. + }]; + let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect", + "mlir::complex::ComplexDialect", + "mlir::arith::ArithDialect"]; +} + def OptimizeArrayRepacking : Pass<"optimize-array-repacking", "mlir::func::FuncOp"> { let summary = "Optimizes redundant array repacking operations"; diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index ce1376fd209cc..466458c05dba7 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1327,18 +1327,18 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, const MathOperation &mathOp, mlir::FunctionType mathLibFuncType, llvm::ArrayRef args) { - bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN(); - if (!isAMDGPU) + if (mathRuntimeVersion == preciseVersion) return genLibCall(builder, loc, mathOp, mathLibFuncType, args); - auto complexTy = mlir::cast(mathLibFuncType.getInput(0)); - auto realTy = complexTy.getElementType(); - mlir::Value realExp = builder.createConvert(loc, realTy, args[1]); - mlir::Value zero = builder.createRealConstant(loc, realTy, 0); - mlir::Value complexExp = - builder.create(loc, complexTy, realExp, zero); - mlir::Value result = - builder.create(loc, args[0], complexExp); + mlir::Value exp = args[1]; + if (!mlir::isa(exp.getType())) { + auto realTy = complexTy.getElementType(); + mlir::Value realExp = builder.createConvert(loc, realTy, exp); + mlir::Value zero = builder.createRealConstant(loc, realTy, 0); + exp = + builder.create(loc, complexTy, realExp, zero); + } + mlir::Value result = builder.create(loc, args[0], exp); result = builder.createConvert(loc, mathLibFuncType.getResult(0), result); return result; } @@ -1668,11 +1668,11 @@ static constexpr MathOperation mathOperations[] = { {"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call}, {"pow", "cpowf", genFuncType, Ty::Complex<4>, Ty::Complex<4>>, - genComplexMathOp}, + genComplexPow}, {"pow", "cpow", genFuncType, Ty::Complex<8>, Ty::Complex<8>>, - genComplexMathOp}, + genComplexPow}, {"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16, - genLibF128Call}, + genComplexPow}, {"pow", RTNAME_STRING(FPow4i), genFuncType, Ty::Real<4>, Ty::Integer<4>>, genMathOp}, @@ -1698,7 +1698,7 @@ static constexpr MathOperation mathOperations[] = { genFuncType, Ty::Complex<8>, Ty::Integer<4>>, genComplexPow}, {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4, - genLibF128Call}, + genComplexPow}, {"pow", RTNAME_STRING(cpowk), genFuncType, Ty::Complex<4>, Ty::Integer<8>>, genComplexPow}, @@ -1706,7 +1706,7 @@ static constexpr MathOperation mathOperations[] = { genFuncType, Ty::Complex<8>, Ty::Integer<8>>, genComplexPow}, {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8, - genLibF128Call}, + genComplexPow}, {"pow-unsigned", RTNAME_STRING(UPow1), genFuncType, Ty::Integer<1>, Ty::Integer<1>>, genLibCall}, {"pow-unsigned", RTNAME_STRING(UPow2), diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 7c2777baebef1..ddcfffc9f158f 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -225,6 +225,7 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm, pm.addPass(mlir::createCanonicalizerPass(config)); pm.addPass(fir::createSimplifyRegionLite()); + pm.addPass(fir::createConvertComplexPow()); pm.addPass(mlir::createCSEPass()); if (pc.OptLevel.isOptimizingForSpeed()) diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index a8812e08c1ccd..4ec16274830fe 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -35,6 +35,7 @@ add_flang_library(FIRTransforms GenRuntimeCallsForTest.cpp SimplifyFIROperations.cpp OptimizeArrayRepacking.cpp + ConvertComplexPow.cpp DEPENDS CUFAttrs diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp new file mode 100644 index 0000000000000..8b62237cf539d --- /dev/null +++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp @@ -0,0 +1,125 @@ +//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Common/static-multimap-view.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "flang/Runtime/entry-names.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" + +namespace fir { +#define GEN_PASS_DEF_CONVERTCOMPLEXPOW +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace mlir; + +namespace { +class ConvertComplexPowPass + : public fir::impl::ConvertComplexPowBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override; +}; +} // namespace + +// Helper to declare or get a math library function. +static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc, + StringRef name, FunctionType type) { + if (auto func = builder.getNamedFunction(name)) + return func; + auto func = builder.createFunction(loc, name, type); + func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name)); + func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(), + builder.getUnitAttr()); + return func; +} + +static bool isZero(Value v) { + if (auto cst = v.getDefiningOp()) + if (auto attr = dyn_cast(cst.getValue())) + return attr.getValue().isZero(); + return false; +} + +void ConvertComplexPowPass::runOnOperation() { + auto func = getOperation(); + auto mod = func->getParentOfType(); + if (fir::getTargetTriple(mod).isAMDGCN()) + return; + + fir::FirOpBuilder builder(func, fir::getKindMapping(mod)); + + func.walk([&](complex::PowOp op) { + builder.setInsertionPoint(op); + Location loc = op.getLoc(); + auto complexTy = cast(op.getType()); + auto elemTy = complexTy.getElementType(); + + Value base = op.getLhs(); + Value rhs = op.getRhs(); + + Value intExp; + if (auto create = rhs.getDefiningOp()) { + if (isZero(create.getImaginary())) { + if (auto conv = create.getReal().getDefiningOp()) { + if (auto intTy = dyn_cast(conv.getValue().getType())) + intExp = conv.getValue(); + } + } + } + + func::FuncOp callee; + SmallVector args; + if (intExp) { + unsigned realBits = cast(elemTy).getWidth(); + unsigned intBits = cast(intExp.getType()).getWidth(); + auto funcTy = builder.getFunctionType( + {complexTy, builder.getIntegerType(intBits)}, {complexTy}); + if (realBits == 32 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy); + else if (realBits == 32 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy); + else if (realBits == 64 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy); + else if (realBits == 64 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy); + else if (realBits == 128 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy); + else if (realBits == 128 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); + else + return; + args = {base, intExp}; + } else { + unsigned realBits = cast(elemTy).getWidth(); + auto funcTy = + builder.getFunctionType({complexTy, complexTy}, {complexTy}); + if (realBits == 32) + callee = getOrDeclare(builder, loc, "cpowf", funcTy); + else if (realBits == 64) + callee = getOrDeclare(builder, loc, "cpow", funcTy); + else if (realBits == 128) + callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy); + else + return; + args = {base, rhs}; + } + + auto call = fir::CallOp::create(builder, loc, callee, args); + op.replaceAllUsesWith(call.getResult(0)); + op.erase(); + }); +} diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90 index f3791fe9f8dc3..30cb97e4455ee 100644 --- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 +++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90 @@ -69,6 +69,8 @@ ! CHECK-NEXT: SCFToControlFlow ! CHECK-NEXT: Canonicalizer ! CHECK-NEXT: SimplifyRegionLite +! CHECK-NEXT: 'func.func' Pipeline +! CHECK-NEXT: ConvertComplexPow ! CHECK-NEXT: CSE ! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd ! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90 index 42a71b2d6adc3..bb6d5509c3269 100644 --- a/flang/test/Driver/mlir-debug-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90 @@ -96,6 +96,8 @@ ! ALL-NEXT: SCFToControlFlow ! ALL-NEXT: Canonicalizer ! ALL-NEXT: SimplifyRegionLite +! ALL-NEXT: 'func.func' Pipeline +! ALL-NEXT: ConvertComplexPow ! ALL-NEXT: CSE ! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd ! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 index e85a7728fc9af..6006f6672ee72 100644 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -127,6 +127,8 @@ ! ALL-NEXT: SCFToControlFlow ! ALL-NEXT: Canonicalizer ! ALL-NEXT: SimplifyRegionLite +! ALL-NEXT: 'func.func' Pipeline +! ALL-NEXT: ConvertComplexPow ! ALL-NEXT: CSE ! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd ! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir index 0a31397efb332..a2e3cda8f2325 100644 --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -125,6 +125,8 @@ func.func @_QQmain() { // PASSES-NEXT: SCFToControlFlow // PASSES-NEXT: Canonicalizer // PASSES-NEXT: SimplifyRegionLite +// PASSES-NEXT: 'func.func' Pipeline +// PASSES-NEXT: ConvertComplexPow // PASSES-NEXT: CSE // PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd // PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90 index 72cd048ea3615..1fbd333db37c3 100644 --- a/flang/test/Lower/HLFIR/binary-ops.f90 +++ b/flang/test/Lower/HLFIR/binary-ops.f90 @@ -168,7 +168,7 @@ subroutine complex_power(x, y, z) ! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref>, !fir.dscope) -> (!fir.ref>, !fir.ref>) ! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref> ! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref> -! CHECK: %[[VAL_8:.*]] = fir.call @cpowf(%[[VAL_6]], %[[VAL_7]]) fastmath : (complex, complex) -> complex +! CHECK: %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %[[VAL_7]] fastmath : complex subroutine real_to_int_power(x, y, z) @@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z) ! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) ! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref> ! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref -! CHECK: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath : (complex, i32) -> complex +! CHECK: %[[VAL_8:.*]] = complex.pow subroutine extremum(c, n, l) integer(8), intent(in) :: l diff --git a/flang/test/Lower/Intrinsics/pow_complex16.f90 b/flang/test/Lower/Intrinsics/pow_complex16.f90 index 7467986832479..c026dd242e964 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16.f90 @@ -1,9 +1,10 @@ ! REQUIRES: flang-supports-f128-math ! RUN: bbc -emit-fir %s -o - | FileCheck %s -! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s +! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE" ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s -! CHECK: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex, complex) -> complex +! PRECISE: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex, complex) -> complex +! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath : complex complex(16) :: a, b b = a ** b end diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90 index 6f8684d9a663a..1827863a57f43 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16i.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90 @@ -1,9 +1,10 @@ ! REQUIRES: flang-supports-f128-math ! RUN: bbc -emit-fir %s -o - | FileCheck %s -! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s +! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE" ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s -! CHECK: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex, i32) -> complex +! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex, i32) -> complex +! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath : complex complex(16) :: a integer(4) :: b b = a ** b diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90 index d3765050640ae..039dfd5152a06 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16k.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90 @@ -1,9 +1,10 @@ ! REQUIRES: flang-supports-f128-math ! RUN: bbc -emit-fir %s -o - | FileCheck %s -! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s +! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE" ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s -! CHECK: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex, i64) -> complex +! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex, i64) -> complex +! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath : complex complex(16) :: a integer(8) :: b b = a ** b diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90 index 7436e031d20cb..3058927144248 100644 --- a/flang/test/Lower/power-operator.f90 +++ b/flang/test/Lower/power-operator.f90 @@ -1,10 +1,10 @@ -! RUN: bbc -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE" -! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE" -! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s --check-prefixes="FAST" -! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE" -! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,FAST" -! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="PRECISE" -! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="FAST" +! RUN: bbc -emit-fir %s -o - | FileCheck %s +! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefix=PRECISE +! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s +! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefix=PRECISE +! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s ! Test power operation lowering @@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z) complex :: x, z integer :: y z = x ** y - ! CHECK: call @_FortranAcpowi + ! CHECK: complex.pow + ! PRECISE: fir.call @_FortranAcpowi end subroutine ! CHECK-LABEL: pow_c4_i8 @@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z) complex :: x, z integer(8) :: y z = x ** y - ! CHECK: call @_FortranAcpowk + ! CHECK: complex.pow + ! PRECISE: fir.call @_FortranAcpowk end subroutine ! CHECK-LABEL: pow_c8_i4 @@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z) complex(8) :: x, z integer :: y z = x ** y - ! CHECK: call @_FortranAzpowi + ! CHECK: complex.pow + ! PRECISE: fir.call @_FortranAzpowi end subroutine ! CHECK-LABEL: pow_c8_i8 @@ -120,22 +123,23 @@ subroutine pow_c8_i8(x, y, z) complex(8) :: x, z integer(8) :: y z = x ** y - ! CHECK: call @_FortranAzpowk + ! CHECK: complex.pow + ! PRECISE: fir.call @_FortranAzpowk end subroutine ! CHECK-LABEL: pow_c4_c4 subroutine pow_c4_c4(x, y, z) complex :: x, y, z z = x ** y - ! FAST: complex.pow %{{.*}}, %{{.*}} : complex - ! PRECISE: call @cpowf + ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex + ! PRECISE: fir.call @cpowf end subroutine ! CHECK-LABEL: pow_c8_c8 subroutine pow_c8_c8(x, y, z) complex(8) :: x, y, z z = x ** y - ! FAST: complex.pow %{{.*}}, %{{.*}} : complex - ! PRECISE: call @cpow + ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex + ! PRECISE: fir.call @cpow end subroutine diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir new file mode 100644 index 0000000000000..d980817aba9b9 --- /dev/null +++ b/flang/test/Transforms/convert-complex-pow.fir @@ -0,0 +1,102 @@ +// RUN: fir-opt --convert-complex-pow %s | FileCheck %s + +module { + func.func @pow_c4_i4(%arg0: complex, %arg1: i32) -> complex { + %c0 = arith.constant 0.000000e+00 : f32 + %c1 = fir.convert %arg1 : (i32) -> f32 + %c2 = complex.create %c1, %c0 : complex + %0 = complex.pow %arg0, %c2 : complex + return %0 : complex + } + + func.func @pow_c4_i8(%arg0: complex, %arg1: i64) -> complex { + %c0 = arith.constant 0.000000e+00 : f32 + %c1 = fir.convert %arg1 : (i64) -> f32 + %c2 = complex.create %c1, %c0 : complex + %0 = complex.pow %arg0, %c2 : complex + return %0 : complex + } + + func.func @pow_c4_c4(%arg0: complex, %arg1: complex) -> complex { + %0 = complex.pow %arg0, %arg1 : complex + return %0 : complex + } + + func.func @pow_c8_i4(%arg0: complex, %arg1: i32) -> complex { + %c0 = arith.constant 0.000000e+00 : f64 + %c1 = fir.convert %arg1 : (i32) -> f64 + %c2 = complex.create %c1, %c0 : complex + %0 = complex.pow %arg0, %c2 : complex + return %0 : complex + } + + func.func @pow_c8_i8(%arg0: complex, %arg1: i64) -> complex { + %c0 = arith.constant 0.000000e+00 : f64 + %c1 = fir.convert %arg1 : (i64) -> f64 + %c2 = complex.create %c1, %c0 : complex + %0 = complex.pow %arg0, %c2 : complex + return %0 : complex + } + + func.func @pow_c8_c8(%arg0: complex, %arg1: complex) -> complex { + %0 = complex.pow %arg0, %arg1 : complex + return %0 : complex + } + + func.func @pow_c16_i4(%arg0: complex, %arg1: i32) -> complex { + %c0 = arith.constant 0.000000e+00 : f128 + %c1 = fir.convert %arg1 : (i32) -> f128 + %c2 = complex.create %c1, %c0 : complex + %0 = complex.pow %arg0, %c2 : complex + return %0 : complex + } + + func.func @pow_c16_i8(%arg0: complex, %arg1: i64) -> complex { + %c0 = arith.constant 0.000000e+00 : f128 + %c1 = fir.convert %arg1 : (i64) -> f128 + %c2 = complex.create %c1, %c0 : complex + %0 = complex.pow %arg0, %c2 : complex + return %0 : complex + } + + func.func @pow_c16_c16(%arg0: complex, %arg1: complex) -> complex { + %0 = complex.pow %arg0, %arg1 : complex + return %0 : complex + } +} + +// CHECK-LABEL: func.func @pow_c4_i4( +// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c4_i8( +// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c4_c4( +// CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex, complex) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c8_i4( +// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c8_i8( +// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c8_c8( +// CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) : (complex, complex) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c16_i4( +// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c16_i8( +// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c16_c16( +// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) : (complex, complex) -> complex +// CHECK-NOT: complex.pow From c8715a11d49e93240926a5b98e1f0b0e37b83f29 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Mon, 15 Sep 2025 18:19:50 +0100 Subject: [PATCH 2/4] Change ConverComplexPow from func to module pass. --- flang/include/flang/Optimizer/Transforms/Passes.td | 2 +- flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp | 7 +++---- flang/test/Driver/bbc-mlir-pass-pipeline.f90 | 3 +-- flang/test/Driver/mlir-debug-pass-pipeline.f90 | 3 +-- flang/test/Driver/mlir-pass-pipeline.f90 | 3 +-- flang/test/Fir/basic-program.fir | 3 +-- 6 files changed, 8 insertions(+), 13 deletions(-) diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td index 0ed4bb66aff0d..093d5de028048 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -551,7 +551,7 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> { "Prefer expanding without using Fortran runtime calls.">]; } -def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::func::FuncOp"> { +def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::ModuleOp"> { let summary = "Convert complex.pow operations to library calls"; let description = [{ Replace `complex.pow` operations with calls to the appropriate diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp index 8b62237cf539d..dced5f90d6924 100644 --- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp +++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp @@ -55,14 +55,13 @@ static bool isZero(Value v) { } void ConvertComplexPowPass::runOnOperation() { - auto func = getOperation(); - auto mod = func->getParentOfType(); + ModuleOp mod = getOperation(); if (fir::getTargetTriple(mod).isAMDGCN()) return; - fir::FirOpBuilder builder(func, fir::getKindMapping(mod)); + fir::FirOpBuilder builder(mod, fir::getKindMapping(mod)); - func.walk([&](complex::PowOp op) { + mod.walk([&](complex::PowOp op) { builder.setInsertionPoint(op); Location loc = op.getLoc(); auto complexTy = cast(op.getType()); diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90 index 30cb97e4455ee..bf2712d547a82 100644 --- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 +++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90 @@ -69,8 +69,7 @@ ! CHECK-NEXT: SCFToControlFlow ! CHECK-NEXT: Canonicalizer ! CHECK-NEXT: SimplifyRegionLite -! CHECK-NEXT: 'func.func' Pipeline -! CHECK-NEXT: ConvertComplexPow +! CHECK-NEXT: ConvertComplexPow ! CHECK-NEXT: CSE ! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd ! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90 index bb6d5509c3269..5943a3c61c342 100644 --- a/flang/test/Driver/mlir-debug-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90 @@ -96,8 +96,7 @@ ! ALL-NEXT: SCFToControlFlow ! ALL-NEXT: Canonicalizer ! ALL-NEXT: SimplifyRegionLite -! ALL-NEXT: 'func.func' Pipeline -! ALL-NEXT: ConvertComplexPow +! ALL-NEXT: ConvertComplexPow ! ALL-NEXT: CSE ! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd ! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 index 6006f6672ee72..4fd89d6f15d46 100644 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -127,8 +127,7 @@ ! ALL-NEXT: SCFToControlFlow ! ALL-NEXT: Canonicalizer ! ALL-NEXT: SimplifyRegionLite -! ALL-NEXT: 'func.func' Pipeline -! ALL-NEXT: ConvertComplexPow +! ALL-NEXT: ConvertComplexPow ! ALL-NEXT: CSE ! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd ! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir index a2e3cda8f2325..195e5ad7f9dc8 100644 --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -125,8 +125,7 @@ func.func @_QQmain() { // PASSES-NEXT: SCFToControlFlow // PASSES-NEXT: Canonicalizer // PASSES-NEXT: SimplifyRegionLite -// PASSES-NEXT: 'func.func' Pipeline -// PASSES-NEXT: ConvertComplexPow +// PASSES-NEXT: ConvertComplexPow // PASSES-NEXT: CSE // PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd // PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd From 9528edd06fd17eae8a791c1f5536a3d545a42eb7 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Wed, 17 Sep 2025 18:50:34 +0100 Subject: [PATCH 3/4] Skip running on AMDGCN through pass pipelines. Propagate fast-math attributes, update tests. --- flang/include/flang/Tools/CrossToolHelpers.h | 1 + flang/lib/Frontend/FrontendActions.cpp | 2 ++ flang/lib/Optimizer/Passes/Pipelines.cpp | 3 +- .../Transforms/ConvertComplexPow.cpp | 5 ++- flang/test/Transforms/convert-complex-pow.fir | 36 +++++++++---------- flang/tools/bbc/bbc.cpp | 1 + 6 files changed, 26 insertions(+), 22 deletions(-) diff --git a/flang/include/flang/Tools/CrossToolHelpers.h b/flang/include/flang/Tools/CrossToolHelpers.h index 335f0a45531c8..c2a4e082b129d 100644 --- a/flang/include/flang/Tools/CrossToolHelpers.h +++ b/flang/include/flang/Tools/CrossToolHelpers.h @@ -134,6 +134,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks { bool NSWOnLoopVarInc = true; ///< Add nsw flag to loop variable increments. bool EnableOpenMP = false; ///< Enable OpenMP lowering. bool EnableOpenMPSimd = false; ///< Enable OpenMP simd-only mode. + bool SkipConvertComplexPow = false; ///< Do not run complex pow conversion. std::string InstrumentFunctionEntry = ""; ///< Name of the instrument-function that is called on each ///< function-entry diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp index 23cc1e63e773d..33496bfe174a3 100644 --- a/flang/lib/Frontend/FrontendActions.cpp +++ b/flang/lib/Frontend/FrontendActions.cpp @@ -720,6 +720,7 @@ void CodeGenAction::generateLLVMIR() { const CodeGenOptions &opts = invoc.getCodeGenOpts(); const auto &mathOpts = invoc.getLoweringOpts().getMathOptions(); llvm::OptimizationLevel level = mapToLevel(opts); + const llvm::TargetMachine &targetMachine = ci.getTargetMachine(); mlir::DefaultTimingManager &timingMgr = ci.getTimingManager(); mlir::TimingScope &timingScopeRoot = ci.getTimingScopeRoot(); @@ -738,6 +739,7 @@ void CodeGenAction::generateLLVMIR() { pm.enableVerifier(/*verifyPasses=*/true); MLIRToLLVMPassPipelineConfig config(level, opts, mathOpts); + config.SkipConvertComplexPow = targetMachine.getTargetTriple().isAMDGCN(); fir::registerDefaultInlinerPass(config); if (auto vsr = getVScaleRange(ci)) { diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index ddcfffc9f158f..805f84e888798 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -225,7 +225,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm, pm.addPass(mlir::createCanonicalizerPass(config)); pm.addPass(fir::createSimplifyRegionLite()); - pm.addPass(fir::createConvertComplexPow()); + if (!pc.SkipConvertComplexPow) + pm.addPass(fir::createConvertComplexPow()); pm.addPass(mlir::createCSEPass()); if (pc.OptLevel.isOptimizingForSpeed()) diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp index dced5f90d6924..78f9d9e4f639a 100644 --- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp +++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp @@ -56,9 +56,6 @@ static bool isZero(Value v) { void ConvertComplexPowPass::runOnOperation() { ModuleOp mod = getOperation(); - if (fir::getTargetTriple(mod).isAMDGCN()) - return; - fir::FirOpBuilder builder(mod, fir::getKindMapping(mod)); mod.walk([&](complex::PowOp op) { @@ -118,6 +115,8 @@ void ConvertComplexPowPass::runOnOperation() { } auto call = fir::CallOp::create(builder, loc, callee, args); + if (auto fmf = op.getFastmathAttr()) + call.setFastmathAttr(fmf); op.replaceAllUsesWith(call.getResult(0)); op.erase(); }); diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir index d980817aba9b9..c63a44023a97c 100644 --- a/flang/test/Transforms/convert-complex-pow.fir +++ b/flang/test/Transforms/convert-complex-pow.fir @@ -5,7 +5,7 @@ module { %c0 = arith.constant 0.000000e+00 : f32 %c1 = fir.convert %arg1 : (i32) -> f32 %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 : complex + %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex return %0 : complex } @@ -13,12 +13,12 @@ module { %c0 = arith.constant 0.000000e+00 : f32 %c1 = fir.convert %arg1 : (i64) -> f32 %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 : complex + %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex return %0 : complex } func.func @pow_c4_c4(%arg0: complex, %arg1: complex) -> complex { - %0 = complex.pow %arg0, %arg1 : complex + %0 = complex.pow %arg0, %arg1 {fastmath = #arith.fastmath} : complex return %0 : complex } @@ -26,7 +26,7 @@ module { %c0 = arith.constant 0.000000e+00 : f64 %c1 = fir.convert %arg1 : (i32) -> f64 %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 : complex + %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex return %0 : complex } @@ -34,12 +34,12 @@ module { %c0 = arith.constant 0.000000e+00 : f64 %c1 = fir.convert %arg1 : (i64) -> f64 %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 : complex + %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex return %0 : complex } func.func @pow_c8_c8(%arg0: complex, %arg1: complex) -> complex { - %0 = complex.pow %arg0, %arg1 : complex + %0 = complex.pow %arg0, %arg1 {fastmath = #arith.fastmath} : complex return %0 : complex } @@ -47,7 +47,7 @@ module { %c0 = arith.constant 0.000000e+00 : f128 %c1 = fir.convert %arg1 : (i32) -> f128 %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 : complex + %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex return %0 : complex } @@ -55,48 +55,48 @@ module { %c0 = arith.constant 0.000000e+00 : f128 %c1 = fir.convert %arg1 : (i64) -> f128 %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 : complex + %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex return %0 : complex } func.func @pow_c16_c16(%arg0: complex, %arg1: complex) -> complex { - %0 = complex.pow %arg0, %arg1 : complex + %0 = complex.pow %arg0, %arg1 {fastmath = #arith.fastmath} : complex return %0 : complex } } // CHECK-LABEL: func.func @pow_c4_i4( -// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex +// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath : (complex, i32) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c4_i8( -// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex +// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) fastmath : (complex, i64) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c4_c4( -// CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex, complex) -> complex +// CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) fastmath : (complex, complex) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c8_i4( -// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex +// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) fastmath : (complex, i32) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c8_i8( -// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex +// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) fastmath : (complex, i64) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c8_c8( -// CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) : (complex, complex) -> complex +// CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) fastmath : (complex, complex) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c16_i4( -// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex +// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) fastmath : (complex, i32) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c16_i8( -// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex +// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) fastmath : (complex, i64) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c16_c16( -// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) : (complex, complex) -> complex +// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) fastmath : (complex, complex) -> complex // CHECK-NOT: complex.pow diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp index 82dff2653ad09..69a45c66a079a 100644 --- a/flang/tools/bbc/bbc.cpp +++ b/flang/tools/bbc/bbc.cpp @@ -538,6 +538,7 @@ static llvm::LogicalResult convertFortranSourceToMLIR( // Add O2 optimizer pass pipeline. MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2); + config.SkipConvertComplexPow = targetMachine.getTargetTriple().isAMDGCN(); if (enableOpenMP) config.EnableOpenMP = true; config.NSWOnLoopVarInc = !integerWrapAround; From 26329423b99dec16f6c9f93e6220527e00320092 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Wed, 17 Sep 2025 20:59:28 +0100 Subject: [PATCH 4/4] Propagate fastmath in ComplexToROCDL. Fix Targetmachine build error. --- flang/lib/Frontend/FrontendActions.cpp | 4 +- flang/test/Transforms/convert-complex-pow.fir | 127 ++++++++++-------- .../ComplexToROCDLLibraryCalls.cpp | 9 +- 3 files changed, 76 insertions(+), 64 deletions(-) diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp index 33496bfe174a3..6ebea5f8501b4 100644 --- a/flang/lib/Frontend/FrontendActions.cpp +++ b/flang/lib/Frontend/FrontendActions.cpp @@ -720,7 +720,6 @@ void CodeGenAction::generateLLVMIR() { const CodeGenOptions &opts = invoc.getCodeGenOpts(); const auto &mathOpts = invoc.getLoweringOpts().getMathOptions(); llvm::OptimizationLevel level = mapToLevel(opts); - const llvm::TargetMachine &targetMachine = ci.getTargetMachine(); mlir::DefaultTimingManager &timingMgr = ci.getTimingManager(); mlir::TimingScope &timingScopeRoot = ci.getTimingScopeRoot(); @@ -739,7 +738,8 @@ void CodeGenAction::generateLLVMIR() { pm.enableVerifier(/*verifyPasses=*/true); MLIRToLLVMPassPipelineConfig config(level, opts, mathOpts); - config.SkipConvertComplexPow = targetMachine.getTargetTriple().isAMDGCN(); + llvm::Triple pipelineTriple(invoc.getTargetOpts().triple); + config.SkipConvertComplexPow = pipelineTriple.isAMDGCN(); fir::registerDefaultInlinerPass(config); if (auto vsr = getVScaleRange(ci)) { diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir index c63a44023a97c..e09fa7316c4b0 100644 --- a/flang/test/Transforms/convert-complex-pow.fir +++ b/flang/test/Transforms/convert-complex-pow.fir @@ -2,101 +2,110 @@ module { func.func @pow_c4_i4(%arg0: complex, %arg1: i32) -> complex { - %c0 = arith.constant 0.000000e+00 : f32 - %c1 = fir.convert %arg1 : (i32) -> f32 - %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex - return %0 : complex + %c0 = arith.constant 0.0 : f32 + %0 = fir.convert %arg1 : (i32) -> f32 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex } func.func @pow_c4_i8(%arg0: complex, %arg1: i64) -> complex { - %c0 = arith.constant 0.000000e+00 : f32 - %c1 = fir.convert %arg1 : (i64) -> f32 - %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex - return %0 : complex - } - - func.func @pow_c4_c4(%arg0: complex, %arg1: complex) -> complex { - %0 = complex.pow %arg0, %arg1 {fastmath = #arith.fastmath} : complex - return %0 : complex + %c0 = arith.constant 0.0 : f32 + %0 = fir.convert %arg1 : (i64) -> f32 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex } func.func @pow_c8_i4(%arg0: complex, %arg1: i32) -> complex { - %c0 = arith.constant 0.000000e+00 : f64 - %c1 = fir.convert %arg1 : (i32) -> f64 - %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex - return %0 : complex + %c0 = arith.constant 0.0 : f64 + %0 = fir.convert %arg1 : (i32) -> f64 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex } func.func @pow_c8_i8(%arg0: complex, %arg1: i64) -> complex { - %c0 = arith.constant 0.000000e+00 : f64 - %c1 = fir.convert %arg1 : (i64) -> f64 - %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex - return %0 : complex - } - - func.func @pow_c8_c8(%arg0: complex, %arg1: complex) -> complex { - %0 = complex.pow %arg0, %arg1 {fastmath = #arith.fastmath} : complex - return %0 : complex + %c0 = arith.constant 0.0 : f64 + %0 = fir.convert %arg1 : (i64) -> f64 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex } func.func @pow_c16_i4(%arg0: complex, %arg1: i32) -> complex { - %c0 = arith.constant 0.000000e+00 : f128 - %c1 = fir.convert %arg1 : (i32) -> f128 - %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex - return %0 : complex + %c0 = arith.constant 0.0 : f128 + %0 = fir.convert %arg1 : (i32) -> f128 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex } func.func @pow_c16_i8(%arg0: complex, %arg1: i64) -> complex { - %c0 = arith.constant 0.000000e+00 : f128 - %c1 = fir.convert %arg1 : (i64) -> f128 - %c2 = complex.create %c1, %c0 : complex - %0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath} : complex - return %0 : complex + %c0 = arith.constant 0.0 : f128 + %0 = fir.convert %arg1 : (i64) -> f128 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex + } + + func.func @pow_c4_fast(%arg0: complex, %arg1: f32) -> complex { + %c1 = arith.constant 1.0 : f32 + %0 = complex.create %arg1, %c1 : complex + %1 = complex.pow %arg0, %0 fastmath : complex + return %1 : complex + } + + func.func @pow_c8_complex(%arg0: complex, %arg1: f64) -> complex { + %c2 = arith.constant 2.0 : f64 + %0 = complex.create %arg1, %c2 : complex + %1 = complex.pow %arg0, %0 : complex + return %1 : complex } - func.func @pow_c16_c16(%arg0: complex, %arg1: complex) -> complex { - %0 = complex.pow %arg0, %arg1 {fastmath = #arith.fastmath} : complex - return %0 : complex + func.func @pow_c16_complex(%arg0: complex, %arg1: f128) -> complex { + %c3 = arith.constant 3.0 : f128 + %0 = complex.create %arg1, %c3 : complex + %1 = complex.pow %arg0, %0 : complex + return %1 : complex } } // CHECK-LABEL: func.func @pow_c4_i4( -// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath : (complex, i32) -> complex +// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c4_i8( -// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) fastmath : (complex, i64) -> complex -// CHECK-NOT: complex.pow - -// CHECK-LABEL: func.func @pow_c4_c4( -// CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) fastmath : (complex, complex) -> complex +// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c8_i4( -// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) fastmath : (complex, i32) -> complex +// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c8_i8( -// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) fastmath : (complex, i64) -> complex -// CHECK-NOT: complex.pow - -// CHECK-LABEL: func.func @pow_c8_c8( -// CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) fastmath : (complex, complex) -> complex +// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c16_i4( -// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) fastmath : (complex, i32) -> complex +// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex // CHECK-NOT: complex.pow // CHECK-LABEL: func.func @pow_c16_i8( -// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) fastmath : (complex, i64) -> complex +// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex // CHECK-NOT: complex.pow -// CHECK-LABEL: func.func @pow_c16_c16( -// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) fastmath : (complex, complex) -> complex +// CHECK-LABEL: func.func @pow_c4_fast( +// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex +// CHECK: fir.call @cpowf(%{{.*}}, %[[EXP]]) fastmath : (complex, complex) -> complex // CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c8_complex( +// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex +// CHECK: fir.call @cpow(%{{.*}}, %[[EXP]]) : (complex, complex) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c16_complex( +// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex +// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex, complex) -> complex +// CHECK-NOT: complex.pow \ No newline at end of file diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 0372f32d6b6df..72b1fa6e833f9 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -64,9 +64,12 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern { LogicalResult matchAndRewrite(complex::PowOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); - Value logBase = complex::LogOp::create(rewriter, loc, op.getLhs()); - Value mul = complex::MulOp::create(rewriter, loc, op.getRhs(), logBase); - Value exp = complex::ExpOp::create(rewriter, loc, mul); + auto fastmath = op.getFastmathAttr(); + Value logBase = + complex::LogOp::create(rewriter, loc, op.getLhs(), fastmath); + Value mul = + complex::MulOp::create(rewriter, loc, op.getRhs(), logBase, fastmath); + Value exp = complex::ExpOp::create(rewriter, loc, mul, fastmath); rewriter.replaceOp(op, exp); return success(); }