From 57181a27698fba69200ec20af6c5743acdc57f3a Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Tue, 12 Aug 2025 14:06:38 +0100 Subject: [PATCH 1/6] [MLIR] Add cpow support in ComplexToROCDLLibraryCalls This PR contributes the following changes: 1. Force lowering to complex.pow ops for the amdgcn-amd-amdhsa target. 2. Convert complex.pow(z, w) -> complex.exp(w * complex.log(z)). 3. Convert x ** 2 -> x * x, x ** 3 -> x * x * x, ... x ** 8 -> x * x... . --- flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 34 +++++++++++++++-- flang/test/Lower/amdgcn-complex.f90 | 22 +++++++---- flang/test/Lower/power-operator.f90 | 12 ++++-- .../ComplexToROCDLLibraryCalls.cpp | 38 ++++++++++++++++++- .../complex-to-rocdl-library-calls.mlir | 27 +++++++++++++ 5 files changed, 115 insertions(+), 18 deletions(-) diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 22193f0de88a1..74279a7d72078 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1276,6 +1276,28 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc, return result; } +mlir::Value genComplexPowI(fir::FirOpBuilder &builder, mlir::Location loc, + const MathOperation &mathOp, + mlir::FunctionType mathLibFuncType, + llvm::ArrayRef args) { + bool canUseApprox = mlir::arith::bitEnumContainsAny( + builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn); + bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN(); + if (!forceMlirComplex && !canUseApprox && !isAMDGPU) + return genLibCall(builder, loc, mathOp, mathLibFuncType, args); + + auto complexTy = mlir::cast(mathLibFuncType.getInput(0)); + auto realTy = complexTy.getElementType(); + mlir::Value realExp = builder.createConvert(loc, realTy, args[1]); + mlir::Value zero = builder.createRealConstant(loc, realTy, 0); + mlir::Value complexExp = + builder.create(loc, complexTy, realExp, zero); + mlir::Value result = + builder.create(loc, args[0], complexExp); + result = builder.createConvert(loc, mathLibFuncType.getResult(0), result); + return result; +} + /// Mapping between mathematical intrinsic operations and MLIR operations /// of some appropriate dialect (math, complex, etc.) or libm calls. /// TODO: support remaining Fortran math intrinsics. @@ -1625,15 +1647,19 @@ static constexpr MathOperation mathOperations[] = { genFuncType, Ty::Real<16>, Ty::Integer<8>>, genMathOp}, {"pow", RTNAME_STRING(cpowi), - genFuncType, Ty::Complex<4>, Ty::Integer<4>>, genLibCall}, + genFuncType, Ty::Complex<4>, Ty::Integer<4>>, + genComplexPowI}, {"pow", RTNAME_STRING(zpowi), - genFuncType, Ty::Complex<8>, Ty::Integer<4>>, genLibCall}, + genFuncType, Ty::Complex<8>, Ty::Integer<4>>, + genComplexPowI}, {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4, genLibF128Call}, {"pow", RTNAME_STRING(cpowk), - genFuncType, Ty::Complex<4>, Ty::Integer<8>>, genLibCall}, + genFuncType, Ty::Complex<4>, Ty::Integer<8>>, + genComplexPowI}, {"pow", RTNAME_STRING(zpowk), - genFuncType, Ty::Complex<8>, Ty::Integer<8>>, genLibCall}, + genFuncType, Ty::Complex<8>, Ty::Integer<8>>, + genComplexPowI}, {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8, genLibF128Call}, {"remainder", "remainderf", diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90 index f15c7db2b7316..3d52355d3d50a 100644 --- a/flang/test/Lower/amdgcn-complex.f90 +++ b/flang/test/Lower/amdgcn-complex.f90 @@ -1,21 +1,27 @@ ! REQUIRES: amdgpu-registered-target -! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir -flang-deprecated-no-hlfir %s -o - | FileCheck %s +! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir %s -o - | FileCheck %s +! CHECK-LABEL: func @_QPcabsf_test( +! CHECK: complex.abs +! CHECK-NOT: fir.call @cabsf subroutine cabsf_test(a, b) complex :: a real :: b b = abs(a) end subroutine -! CHECK-LABEL: func @_QPcabsf_test( -! CHECK: complex.abs -! CHECK-NOT: fir.call @cabsf - +! CHECK-LABEL: func @_QPcexpf_test( +! CHECK: complex.exp +! CHECK-NOT: fir.call @cexpf subroutine cexpf_test(a, b) complex :: a, b b = exp(a) end subroutine -! CHECK-LABEL: func @_QPcexpf_test( -! CHECK: complex.exp -! CHECK-NOT: fir.call @cexpf +! CHECK-LABEL: func @_QPpow_test( +! CHECK: complex.pow +! CHECK-NOT: fir.call @_FortranAcpowi +subroutine pow_test(a, b) + complex :: a, b + a = b**2 +end subroutine pow_test diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90 index 7436e031d20cb..2a0a09e090dde 100644 --- a/flang/test/Lower/power-operator.f90 +++ b/flang/test/Lower/power-operator.f90 @@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z) complex :: x, z integer :: y z = x ** y - ! CHECK: call @_FortranAcpowi + ! PRECISE: call @_FortranAcpowi + ! FAST: complex.pow %{{.*}}, %{{.*}} : complex end subroutine ! CHECK-LABEL: pow_c4_i8 @@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z) complex :: x, z integer(8) :: y z = x ** y - ! CHECK: call @_FortranAcpowk + ! PRECISE: call @_FortranAcpowk + ! FAST: complex.pow %{{.*}}, %{{.*}} : complex end subroutine ! CHECK-LABEL: pow_c8_i4 @@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z) complex(8) :: x, z integer :: y z = x ** y - ! CHECK: call @_FortranAzpowi + ! PRECISE: call @_FortranAzpowi + ! FAST: complex.pow %{{.*}}, %{{.*}} : complex end subroutine ! CHECK-LABEL: pow_c8_i8 @@ -120,7 +123,8 @@ subroutine pow_c8_i8(x, y, z) complex(8) :: x, z integer(8) :: y z = x ** y - ! CHECK: call @_FortranAzpowk + ! PRECISE: call @_FortranAzpowk + ! FAST: complex.pow %{{.*}}, %{{.*}} : complex end subroutine ! CHECK-LABEL: pow_c4_c4 diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index b3d6d59e25bd0..558fcdf782800 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -56,10 +56,43 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern { private: std::string funcName; }; + +// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z)) +struct PowOpToROCDLLibraryCalls : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(complex::PowOp op, + PatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + if (auto constOp = op.getRhs().getDefiningOp()) { + ArrayAttr value = constOp.getValue(); + if (value.size() == 2) { + auto real = dyn_cast(value[0]); + auto imag = dyn_cast(value[1]); + if (real && imag && imag.getValue().isZero()) + for (int i = 2; i <= 8; ++i) + if (real.getValue().isExactlyValue(i)) { + Value base = op.getLhs(); + Value result = base; + for (int j = 1; j < i; ++j) + result = rewriter.create(loc, result, base); + rewriter.replaceOp(op, result); + return success(); + } + } + } + Value logBase = rewriter.create(loc, op.getLhs()); + Value mul = rewriter.create(loc, op.getRhs(), logBase); + Value exp = rewriter.create(loc, mul); + rewriter.replaceOp(op, exp); + return success(); + } +}; } // namespace void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); patterns.add>( patterns.getContext(), "__ocml_cabs_f32"); patterns.add>( @@ -110,9 +143,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect(); + target.addLegalOp(); target.addIllegalOp(); + complex::LogOp, complex::PowOp, complex::SinOp, + complex::SqrtOp, complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index 82936d89e8ac1..ef6ae74a45c1c 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -57,6 +57,33 @@ func.func @log_caller(%f: complex, %d: complex) -> (complex, comp return %lf, %ld : complex, complex } +//CHECK-LABEL: @pow_caller +//CHECK: (%[[Z:.*]]: complex, %[[W:.*]]: complex) +func.func @pow_caller(%z: complex, %w: complex) -> complex { + // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]]) + // CHECK: %[[MUL:.*]] = complex.mul %[[W]], %[[LOG]] + // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]]) + // CHECK: return %[[EXP]] + %r = complex.pow %z, %w : complex + return %r : complex +} + +// CHECK-LABEL: @pow_int_caller +func.func @pow_int_caller(%f : complex, %d : complex) + ->(complex, complex) { + // CHECK-NOT: call @__ocml + // CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex + %c2 = complex.constant [2.0 : f32, 0.0 : f32] : complex + %p2 = complex.pow %f, %c2 : complex + // CHECK-NOT: call @__ocml + // CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex + // CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex + %c3 = complex.constant [3.0 : f64, 0.0 : f64] : complex + %p3 = complex.pow %d, %c3 : complex + // CHECK: return %[[M2]], %[[M3B]] + return %p2, %p3 : complex, complex +} + //CHECK-LABEL: @sin_caller func.func @sin_caller(%f: complex, %d: complex) -> (complex, complex) { // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}}) From 9f0145439c29f91a0b71f39675cd26603a585e20 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Tue, 12 Aug 2025 20:07:08 +0100 Subject: [PATCH 2/6] Change constant pow special case handling from complex::ConstantOp to complex::CreateOp. --- flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 45 ++++++++-------- .../ComplexToROCDLLibraryCalls.cpp | 54 +++++++++++++------ .../complex-to-rocdl-library-calls.mlir | 12 +++-- 3 files changed, 70 insertions(+), 41 deletions(-) diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 74279a7d72078..89866bb143fba 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1276,10 +1276,10 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc, return result; } -mlir::Value genComplexPowI(fir::FirOpBuilder &builder, mlir::Location loc, - const MathOperation &mathOp, - mlir::FunctionType mathLibFuncType, - llvm::ArrayRef args) { +mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, + const MathOperation &mathOp, + mlir::FunctionType mathLibFuncType, + llvm::ArrayRef args) { bool canUseApprox = mlir::arith::bitEnumContainsAny( builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn); bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN(); @@ -1648,18 +1648,18 @@ static constexpr MathOperation mathOperations[] = { genMathOp}, {"pow", RTNAME_STRING(cpowi), genFuncType, Ty::Complex<4>, Ty::Integer<4>>, - genComplexPowI}, + genComplexPow}, {"pow", RTNAME_STRING(zpowi), genFuncType, Ty::Complex<8>, Ty::Integer<4>>, - genComplexPowI}, + genComplexPow}, {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4, genLibF128Call}, {"pow", RTNAME_STRING(cpowk), genFuncType, Ty::Complex<4>, Ty::Integer<8>>, - genComplexPowI}, + genComplexPow}, {"pow", RTNAME_STRING(zpowk), genFuncType, Ty::Complex<8>, Ty::Integer<8>>, - genComplexPowI}, + genComplexPow}, {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8, genLibF128Call}, {"remainder", "remainderf", @@ -4058,21 +4058,20 @@ void IntrinsicLibrary::genExecuteCommandLine( mlir::Value waitAddr = fir::getBase(wait); mlir::Value waitIsPresentAtRuntime = builder.genIsNotNullAddr(loc, waitAddr); - waitBool = builder - .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime, - /*withElseRegion=*/true) - .genThen([&]() { - auto waitLoad = - fir::LoadOp::create(builder, loc, waitAddr); - mlir::Value cast = - builder.createConvert(loc, i1Ty, waitLoad); - fir::ResultOp::create(builder, loc, cast); - }) - .genElse([&]() { - mlir::Value trueVal = builder.createBool(loc, true); - fir::ResultOp::create(builder, loc, trueVal); - }) - .getResults()[0]; + waitBool = + builder + .genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime, + /*withElseRegion=*/true) + .genThen([&]() { + auto waitLoad = fir::LoadOp::create(builder, loc, waitAddr); + mlir::Value cast = builder.createConvert(loc, i1Ty, waitLoad); + fir::ResultOp::create(builder, loc, cast); + }) + .genElse([&]() { + mlir::Value trueVal = builder.createBool(loc, true); + fir::ResultOp::create(builder, loc, trueVal); + }) + .getResults()[0]; } mlir::Value exitstatBox = diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 558fcdf782800..3bb40dd705cc2 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" @@ -58,29 +59,52 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern { }; // Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z)) +// Rewrite complex.pow(z, i) -> z * z ... * z for 2 >= i <=8 struct PowOpToROCDLLibraryCalls : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(complex::PowOp op, PatternRewriter &rewriter) const final { auto loc = op.getLoc(); - if (auto constOp = op.getRhs().getDefiningOp()) { - ArrayAttr value = constOp.getValue(); - if (value.size() == 2) { - auto real = dyn_cast(value[0]); - auto imag = dyn_cast(value[1]); - if (real && imag && imag.getValue().isZero()) - for (int i = 2; i <= 8; ++i) - if (real.getValue().isExactlyValue(i)) { - Value base = op.getLhs(); - Value result = base; - for (int j = 1; j < i; ++j) - result = rewriter.create(loc, result, base); - rewriter.replaceOp(op, result); - return success(); - } + + auto peelConst = [&](Value val) -> std::optional { + while (val) { + Operation *defOp = val.getDefiningOp(); + if (!defOp) + return std::nullopt; + + if (auto constVal = dyn_cast(defOp)) + return dyn_cast(constVal.getValue()); + + if (defOp->getName().getStringRef() == "fir.convert" && + defOp->getNumOperands() == 1) { + val = defOp->getOperand(0); + continue; + } + return std::nullopt; + } + return std::nullopt; + }; + + if (auto createOp = op.getRhs().getDefiningOp()) { + auto image = peelConst(createOp.getImaginary()); + auto real = peelConst(createOp.getReal()); + if (image && real) { + auto imagFloat = dyn_cast(*image); + if (imagFloat && imagFloat.getValue().isZero()) { + auto realInt = dyn_cast(*real); + if (realInt && realInt.getInt() >= 2 && realInt.getInt() <= 8) { + Value base = op.getLhs(); + Value result = base; + for (int i = 1; i < realInt.getInt(); ++i) + result = rewriter.create(loc, result, base); + rewriter.replaceOp(op, result); + return success(); + } + } } } + Value logBase = rewriter.create(loc, op.getLhs()); Value mul = rewriter.create(loc, op.getRhs(), logBase); Value exp = rewriter.create(loc, mul); diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index ef6ae74a45c1c..ba0dd92e20747 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s +// RUN: mlir-opt %s --allow-unregistered-dialect -convert-complex-to-rocdl-library-calls | FileCheck %s // CHECK-DAG: @__ocml_cabs_f32(complex) -> f32 // CHECK-DAG: @__ocml_cabs_f64(complex) -> f64 @@ -73,12 +73,18 @@ func.func @pow_int_caller(%f : complex, %d : complex) ->(complex, complex) { // CHECK-NOT: call @__ocml // CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex - %c2 = complex.constant [2.0 : f32, 0.0 : f32] : complex + %c2_i32 = arith.constant 2 : i32 + %c2r = "fir.convert"(%c2_i32) : (i32) -> f32 + %c2i = arith.constant 0.0 : f32 + %c2 = complex.create %c2r, %c2i : complex %p2 = complex.pow %f, %c2 : complex // CHECK-NOT: call @__ocml // CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex // CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex - %c3 = complex.constant [3.0 : f64, 0.0 : f64] : complex + %c3_i32 = arith.constant 3 : i32 + %c3r = "fir.convert"(%c3_i32) : (i32) -> f64 + %c3i = arith.constant 0.0 : f64 + %c3 = complex.create %c3r, %c3i : complex %p3 = complex.pow %d, %c3 : complex // CHECK: return %[[M2]], %[[M3B]] return %p2, %p3 : complex, complex From dce85c3dbb3d503c945b6088260bb572d05f1933 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Thu, 14 Aug 2025 13:12:30 +0100 Subject: [PATCH 3/6] Move cpow constant optimisation to Fortran lowering. --- flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 8 ++++ flang/test/Lower/amdgcn-complex.f90 | 17 ++++++-- flang/test/Lower/power-operator.f90 | 8 ++++ .../ComplexToROCDLLibraryCalls.cpp | 41 ------------------- .../complex-to-rocdl-library-calls.mlir | 22 ---------- 5 files changed, 29 insertions(+), 67 deletions(-) diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 89866bb143fba..a424007eee799 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1280,6 +1280,14 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, const MathOperation &mathOp, mlir::FunctionType mathLibFuncType, llvm::ArrayRef args) { + if (auto expInt = fir::getIntIfConstant(args[1])) + if (*expInt >= 2 && *expInt <= 8) { + mlir::Value result = args[0]; + for (int i = 1; i < *expInt; ++i) + result = builder.create(loc, result, args[0]); + return builder.createConvert(loc, mathLibFuncType.getResult(0), result); + } + bool canUseApprox = mlir::arith::bitEnumContainsAny( builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn); bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN(); diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90 index 3d52355d3d50a..dab8cb4034883 100644 --- a/flang/test/Lower/amdgcn-complex.f90 +++ b/flang/test/Lower/amdgcn-complex.f90 @@ -18,10 +18,19 @@ subroutine cexpf_test(a, b) b = exp(a) end subroutine -! CHECK-LABEL: func @_QPpow_test( -! CHECK: complex.pow +! CHECK-LABEL: func @_QPpow_test1( +! CHECK: complex.mul +! CHECK-NOT: complex.pow ! CHECK-NOT: fir.call @_FortranAcpowi -subroutine pow_test(a, b) +subroutine pow_test1(a, b) complex :: a, b a = b**2 -end subroutine pow_test +end subroutine pow_test1 + +! CHECK-LABEL: func @_QPpow_test2( +! CHECK: complex.pow +! CHECK-NOT: fir.call @_FortranAcpowi +subroutine pow_test2(a, b, c) + complex :: a, b, c + a = b**c +end subroutine pow_test2 diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90 index 2a0a09e090dde..a8943a3aa8c0b 100644 --- a/flang/test/Lower/power-operator.f90 +++ b/flang/test/Lower/power-operator.f90 @@ -143,3 +143,11 @@ subroutine pow_c8_c8(x, y, z) ! PRECISE: call @cpow end subroutine +! CHECK-LABEL: pow_const +subroutine pow_const(a, b) + complex :: a, b + ! CHECK-NOT: complex.pow + ! CHECK-NOT: @_FortranAcpowi + ! CHECK-COUNT-3: complex.mul + a = b**4 +end subroutine diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 3bb40dd705cc2..cc0e93248a114 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" @@ -59,52 +58,12 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern { }; // Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z)) -// Rewrite complex.pow(z, i) -> z * z ... * z for 2 >= i <=8 struct PowOpToROCDLLibraryCalls : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(complex::PowOp op, PatternRewriter &rewriter) const final { auto loc = op.getLoc(); - - auto peelConst = [&](Value val) -> std::optional { - while (val) { - Operation *defOp = val.getDefiningOp(); - if (!defOp) - return std::nullopt; - - if (auto constVal = dyn_cast(defOp)) - return dyn_cast(constVal.getValue()); - - if (defOp->getName().getStringRef() == "fir.convert" && - defOp->getNumOperands() == 1) { - val = defOp->getOperand(0); - continue; - } - return std::nullopt; - } - return std::nullopt; - }; - - if (auto createOp = op.getRhs().getDefiningOp()) { - auto image = peelConst(createOp.getImaginary()); - auto real = peelConst(createOp.getReal()); - if (image && real) { - auto imagFloat = dyn_cast(*image); - if (imagFloat && imagFloat.getValue().isZero()) { - auto realInt = dyn_cast(*real); - if (realInt && realInt.getInt() >= 2 && realInt.getInt() <= 8) { - Value base = op.getLhs(); - Value result = base; - for (int i = 1; i < realInt.getInt(); ++i) - result = rewriter.create(loc, result, base); - rewriter.replaceOp(op, result); - return success(); - } - } - } - } - Value logBase = rewriter.create(loc, op.getLhs()); Value mul = rewriter.create(loc, op.getRhs(), logBase); Value exp = rewriter.create(loc, mul); diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index ba0dd92e20747..080ba4f0ff67b 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -68,28 +68,6 @@ func.func @pow_caller(%z: complex, %w: complex) -> complex { return %r : complex } -// CHECK-LABEL: @pow_int_caller -func.func @pow_int_caller(%f : complex, %d : complex) - ->(complex, complex) { - // CHECK-NOT: call @__ocml - // CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex - %c2_i32 = arith.constant 2 : i32 - %c2r = "fir.convert"(%c2_i32) : (i32) -> f32 - %c2i = arith.constant 0.0 : f32 - %c2 = complex.create %c2r, %c2i : complex - %p2 = complex.pow %f, %c2 : complex - // CHECK-NOT: call @__ocml - // CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex - // CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex - %c3_i32 = arith.constant 3 : i32 - %c3r = "fir.convert"(%c3_i32) : (i32) -> f64 - %c3i = arith.constant 0.0 : f64 - %c3 = complex.create %c3r, %c3i : complex - %p3 = complex.pow %d, %c3 : complex - // CHECK: return %[[M2]], %[[M3B]] - return %p2, %p3 : complex, complex -} - //CHECK-LABEL: @sin_caller func.func @sin_caller(%f: complex, %d: complex) -> (complex, complex) { // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}}) From 0cebae8de1fb432150811f319cbcd12522c2321e Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Tue, 19 Aug 2025 16:07:57 +0100 Subject: [PATCH 4/6] Remove constant optimisation. --- flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 8 -------- flang/test/Lower/amdgcn-complex.f90 | 15 +++------------ flang/test/Lower/power-operator.f90 | 9 --------- 3 files changed, 3 insertions(+), 29 deletions(-) diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index a424007eee799..89866bb143fba 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1280,14 +1280,6 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, const MathOperation &mathOp, mlir::FunctionType mathLibFuncType, llvm::ArrayRef args) { - if (auto expInt = fir::getIntIfConstant(args[1])) - if (*expInt >= 2 && *expInt <= 8) { - mlir::Value result = args[0]; - for (int i = 1; i < *expInt; ++i) - result = builder.create(loc, result, args[0]); - return builder.createConvert(loc, mathLibFuncType.getResult(0), result); - } - bool canUseApprox = mlir::arith::bitEnumContainsAny( builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn); bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN(); diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90 index dab8cb4034883..4ee5de4d2842e 100644 --- a/flang/test/Lower/amdgcn-complex.f90 +++ b/flang/test/Lower/amdgcn-complex.f90 @@ -18,19 +18,10 @@ subroutine cexpf_test(a, b) b = exp(a) end subroutine -! CHECK-LABEL: func @_QPpow_test1( -! CHECK: complex.mul -! CHECK-NOT: complex.pow -! CHECK-NOT: fir.call @_FortranAcpowi -subroutine pow_test1(a, b) - complex :: a, b - a = b**2 -end subroutine pow_test1 - -! CHECK-LABEL: func @_QPpow_test2( +! CHECK-LABEL: func @_QPpow_test( ! CHECK: complex.pow ! CHECK-NOT: fir.call @_FortranAcpowi -subroutine pow_test2(a, b, c) +subroutine pow_test(a, b, c) complex :: a, b, c a = b**c -end subroutine pow_test2 +end subroutine pow_test diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90 index a8943a3aa8c0b..ebce4f52d449d 100644 --- a/flang/test/Lower/power-operator.f90 +++ b/flang/test/Lower/power-operator.f90 @@ -142,12 +142,3 @@ subroutine pow_c8_c8(x, y, z) ! FAST: complex.pow %{{.*}}, %{{.*}} : complex ! PRECISE: call @cpow end subroutine - -! CHECK-LABEL: pow_const -subroutine pow_const(a, b) - complex :: a, b - ! CHECK-NOT: complex.pow - ! CHECK-NOT: @_FortranAcpowi - ! CHECK-COUNT-3: complex.mul - a = b**4 -end subroutine From 1325a564dcdd82f23dc304ba0ba031145a6ce02e Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Wed, 20 Aug 2025 15:09:53 +0100 Subject: [PATCH 5/6] Address nit comments. --- .../ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index cc0e93248a114..e03979f731b5e 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -63,7 +64,7 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern { LogicalResult matchAndRewrite(complex::PowOp op, PatternRewriter &rewriter) const final { - auto loc = op.getLoc(); + Location loc = op.getLoc(); Value logBase = rewriter.create(loc, op.getLhs()); Value mul = rewriter.create(loc, op.getRhs(), logBase); Value exp = rewriter.create(loc, mul); From a677ee9890bc6ac021b8a3e62d992d545af1606e Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Wed, 20 Aug 2025 17:27:57 +0100 Subject: [PATCH 6/6] Address reviewer comments. --- flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 4 +--- flang/test/Lower/power-operator.f90 | 13 +++++-------- .../ComplexToROCDLLibraryCalls.cpp | 1 - 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 89866bb143fba..2375764222efa 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1280,10 +1280,8 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, const MathOperation &mathOp, mlir::FunctionType mathLibFuncType, llvm::ArrayRef args) { - bool canUseApprox = mlir::arith::bitEnumContainsAny( - builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn); bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN(); - if (!forceMlirComplex && !canUseApprox && !isAMDGPU) + if (!isAMDGPU) return genLibCall(builder, loc, mathOp, mathLibFuncType, args); auto complexTy = mlir::cast(mathLibFuncType.getInput(0)); diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90 index ebce4f52d449d..7436e031d20cb 100644 --- a/flang/test/Lower/power-operator.f90 +++ b/flang/test/Lower/power-operator.f90 @@ -96,8 +96,7 @@ subroutine pow_c4_i4(x, y, z) complex :: x, z integer :: y z = x ** y - ! PRECISE: call @_FortranAcpowi - ! FAST: complex.pow %{{.*}}, %{{.*}} : complex + ! CHECK: call @_FortranAcpowi end subroutine ! CHECK-LABEL: pow_c4_i8 @@ -105,8 +104,7 @@ subroutine pow_c4_i8(x, y, z) complex :: x, z integer(8) :: y z = x ** y - ! PRECISE: call @_FortranAcpowk - ! FAST: complex.pow %{{.*}}, %{{.*}} : complex + ! CHECK: call @_FortranAcpowk end subroutine ! CHECK-LABEL: pow_c8_i4 @@ -114,8 +112,7 @@ subroutine pow_c8_i4(x, y, z) complex(8) :: x, z integer :: y z = x ** y - ! PRECISE: call @_FortranAzpowi - ! FAST: complex.pow %{{.*}}, %{{.*}} : complex + ! CHECK: call @_FortranAzpowi end subroutine ! CHECK-LABEL: pow_c8_i8 @@ -123,8 +120,7 @@ subroutine pow_c8_i8(x, y, z) complex(8) :: x, z integer(8) :: y z = x ** y - ! PRECISE: call @_FortranAzpowk - ! FAST: complex.pow %{{.*}}, %{{.*}} : complex + ! CHECK: call @_FortranAzpowk end subroutine ! CHECK-LABEL: pow_c4_c4 @@ -142,3 +138,4 @@ subroutine pow_c8_c8(x, y, z) ! FAST: complex.pow %{{.*}}, %{{.*}} : complex ! PRECISE: call @cpow end subroutine + diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index e03979f731b5e..7a3a7fdb73e5e 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -9,7 +9,6 @@ #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h"