Skip to content

Commit d69ccde

Browse files
authored
[MLIR] Add cpow support in ComplexToROCDLLibraryCalls (#153183)
This PR adds support for complex power operations (`cpow`) in the `ComplexToROCDLLibraryCalls` conversion pass, specifically targeting AMDGPU architectures. The implementation optimises complex exponentiation by using mathematical identities and special-case handling for small integer powers. - Force lowering to `complex.pow` operations for the `amdgcn-amd-amdhsa` target instead of using library calls - Convert `complex.pow(z, w)` to `complex.exp(w * complex.log(z))` using mathematical identity
1 parent 65de318 commit d69ccde

File tree

4 files changed

+87
-30
lines changed

4 files changed

+87
-30
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,26 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
12871287
return result;
12881288
}
12891289

1290+
mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
1291+
const MathOperation &mathOp,
1292+
mlir::FunctionType mathLibFuncType,
1293+
llvm::ArrayRef<mlir::Value> args) {
1294+
bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
1295+
if (!isAMDGPU)
1296+
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
1297+
1298+
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
1299+
auto realTy = complexTy.getElementType();
1300+
mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
1301+
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
1302+
mlir::Value complexExp =
1303+
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
1304+
mlir::Value result =
1305+
builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
1306+
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
1307+
return result;
1308+
}
1309+
12901310
/// Mapping between mathematical intrinsic operations and MLIR operations
12911311
/// of some appropriate dialect (math, complex, etc.) or libm calls.
12921312
/// TODO: support remaining Fortran math intrinsics.
@@ -1636,15 +1656,19 @@ static constexpr MathOperation mathOperations[] = {
16361656
genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>,
16371657
genMathOp<mlir::math::FPowIOp>},
16381658
{"pow", RTNAME_STRING(cpowi),
1639-
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall},
1659+
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
1660+
genComplexPow},
16401661
{"pow", RTNAME_STRING(zpowi),
1641-
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, genLibCall},
1662+
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
1663+
genComplexPow},
16421664
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
16431665
genLibF128Call},
16441666
{"pow", RTNAME_STRING(cpowk),
1645-
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, genLibCall},
1667+
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
1668+
genComplexPow},
16461669
{"pow", RTNAME_STRING(zpowk),
1647-
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, genLibCall},
1670+
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
1671+
genComplexPow},
16481672
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
16491673
genLibF128Call},
16501674
{"remainder", "remainderf",
@@ -4044,21 +4068,20 @@ void IntrinsicLibrary::genExecuteCommandLine(
40444068
mlir::Value waitAddr = fir::getBase(wait);
40454069
mlir::Value waitIsPresentAtRuntime =
40464070
builder.genIsNotNullAddr(loc, waitAddr);
4047-
waitBool = builder
4048-
.genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
4049-
/*withElseRegion=*/true)
4050-
.genThen([&]() {
4051-
auto waitLoad =
4052-
fir::LoadOp::create(builder, loc, waitAddr);
4053-
mlir::Value cast =
4054-
builder.createConvert(loc, i1Ty, waitLoad);
4055-
fir::ResultOp::create(builder, loc, cast);
4056-
})
4057-
.genElse([&]() {
4058-
mlir::Value trueVal = builder.createBool(loc, true);
4059-
fir::ResultOp::create(builder, loc, trueVal);
4060-
})
4061-
.getResults()[0];
4071+
waitBool =
4072+
builder
4073+
.genIfOp(loc, {i1Ty}, waitIsPresentAtRuntime,
4074+
/*withElseRegion=*/true)
4075+
.genThen([&]() {
4076+
auto waitLoad = fir::LoadOp::create(builder, loc, waitAddr);
4077+
mlir::Value cast = builder.createConvert(loc, i1Ty, waitLoad);
4078+
fir::ResultOp::create(builder, loc, cast);
4079+
})
4080+
.genElse([&]() {
4081+
mlir::Value trueVal = builder.createBool(loc, true);
4082+
fir::ResultOp::create(builder, loc, trueVal);
4083+
})
4084+
.getResults()[0];
40624085
}
40634086

40644087
mlir::Value exitstatBox =
Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
11
! REQUIRES: amdgpu-registered-target
2-
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir -flang-deprecated-no-hlfir %s -o - | FileCheck %s
2+
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir %s -o - | FileCheck %s
33

4+
! CHECK-LABEL: func @_QPcabsf_test(
5+
! CHECK: complex.abs
6+
! CHECK-NOT: fir.call @cabsf
47
subroutine cabsf_test(a, b)
58
complex :: a
69
real :: b
710
b = abs(a)
811
end subroutine
912

10-
! CHECK-LABEL: func @_QPcabsf_test(
11-
! CHECK: complex.abs
12-
! CHECK-NOT: fir.call @cabsf
13-
13+
! CHECK-LABEL: func @_QPcexpf_test(
14+
! CHECK: complex.exp
15+
! CHECK-NOT: fir.call @cexpf
1416
subroutine cexpf_test(a, b)
1517
complex :: a, b
1618
b = exp(a)
1719
end subroutine
1820

19-
! CHECK-LABEL: func @_QPcexpf_test(
20-
! CHECK: complex.exp
21-
! CHECK-NOT: fir.call @cexpf
21+
! CHECK-LABEL: func @_QPpow_test(
22+
! CHECK: complex.pow
23+
! CHECK-NOT: fir.call @_FortranAcpowi
24+
subroutine pow_test(a, b, c)
25+
complex :: a, b, c
26+
a = b**c
27+
end subroutine pow_test

mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,26 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
5656
private:
5757
std::string funcName;
5858
};
59+
60+
// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
61+
struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
62+
using OpRewritePattern<complex::PowOp>::OpRewritePattern;
63+
64+
LogicalResult matchAndRewrite(complex::PowOp op,
65+
PatternRewriter &rewriter) const final {
66+
Location loc = op.getLoc();
67+
Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
68+
Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
69+
Value exp = rewriter.create<complex::ExpOp>(loc, mul);
70+
rewriter.replaceOp(op, exp);
71+
return success();
72+
}
73+
};
5974
} // namespace
6075

6176
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
6277
RewritePatternSet &patterns) {
78+
patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
6379
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
6480
patterns.getContext(), "__ocml_cabs_f32");
6581
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
@@ -110,9 +126,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
110126

111127
ConversionTarget target(getContext());
112128
target.addLegalDialect<func::FuncDialect>();
129+
target.addLegalOp<complex::MulOp>();
113130
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
114-
complex::LogOp, complex::SinOp, complex::SqrtOp,
115-
complex::TanOp, complex::TanhOp>();
131+
complex::LogOp, complex::PowOp, complex::SinOp,
132+
complex::SqrtOp, complex::TanOp, complex::TanhOp>();
116133
if (failed(applyPartialConversion(op, target, std::move(patterns))))
117134
signalPassFailure();
118135
}

mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s
1+
// RUN: mlir-opt %s --allow-unregistered-dialect -convert-complex-to-rocdl-library-calls | FileCheck %s
22

33
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
44
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
@@ -57,6 +57,17 @@ func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
5757
return %lf, %ld : complex<f32>, complex<f64>
5858
}
5959

60+
//CHECK-LABEL: @pow_caller
61+
//CHECK: (%[[Z:.*]]: complex<f32>, %[[W:.*]]: complex<f32>)
62+
func.func @pow_caller(%z: complex<f32>, %w: complex<f32>) -> complex<f32> {
63+
// CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]])
64+
// CHECK: %[[MUL:.*]] = complex.mul %[[W]], %[[LOG]]
65+
// CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]])
66+
// CHECK: return %[[EXP]]
67+
%r = complex.pow %z, %w : complex<f32>
68+
return %r : complex<f32>
69+
}
70+
6071
//CHECK-LABEL: @sin_caller
6172
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
6273
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})

0 commit comments

Comments
 (0)