-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[Flang] Add new ConvertComplexPow pass for Flang #158642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,124 @@ | ||||||||||||||||
//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===// | ||||||||||||||||
// | ||||||||||||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||||||||||||||
// See https://llvm.org/LICENSE.txt for license information. | ||||||||||||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||||||||||||||
// | ||||||||||||||||
//===----------------------------------------------------------------------===// | ||||||||||||||||
|
||||||||||||||||
#include "flang/Common/static-multimap-view.h" | ||||||||||||||||
#include "flang/Optimizer/Builder/FIRBuilder.h" | ||||||||||||||||
#include "flang/Optimizer/Dialect/FIRDialect.h" | ||||||||||||||||
#include "flang/Optimizer/Transforms/Passes.h" | ||||||||||||||||
#include "flang/Runtime/entry-names.h" | ||||||||||||||||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||||||||||||||||
#include "mlir/Dialect/Complex/IR/Complex.h" | ||||||||||||||||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||||||||||||||||
#include "mlir/Pass/Pass.h" | ||||||||||||||||
|
||||||||||||||||
namespace fir { | ||||||||||||||||
#define GEN_PASS_DEF_CONVERTCOMPLEXPOW | ||||||||||||||||
#include "flang/Optimizer/Transforms/Passes.h.inc" | ||||||||||||||||
} // namespace fir | ||||||||||||||||
|
||||||||||||||||
using namespace mlir; | ||||||||||||||||
|
||||||||||||||||
namespace { | ||||||||||||||||
class ConvertComplexPowPass | ||||||||||||||||
: public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> { | ||||||||||||||||
public: | ||||||||||||||||
void getDependentDialects(DialectRegistry ®istry) const override { | ||||||||||||||||
registry.insert<fir::FIROpsDialect, complex::ComplexDialect, | ||||||||||||||||
arith::ArithDialect, func::FuncDialect>(); | ||||||||||||||||
} | ||||||||||||||||
TIFitis marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||
void runOnOperation() override; | ||||||||||||||||
}; | ||||||||||||||||
} // namespace | ||||||||||||||||
|
||||||||||||||||
// Helper to declare or get a math library function. | ||||||||||||||||
static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc, | ||||||||||||||||
StringRef name, FunctionType type) { | ||||||||||||||||
if (auto func = builder.getNamedFunction(name)) | ||||||||||||||||
return func; | ||||||||||||||||
auto func = builder.createFunction(loc, name, type); | ||||||||||||||||
func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name)); | ||||||||||||||||
func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(), | ||||||||||||||||
builder.getUnitAttr()); | ||||||||||||||||
return func; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
static bool isZero(Value v) { | ||||||||||||||||
if (auto cst = v.getDefiningOp<arith::ConstantOp>()) | ||||||||||||||||
if (auto attr = dyn_cast<FloatAttr>(cst.getValue())) | ||||||||||||||||
return attr.getValue().isZero(); | ||||||||||||||||
return false; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
void ConvertComplexPowPass::runOnOperation() { | ||||||||||||||||
ModuleOp mod = getOperation(); | ||||||||||||||||
if (fir::getTargetTriple(mod).isAMDGCN()) | ||||||||||||||||
|
||||||||||||||||
return; | ||||||||||||||||
|
||||||||||||||||
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod)); | ||||||||||||||||
|
||||||||||||||||
mod.walk([&](complex::PowOp op) { | ||||||||||||||||
builder.setInsertionPoint(op); | ||||||||||||||||
Location loc = op.getLoc(); | ||||||||||||||||
auto complexTy = cast<ComplexType>(op.getType()); | ||||||||||||||||
auto elemTy = complexTy.getElementType(); | ||||||||||||||||
|
||||||||||||||||
Value base = op.getLhs(); | ||||||||||||||||
Value rhs = op.getRhs(); | ||||||||||||||||
|
||||||||||||||||
Value intExp; | ||||||||||||||||
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) { | ||||||||||||||||
if (isZero(create.getImaginary())) { | ||||||||||||||||
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) { | ||||||||||||||||
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType())) | ||||||||||||||||
intExp = conv.getValue(); | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
func::FuncOp callee; | ||||||||||||||||
SmallVector<Value> args; | ||||||||||||||||
if (intExp) { | ||||||||||||||||
unsigned realBits = cast<FloatType>(elemTy).getWidth(); | ||||||||||||||||
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth(); | ||||||||||||||||
auto funcTy = builder.getFunctionType( | ||||||||||||||||
{complexTy, builder.getIntegerType(intBits)}, {complexTy}); | ||||||||||||||||
if (realBits == 32 && intBits == 32) | ||||||||||||||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy); | ||||||||||||||||
else if (realBits == 32 && intBits == 64) | ||||||||||||||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy); | ||||||||||||||||
else if (realBits == 64 && intBits == 32) | ||||||||||||||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy); | ||||||||||||||||
else if (realBits == 64 && intBits == 64) | ||||||||||||||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy); | ||||||||||||||||
else if (realBits == 128 && intBits == 32) | ||||||||||||||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy); | ||||||||||||||||
else if (realBits == 128 && intBits == 64) | ||||||||||||||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); | ||||||||||||||||
else | ||||||||||||||||
return; | ||||||||||||||||
Comment on lines
+99
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Early returns without error handling or logging make debugging difficult. Consider adding a diagnostic message or comment explaining why these combinations are unsupported. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||
args = {base, intExp}; | ||||||||||||||||
} else { | ||||||||||||||||
unsigned realBits = cast<FloatType>(elemTy).getWidth(); | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am really worried about dropping the imaginary part for these cases. Imagine, somewhere in Flang we start generating I think we need to keep Ideally, we should have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I have added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for adding |
||||||||||||||||
auto funcTy = | ||||||||||||||||
builder.getFunctionType({complexTy, complexTy}, {complexTy}); | ||||||||||||||||
if (realBits == 32) | ||||||||||||||||
callee = getOrDeclare(builder, loc, "cpowf", funcTy); | ||||||||||||||||
else if (realBits == 64) | ||||||||||||||||
callee = getOrDeclare(builder, loc, "cpow", funcTy); | ||||||||||||||||
else if (realBits == 128) | ||||||||||||||||
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy); | ||||||||||||||||
else | ||||||||||||||||
return; | ||||||||||||||||
Comment on lines
+112
to
+113
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Early return without error handling or logging makes debugging difficult. Consider adding a diagnostic message or comment explaining why this bit width is unsupported.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||
args = {base, rhs}; | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
auto call = fir::CallOp::create(builder, loc, callee, args); | ||||||||||||||||
op.replaceAllUsesWith(call.getResult(0)); | ||||||||||||||||
op.erase(); | ||||||||||||||||
}); | ||||||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
! REQUIRES: flang-supports-f128-math | ||
! RUN: bbc -emit-fir %s -o - | FileCheck %s | ||
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s | ||
! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE" | ||
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s | ||
|
||
! CHECK: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128> | ||
! PRECISE: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128> | ||
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128> | ||
complex(16) :: a, b | ||
b = a ** b | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you have
complex.powi
, I think we can just usegenComplexMathOp<mlir::complex::powi>
orgenMathOp<mlir::complex::powi>
here.We can probably get rid of
genComplexPow
and usegenMathOp
instead.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
genComplexMathOp
would lower tolibCall
if (!forceMlirComplex && !canUseApprox && !isAMDGPU)
which means we would restrict lowering to
complex.pow
for some cases where we are currently forcing it. Is that okay?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
if (!forceMlirComplex && !canUseApprox && !isAMDGPU)
check is yet another workaround that has to be removed eventually (not in this PR).I think
genMathOp
should work here just fine or am I missing something?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I glossed over
genMathOp
, I've added this change in #158722 to removedgenComplexPow
.