Skip to content

Commit fc114e4

Browse files
authored
[MLIR] Add ComplexTOROCDLLibraryCalls pass (#144926)
1 parent 1742966 commit fc114e4

File tree

9 files changed

+188
-7
lines changed

9 files changed

+188
-7
lines changed

flang/lib/Optimizer/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ add_flang_library(FIRCodeGen
3434

3535
MLIR_LIBS
3636
MLIRComplexToLLVM
37+
MLIRComplexToROCDLLibraryCalls
3738
MLIRComplexToStandard
3839
MLIRGPUDialect
3940
MLIRMathToFuncs

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
3434
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
3535
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
36+
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
3637
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
3738
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
3839
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4145,22 +4146,24 @@ class FIRToLLVMLowering
41454146
// conversions that affect the ModuleOp, e.g. create new
41464147
// function operations in it. We have to run such conversions
41474148
// as passes here.
4148-
mlir::OpPassManager mathConvertionPM("builtin.module");
4149+
mlir::OpPassManager mathConversionPM("builtin.module");
41494150

41504151
bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
41514152
// If compiling for AMD target some math operations must be lowered to AMD
41524153
// GPU library calls, the rest can be converted to LLVM intrinsics, which
41534154
// is handled in the mathToLLVM conversion. The lowering to libm calls is
41544155
// not needed since all math operations are handled this way.
4155-
if (isAMDGCN)
4156-
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
4156+
if (isAMDGCN) {
4157+
mathConversionPM.addPass(mlir::createConvertMathToROCDL());
4158+
mathConversionPM.addPass(mlir::createConvertComplexToROCDLLibraryCalls());
4159+
}
41574160

41584161
// Convert math::FPowI operations to inline implementation
41594162
// only if the exponent's width is greater than 32, otherwise,
41604163
// it will be lowered to LLVM intrinsic operation by a later conversion.
41614164
mlir::ConvertMathToFuncsOptions mathToFuncsOptions{};
41624165
mathToFuncsOptions.minWidthOfFPowIExponent = 33;
4163-
mathConvertionPM.addPass(
4166+
mathConversionPM.addPass(
41644167
mlir::createConvertMathToFuncs(mathToFuncsOptions));
41654168

41664169
mlir::ConvertComplexToStandardPassOptions complexToStandardOptions{};
@@ -4173,15 +4176,15 @@ class FIRToLLVMLowering
41734176
complexToStandardOptions.complexRange =
41744177
mlir::complex::ComplexRangeFlags::improved;
41754178
}
4176-
mathConvertionPM.addPass(
4179+
mathConversionPM.addPass(
41774180
mlir::createConvertComplexToStandardPass(complexToStandardOptions));
41784181

41794182
// Convert Math dialect operations into LLVM dialect operations.
41804183
// There is no way to prefer MathToLLVM patterns over MathToLibm
41814184
// patterns (applied below), so we have to run MathToLLVM conversion here.
4182-
mathConvertionPM.addNestedPass<mlir::func::FuncOp>(
4185+
mathConversionPM.addNestedPass<mlir::func::FuncOp>(
41834186
mlir::createConvertMathToLLVMPass());
4184-
if (mlir::failed(runPipeline(mathConvertionPM, mod)))
4187+
if (mlir::failed(runPipeline(mathConversionPM, mod)))
41854188
return signalPassFailure();
41864189

41874190
std::optional<mlir::DataLayout> dl =
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- ComplexToROCDLLibraryCalls.h - convert from Complex to ROCDL calls -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_
10+
#define MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_
11+
12+
#include "mlir/IR/PatternMatch.h"
13+
#include "mlir/Pass/Pass.h"
14+
15+
namespace mlir {
16+
class RewritePatternSet;
17+
18+
#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDLLIBRARYCALLS
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
/// Populate the given list with patterns that convert from Complex to ROCDL
22+
/// calls.
23+
void populateComplexToROCDLLibraryCallsConversionPatterns(
24+
RewritePatternSet &patterns);
25+
} // namespace mlir
26+
27+
#endif // MLIR_CONVERSION_COMPLEXTOROCDLLIBRARYCALLS_COMPLEXTOROCDLLIBRARYCALLS_H_

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
2424
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
2525
#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
26+
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
2627
#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
2728
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
2829
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
312312
let dependentDialects = ["func::FuncDialect"];
313313
}
314314

315+
//===----------------------------------------------------------------------===//
316+
// ComplexToROCDLLibraryCalls
317+
//===----------------------------------------------------------------------===//
318+
319+
def ConvertComplexToROCDLLibraryCalls : Pass<"convert-complex-to-rocdl-library-calls", "ModuleOp"> {
320+
let summary = "Convert Complex dialect to ROCDL library calls";
321+
let description = [{
322+
This pass converts supported Complex ops to calls to the AMD device library.
323+
}];
324+
let dependentDialects = ["func::FuncDialect"];
325+
}
326+
315327
//===----------------------------------------------------------------------===//
316328
// ComplexToSPIRV
317329
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
1313
add_subdirectory(BufferizationToMemRef)
1414
add_subdirectory(ComplexCommon)
1515
add_subdirectory(ComplexToLibm)
16+
add_subdirectory(ComplexToROCDLLibraryCalls)
1617
add_subdirectory(ComplexToLLVM)
1718
add_subdirectory(ComplexToSPIRV)
1819
add_subdirectory(ComplexToStandard)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_mlir_conversion_library(MLIRComplexToROCDLLibraryCalls
2+
ComplexToROCDLLibraryCalls.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDLLibraryCalls
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRComplexDialect
15+
MLIRFuncDialect
16+
MLIRPass
17+
MLIRTransformUtils
18+
)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
//=== ComplexToROCDLLibraryCalls.cpp - convert from Complex to ROCDL calls ===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h"
10+
#include "mlir/Dialect/Complex/IR/Complex.h"
11+
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/IR/PatternMatch.h"
13+
#include "mlir/Transforms/DialectConversion.h"
14+
15+
namespace mlir {
16+
#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDLLIBRARYCALLS
17+
#include "mlir/Conversion/Passes.h.inc"
18+
} // namespace mlir
19+
20+
using namespace mlir;
21+
22+
namespace {
23+
24+
template <typename Op, typename FloatTy>
25+
// Pattern to convert Complex ops to ROCDL function calls.
26+
struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
27+
using OpRewritePattern<Op>::OpRewritePattern;
28+
ComplexOpToROCDLLibraryCalls(MLIRContext *context, StringRef funcName,
29+
PatternBenefit benefit = 1)
30+
: OpRewritePattern<Op>(context, benefit), funcName(funcName) {}
31+
32+
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
33+
Operation *symTable = SymbolTable::getNearestSymbolTable(op);
34+
Type resType = op.getType();
35+
if (auto complexType = dyn_cast<ComplexType>(resType))
36+
resType = complexType.getElementType();
37+
if (!isa<FloatTy>(resType))
38+
return failure();
39+
40+
auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
41+
SymbolTable::lookupSymbolIn(symTable, funcName));
42+
if (!opFunc) {
43+
OpBuilder::InsertionGuard guard(rewriter);
44+
rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
45+
auto funcTy = FunctionType::get(
46+
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
47+
opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), funcName,
48+
funcTy);
49+
opFunc.setPrivate();
50+
}
51+
rewriter.replaceOpWithNewOp<func::CallOp>(op, funcName, op.getType(),
52+
op->getOperands());
53+
return success();
54+
}
55+
56+
private:
57+
std::string funcName;
58+
};
59+
} // namespace
60+
61+
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
62+
RewritePatternSet &patterns) {
63+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
64+
patterns.getContext(), "__ocml_cabs_f32");
65+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
66+
patterns.getContext(), "__ocml_cabs_f64");
67+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
68+
patterns.getContext(), "__ocml_cexp_f32");
69+
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
70+
patterns.getContext(), "__ocml_cexp_f64");
71+
}
72+
73+
namespace {
74+
struct ConvertComplexToROCDLLibraryCallsPass
75+
: public impl::ConvertComplexToROCDLLibraryCallsBase<
76+
ConvertComplexToROCDLLibraryCallsPass> {
77+
void runOnOperation() override;
78+
};
79+
} // namespace
80+
81+
void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
82+
Operation *op = getOperation();
83+
84+
RewritePatternSet patterns(&getContext());
85+
populateComplexToROCDLLibraryCallsConversionPatterns(patterns);
86+
87+
ConversionTarget target(getContext());
88+
target.addLegalDialect<func::FuncDialect>();
89+
target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
90+
if (failed(applyPartialConversion(op, target, std::move(patterns))))
91+
signalPassFailure();
92+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: mlir-opt %s -convert-complex-to-rocdl-library-calls | FileCheck %s
2+
3+
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
4+
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
5+
// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
6+
// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
7+
8+
//CHECK-LABEL: @abs_caller
9+
func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
10+
// CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%{{.*}})
11+
%rf = complex.abs %f : complex<f32>
12+
// CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%{{.*}})
13+
%rd = complex.abs %d : complex<f64>
14+
// CHECK: return %[[RF]], %[[RD]]
15+
return %rf, %rd : f32, f64
16+
}
17+
18+
//CHECK-LABEL: @exp_caller
19+
func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
20+
// CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
21+
%ef = complex.exp %f : complex<f32>
22+
// CHECK: %[[ED:.*]] = call @__ocml_cexp_f64(%{{.*}})
23+
%ed = complex.exp %d : complex<f64>
24+
// CHECK: return %[[EF]], %[[ED]]
25+
return %ef, %ed : complex<f32>, complex<f64>
26+
}

0 commit comments

Comments
 (0)