Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,28 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
return result;
}

mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
const MathOperation &mathOp,
mlir::FunctionType mathLibFuncType,
llvm::ArrayRef<mlir::Value> args) {
bool canUseApprox = mlir::arith::bitEnumContainsAny(
builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
if (!forceMlirComplex && !canUseApprox && !isAMDGPU)
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);

auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
auto realTy = complexTy.getElementType();
mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
mlir::Value complexExp =
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
mlir::Value result =
builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure that complex.pow is only generated when isAMDGPU is true, otherwise, I would expect performance regressions in afn compilations.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
}

/// Mapping between mathematical intrinsic operations and MLIR operations
/// of some appropriate dialect (math, complex, etc.) or libm calls.
/// TODO: support remaining Fortran math intrinsics.
Expand Down Expand Up @@ -1625,15 +1647,19 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>,
genMathOp<mlir::math::FPowIOp>},
{"pow", RTNAME_STRING(cpowi),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please clarify the benefits of expanding these functions in MLIR vs implementing the same logic in Fortran runtime compiled for AMD GPU device? I do not have any concerns, I am just curious.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's more modular in the sense any frontend can just lower to complex dialect and have the conversion pass take care of the rest, rather than have every frontend lower specifically for amdgcn.

But Flang is the only concern at the moment so I'm happy to move it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, the "runtime" approach that I meant was that we generate the cpowi etc. calls here, and then the runtime implementation for AMD GPU device uses the __ocml_* intrinsics. Would that be a viable solution? I guess the benefit of having the complex operations is some special case handling, like the constant exponent optimizations.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but I didn't understand. This pass only runs on device pass for amdgpu and converts the complex ops to relevant ocml library calls. Are you suggesting we delay this conversion to something like mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that we could just have AMD GPU specific versions of _FortranAcpowi and other functions in flang-rt/lib/runtime.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I'll look into it if that's a possibility. But I guess this is a workaround at the moment for lib functions that are not available on the GPU. mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp also does something similar by converting things to ROCDL calls.

genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall},
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
genComplexPow},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be better to handle these cases same way as the other intrinsics where we can generate either a lib call or an MLIR operation, e.g.:

    {"pow", RTNAME_STRING(FPow4i),
     genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
     genMathOp<mlir::math::FPowIOp>},

Then, you may have a ROCDL specific pass that converts the complex operations into AMD GPU code, and a Flang pipeline pass that converts the complex operations into the runtime calls. You may also have a pass that does the canonicalization/folding for the constant exponent cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but I don't understand the change. Currently @_FortranAcpowi calls are generated for all cases, and this PR adds lowering to complex.pow op for amdgpu device pass. The complex.pow gets later converted to ocml calls. Can you please clarify what you are suggesting instead?

Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please look at how this case works:

    {"pow", RTNAME_STRING(FPow4i),
     genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
     genMathOp<mlir::math::FPowIOp>},

Depending on the mathRuntimeVersion Flang either generates a call to _FortranAFPow4i or an mlir::math::FPowIOp operation. You can do the same for _FortranAcpowi vs complex.pow, and then handle complex.pow any way you wish later in the pipeline. So for AMD GPU you may convert it to the ocml calls, and otherwise you may convert it to _FortranAcpowi late in Flang pass pipeline. This way, we get all the benefits of not having a call with side effects at MLIR level, and we can apply folding/canonicalization to complex.pow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the code to reflect this change. Let me know if it's what you wanted or would like to see any further changes.

Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! It seems to be the right direction to me, though there is a couple of missing things:

  1. I think we need to make sure that we still call the _FortranAcpowi and other Fortran runtime functions for Flang, so I think we need to have a pass that will convert the complex pow operations back to Fortran runtime calls (unless the ROCDL conversion converts them to AMD GPU specific code).
  2. I would suggest introducing powi operation in the Complex dialect, so that we know that the exponent argument is an integer value 100%. If there is a way to guarantee that we always recognize complex.pow's integer exponent argument whenever Flang created such an operation, then powi is redundant. So it depends on how reliable the recognition is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the optimisation has deviated the PR too far. The original motivation of this PR is only to add support for cpow lowering on AMDGPU.

I've reverted the PR to an older revision which was already accepted by @krzysz00 , and dropped the optimisation entirely.

I'll start a separate PR soon for the Flang lowering changes along with the optimisation.

Please let me know if you would like to see any changes to this PR before I merge it.

Thanks.

{"pow", RTNAME_STRING(zpowi),
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, genLibCall},
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
genLibF128Call},
{"pow", RTNAME_STRING(cpowk),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, genLibCall},
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
genComplexPow},
{"pow", RTNAME_STRING(zpowk),
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, genLibCall},
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
genLibF128Call},
{"remainder", "remainderf",
Expand Down Expand Up @@ -4032,21 +4058,20 @@ void IntrinsicLibrary::genExecuteCommandLine(
mlir::Value waitAddr = fir::getBase(wait);
mlir::Value waitIsPresentAtRuntime =
builder.genIsNotNullAddr(loc, waitAddr);
waitBool = builder
.genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
/*withElseRegion=*/true)
.genThen([&]() {
auto waitLoad =
fir::LoadOp::create(builder, loc, waitAddr);
mlir::Value cast =
builder.createConvert(loc, i1Ty, waitLoad);
fir::ResultOp::create(builder, loc, cast);
})
.genElse([&]() {
mlir::Value trueVal = builder.createBool(loc, true);
fir::ResultOp::create(builder, loc, trueVal);
})
.getResults()[0];
waitBool =
builder
.genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
/*withElseRegion=*/true)
.genThen([&]() {
auto waitLoad = fir::LoadOp::create(builder, loc, waitAddr);
mlir::Value cast = builder.createConvert(loc, i1Ty, waitLoad);
fir::ResultOp::create(builder, loc, cast);
})
.genElse([&]() {
mlir::Value trueVal = builder.createBool(loc, true);
fir::ResultOp::create(builder, loc, trueVal);
})
.getResults()[0];
}

mlir::Value exitstatBox =
Expand Down
22 changes: 14 additions & 8 deletions flang/test/Lower/amdgcn-complex.f90
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
! REQUIRES: amdgpu-registered-target
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir -flang-deprecated-no-hlfir %s -o - | FileCheck %s
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir %s -o - | FileCheck %s

! CHECK-LABEL: func @_QPcabsf_test(
! CHECK: complex.abs
! CHECK-NOT: fir.call @cabsf
subroutine cabsf_test(a, b)
complex :: a
real :: b
b = abs(a)
end subroutine

! CHECK-LABEL: func @_QPcabsf_test(
! CHECK: complex.abs
! CHECK-NOT: fir.call @cabsf

! CHECK-LABEL: func @_QPcexpf_test(
! CHECK: complex.exp
! CHECK-NOT: fir.call @cexpf
subroutine cexpf_test(a, b)
complex :: a, b
b = exp(a)
end subroutine

! CHECK-LABEL: func @_QPcexpf_test(
! CHECK: complex.exp
! CHECK-NOT: fir.call @cexpf
! CHECK-LABEL: func @_QPpow_test(
! CHECK: complex.pow
! CHECK-NOT: fir.call @_FortranAcpowi
subroutine pow_test(a, b, c)
complex :: a, b, c
a = b**c
end subroutine pow_test
13 changes: 8 additions & 5 deletions flang/test/Lower/power-operator.f90
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,35 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
! CHECK: call @_FortranAcpowi
! PRECISE: call @_FortranAcpowi
! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
end subroutine

! CHECK-LABEL: pow_c4_i8
subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
! CHECK: call @_FortranAcpowk
! PRECISE: call @_FortranAcpowk
! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
end subroutine

! CHECK-LABEL: pow_c8_i4
subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
! CHECK: call @_FortranAzpowi
! PRECISE: call @_FortranAzpowi
! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
end subroutine

! CHECK-LABEL: pow_c8_i8
subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
! CHECK: call @_FortranAzpowk
! PRECISE: call @_FortranAzpowk
! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
end subroutine

! CHECK-LABEL: pow_c4_c4
Expand All @@ -138,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
! PRECISE: call @cpow
end subroutine

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

Expand Down Expand Up @@ -56,10 +57,26 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
private:
std::string funcName;
};

// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
using OpRewritePattern<complex::PowOp>::OpRewritePattern;

LogicalResult matchAndRewrite(complex::PowOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
Value exp = rewriter.create<complex::ExpOp>(loc, mul);
rewriter.replaceOp(op, exp);
return success();
}
};
} // namespace

void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
Expand Down Expand Up @@ -110,9 +127,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {

ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
target.addLegalOp<complex::MulOp>();
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
complex::LogOp, complex::SinOp, complex::SqrtOp,
complex::TanOp, complex::TanhOp>();
complex::LogOp, complex::PowOp, complex::SinOp,
complex::SqrtOp, complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s
// RUN: mlir-opt %s --allow-unregistered-dialect -convert-complex-to-rocdl-library-calls | FileCheck %s

// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
Expand Down Expand Up @@ -57,6 +57,17 @@ func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
return %lf, %ld : complex<f32>, complex<f64>
}

//CHECK-LABEL: @pow_caller
//CHECK: (%[[Z:.*]]: complex<f32>, %[[W:.*]]: complex<f32>)
func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
// CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]])
// CHECK: %[[MUL:.*]] = complex.mul %[[W]], %[[LOG]]
// CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]])
// CHECK: return %[[EXP]]
%r = complex.pow %z, %w : complex<f32>
return %r : complex<f32>
}

//CHECK-LABEL: @sin_caller
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
Expand Down