Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 9 additions & 29 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1323,26 +1323,6 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
return result;
}

mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
const MathOperation &mathOp,
mlir::FunctionType mathLibFuncType,
llvm::ArrayRef<mlir::Value> args) {
if (mathRuntimeVersion == preciseVersion)
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
mlir::Value exp = args[1];
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
auto realTy = complexTy.getElementType();
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
exp =
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
}
mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
}

/// Mapping between mathematical intrinsic operations and MLIR operations
/// of some appropriate dialect (math, complex, etc.) or libm calls.
/// TODO: support remaining Fortran math intrinsics.
Expand Down Expand Up @@ -1668,11 +1648,11 @@ static constexpr MathOperation mathOperations[] = {
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
{"pow", "cpowf",
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
genComplexPow},
genMathOp<mlir::complex::PowOp>},
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
genComplexPow},
genMathOp<mlir::complex::PowOp>},
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
genComplexPow},
genMathOp<mlir::complex::PowOp>},
{"pow", RTNAME_STRING(FPow4i),
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
genMathOp<mlir::math::FPowIOp>},
Expand All @@ -1693,20 +1673,20 @@ static constexpr MathOperation mathOperations[] = {
genMathOp<mlir::math::FPowIOp>},
{"pow", RTNAME_STRING(cpowi),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(zpowi),
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(cpowk),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(zpowk),
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
genComplexPow},
genMathOp<mlir::complex::PowiOp>},
{"pow-unsigned", RTNAME_STRING(UPow1),
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
{"pow-unsigned", RTNAME_STRING(UPow2),
Expand Down
66 changes: 28 additions & 38 deletions flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,39 +47,19 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
return func;
}

static bool isZero(Value v) {
if (auto cst = v.getDefiningOp<arith::ConstantOp>())
if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
return attr.getValue().isZero();
return false;
}

void ConvertComplexPowPass::runOnOperation() {
ModuleOp mod = getOperation();
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));

mod.walk([&](complex::PowOp op) {
builder.setInsertionPoint(op);
Location loc = op.getLoc();
auto complexTy = cast<ComplexType>(op.getType());
auto elemTy = complexTy.getElementType();

Value base = op.getLhs();
Value rhs = op.getRhs();

Value intExp;
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
if (isZero(create.getImaginary())) {
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
intExp = conv.getValue();
}
}
}

func::FuncOp callee;
SmallVector<Value> args;
if (intExp) {
mod.walk([&](Operation *op) {
if (auto powIop = dyn_cast<complex::PowiOp>(op)) {
builder.setInsertionPoint(powIop);
Location loc = powIop.getLoc();
auto complexTy = cast<ComplexType>(powIop.getType());
auto elemTy = complexTy.getElementType();
Value base = powIop.getLhs();
Value intExp = powIop.getRhs();
func::FuncOp callee;
unsigned realBits = cast<FloatType>(elemTy).getWidth();
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
auto funcTy = builder.getFunctionType(
Expand All @@ -98,9 +78,20 @@ void ConvertComplexPowPass::runOnOperation() {
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
else
return;
args = {base, intExp};
} else {
auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
if (auto fmf = powIop.getFastmathAttr())
call.setFastmathAttr(fmf);
powIop.replaceAllUsesWith(call.getResult(0));
powIop.erase();
}

if (auto powOp = dyn_cast<complex::PowOp>(op)) {
builder.setInsertionPoint(powOp);
Location loc = powOp.getLoc();
auto complexTy = cast<ComplexType>(powOp.getType());
auto elemTy = complexTy.getElementType();
unsigned realBits = cast<FloatType>(elemTy).getWidth();
func::FuncOp callee;
auto funcTy =
builder.getFunctionType({complexTy, complexTy}, {complexTy});
if (realBits == 32)
Expand All @@ -111,13 +102,12 @@ void ConvertComplexPowPass::runOnOperation() {
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
else
return;
args = {base, rhs};
auto call = fir::CallOp::create(builder, loc, callee,
{powOp.getLhs(), powOp.getRhs()});
if (auto fmf = powOp.getFastmathAttr())
call.setFastmathAttr(fmf);
powOp.replaceAllUsesWith(call.getResult(0));
powOp.erase();
}

auto call = fir::CallOp::create(builder, loc, callee, args);
if (auto fmf = op.getFastmathAttr())
call.setFastmathAttr(fmf);
op.replaceAllUsesWith(call.getResult(0));
op.erase();
});
}
2 changes: 1 addition & 1 deletion flang/test/Lower/HLFIR/binary-ops.f90
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
! CHECK: %[[VAL_8:.*]] = complex.pow
! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>, i32

subroutine extremum(c, n, l)
integer(8), intent(in) :: l
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/Intrinsics/pow_complex16i.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s

! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(4) :: b
b = a ** b
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/Intrinsics/pow_complex16k.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s

! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(8) :: b
b = a ** b
Expand Down
9 changes: 9 additions & 0 deletions flang/test/Lower/amdgcn-complex.f90
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
complex :: a, b, c
a = b**c
end subroutine pow_test

! CHECK-LABEL: func @_QPpowi_test(
! CHECK: complex.powi
! CHECK-NOT: fir.call @_FortranAcpowi
subroutine powi_test(a, b, c)
complex :: a, b
integer :: i
b = a ** i
end subroutine powi_test
9 changes: 4 additions & 5 deletions flang/test/Lower/power-operator.f90
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
! CHECK: complex.pow
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
! PRECISE: fir.call @_FortranAcpowi
end subroutine

Expand All @@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
! CHECK: complex.pow
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
! PRECISE: fir.call @_FortranAcpowk
end subroutine

Expand All @@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
! CHECK: complex.pow
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
! PRECISE: fir.call @_FortranAzpowi
end subroutine

Expand All @@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
! CHECK: complex.pow
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
! PRECISE: fir.call @_FortranAzpowk
end subroutine

Expand All @@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
! PRECISE: fir.call @cpow
end subroutine

60 changes: 29 additions & 31 deletions flang/test/Transforms/convert-complex-pow.fir
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,38 @@

module {
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
%c0 = arith.constant 0.0 : f32
%0 = fir.convert %arg1 : (i32) -> f32
%1 = complex.create %0, %c0 : complex<f32>
%2 = complex.pow %arg0, %1 : complex<f32>
return %2 : complex<f32>
%0 = complex.powi %arg0, %arg1 : complex<f32>, i32
return %0 : complex<f32>
}

func.func @pow_c4_i4_fast(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
%0 = complex.powi %arg0, %arg1 fastmath<fast> : complex<f32>, i32
return %0 : complex<f32>
}

func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
%c0 = arith.constant 0.0 : f32
%0 = fir.convert %arg1 : (i64) -> f32
%1 = complex.create %0, %c0 : complex<f32>
%2 = complex.pow %arg0, %1 : complex<f32>
return %2 : complex<f32>
%0 = complex.powi %arg0, %arg1 : complex<f32>, i64
return %0 : complex<f32>
}

func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
%c0 = arith.constant 0.0 : f64
%0 = fir.convert %arg1 : (i32) -> f64
%1 = complex.create %0, %c0 : complex<f64>
%2 = complex.pow %arg0, %1 : complex<f64>
return %2 : complex<f64>
%0 = complex.powi %arg0, %arg1 : complex<f64>, i32
return %0 : complex<f64>
}

func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
%c0 = arith.constant 0.0 : f64
%0 = fir.convert %arg1 : (i64) -> f64
%1 = complex.create %0, %c0 : complex<f64>
%2 = complex.pow %arg0, %1 : complex<f64>
return %2 : complex<f64>
%0 = complex.powi %arg0, %arg1 : complex<f64>, i64
return %0 : complex<f64>
}

func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
%c0 = arith.constant 0.0 : f128
%0 = fir.convert %arg1 : (i32) -> f128
%1 = complex.create %0, %c0 : complex<f128>
%2 = complex.pow %arg0, %1 : complex<f128>
return %2 : complex<f128>
%0 = complex.powi %arg0, %arg1 : complex<f128>, i32
return %0 : complex<f128>
}

func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
%c0 = arith.constant 0.0 : f128
%0 = fir.convert %arg1 : (i64) -> f128
%1 = complex.create %0, %c0 : complex<f128>
%2 = complex.pow %arg0, %1 : complex<f128>
return %2 : complex<f128>
%0 = complex.powi %arg0, %arg1 : complex<f128>, i64
return %0 : complex<f128>
}

func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> {
Expand Down Expand Up @@ -74,26 +61,37 @@ module {
// CHECK-LABEL: func.func @pow_c4_i4(
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi

// CHECK-LABEL: func.func @pow_c4_i4_fast(
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath<fast> : (complex<f32>, i32) -> complex<f32>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi

// CHECK-LABEL: func.func @pow_c4_i8(
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi

// CHECK-LABEL: func.func @pow_c8_i4(
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi

// CHECK-LABEL: func.func @pow_c8_i8(
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi

// CHECK-LABEL: func.func @pow_c16_i4(
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi

// CHECK-LABEL: func.func @pow_c16_i8(
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.powi

// CHECK-LABEL: func.func @pow_c4_fast(
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32>
Expand All @@ -108,4 +106,4 @@ module {
// CHECK-LABEL: func.func @pow_c16_complex(
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128>
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128>
// CHECK-NOT: complex.pow
// CHECK-NOT: complex.pow
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,36 @@ def PowOp : ComplexArithmeticOp<"pow"> {
}];
}

//===----------------------------------------------------------------------===//
// PowiOp
//===----------------------------------------------------------------------===//

def PowiOp : Complex_Op<"powi",
[Pure, Elementwise, SameOperandsAndResultShape,
AllTypesMatch<["lhs", "result"]>,
DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
let summary = "complex number raised to signed integer power";
let description = [{
The `powi` operation takes a `base` operand of complex type and a `power`
operand of signed integer type and returns one result of the same type
as `base`. The result is `base` raised to the power of `power`.

Example:

```mlir
%a = complex.powi %b, %c : complex<f32>, i32
```
}];

let arguments = (ins Complex<AnyFloat>:$lhs,
AnySignlessInteger:$rhs,
OptionalAttr<Arith_FastMathAttr>:$fastmath);
let results = (outs Complex<AnyFloat>:$result);

let assemblyFormat =
"$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result) `,` type($rhs)";
}

//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//
Expand Down
Loading