1111#include " mlir/Dialect/Func/IR/FuncOps.h"
1212#include " mlir/IR/PatternMatch.h"
1313#include " mlir/Transforms/DialectConversion.h"
14- #include < optional>
1514
1615namespace mlir {
1716#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
@@ -21,36 +20,38 @@ namespace mlir {
2120using namespace mlir ;
2221
2322namespace {
24- struct FloatTypeResolver {
25- std::optional<bool > operator ()(Type type) const {
26- auto elementType = cast<FloatType>(type);
27- if (!isa<Float32Type, Float64Type>(elementType))
28- return {};
29- return elementType.getIntOrFloatBitWidth () == 64 ;
30- }
31- };
3223
33- template <typename Op, typename TypeResolver = FloatTypeResolver>
34- struct ScalarOpToROCDLCall : public OpRewritePattern <Op> {
24+ template <typename Op>
25+ // Pattern to convert Complex ops to ROCDL function calls.
26+ struct ComplexOpToROCDLCall : public OpRewritePattern <Op> {
3527 using OpRewritePattern<Op>::OpRewritePattern;
36- ScalarOpToROCDLCall (MLIRContext *context, StringRef floatFunc,
37- StringRef doubleFunc, PatternBenefit benefit)
28+ ComplexOpToROCDLCall (MLIRContext *context, StringRef floatFunc,
29+ StringRef doubleFunc, PatternBenefit benefit = 1 )
3830 : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
3931 doubleFunc (doubleFunc) {}
4032
4133 LogicalResult matchAndRewrite (Op op, PatternRewriter &rewriter) const final {
42- auto module = SymbolTable::getNearestSymbolTable (op);
43- auto isDouble = TypeResolver ()(op.getType ());
44- if (!isDouble.has_value ())
34+ Operation *symTable = SymbolTable::getNearestSymbolTable (op);
35+ Type resType = op.getType ();
36+ if (auto complexType = dyn_cast<ComplexType>(resType))
37+ resType = complexType.getElementType ();
38+ FloatType floatTy = dyn_cast<FloatType>(resType);
39+ if (!floatTy)
4540 return failure ();
4641
47- auto name = *isDouble ? doubleFunc : floatFunc;
42+ StringRef name;
43+ if (floatTy.isF64 ())
44+ name = doubleFunc;
45+ else if (floatTy.isF32 ())
46+ name = floatFunc;
47+ else
48+ return failure ();
4849
4950 auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
50- SymbolTable::lookupSymbolIn (module , name));
51+ SymbolTable::lookupSymbolIn (symTable , name));
5152 if (!opFunc) {
5253 OpBuilder::InsertionGuard guard (rewriter);
53- rewriter.setInsertionPointToStart (&module ->getRegion (0 ).front ());
54+ rewriter.setInsertionPointToStart (&symTable ->getRegion (0 ).front ());
5455 auto funcTy = FunctionType::get (
5556 rewriter.getContext (), op->getOperandTypes (), op->getResultTypes ());
5657 opFunc =
@@ -67,10 +68,12 @@ struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
6768};
6869} // namespace
6970
70- void mlir::populateComplexToROCDLConversionPatterns (RewritePatternSet &patterns,
71- PatternBenefit benefit) {
72- patterns.add <ScalarOpToROCDLCall<complex ::AbsOp>>(
73- patterns.getContext (), " __ocml_cabs_f32" , " __ocml_cabs_f64" , benefit);
71+ void mlir::populateComplexToROCDLConversionPatterns (
72+ RewritePatternSet &patterns) {
73+ patterns.add <ComplexOpToROCDLCall<complex ::AbsOp>>(
74+ patterns.getContext (), " __ocml_cabs_f32" , " __ocml_cabs_f64" );
75+ patterns.add <ComplexOpToROCDLCall<complex ::ExpOp>>(
76+ patterns.getContext (), " __ocml_cexp_f32" , " __ocml_cexp_f64" );
7477}
7578
7679namespace {
@@ -81,14 +84,14 @@ struct ConvertComplexToROCDLPass
8184} // namespace
8285
8386void ConvertComplexToROCDLPass::runOnOperation () {
84- auto module = getOperation ();
87+ Operation *op = getOperation ();
8588
8689 RewritePatternSet patterns (&getContext ());
87- populateComplexToROCDLConversionPatterns (patterns, /* benefit= */ 1 );
90+ populateComplexToROCDLConversionPatterns (patterns);
8891
8992 ConversionTarget target (getContext ());
9093 target.addLegalDialect <func::FuncDialect>();
91- target.addIllegalOp <complex ::AbsOp>();
92- if (failed (applyPartialConversion (module , target, std::move (patterns))))
94+ target.addIllegalOp <complex ::AbsOp, complex ::ExpOp >();
95+ if (failed (applyPartialConversion (op , target, std::move (patterns))))
9396 signalPassFailure ();
9497}
0 commit comments