@@ -21,59 +21,53 @@ using namespace mlir;
2121
2222namespace {
2323
24- template <typename Op>
24+ template <typename Op, typename Ty >
2525// Pattern to convert Complex ops to ROCDL function calls.
2626struct ComplexOpToROCDLCall : public OpRewritePattern <Op> {
2727 using OpRewritePattern<Op>::OpRewritePattern;
28- ComplexOpToROCDLCall (MLIRContext *context, StringRef floatFunc,
29- StringRef doubleFunc, PatternBenefit benefit = 1 )
30- : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
31- doubleFunc (doubleFunc) {}
28+ ComplexOpToROCDLCall (MLIRContext *context, StringRef funcName,
29+ PatternBenefit benefit = 1 )
30+ : OpRewritePattern<Op>(context, benefit), funcName(funcName) {}
3231
3332 LogicalResult matchAndRewrite (Op op, PatternRewriter &rewriter) const final {
3433 Operation *symTable = SymbolTable::getNearestSymbolTable (op);
3534 Type resType = op.getType ();
3635 if (auto complexType = dyn_cast<ComplexType>(resType))
3736 resType = complexType.getElementType ();
38- FloatType floatTy = dyn_cast<FloatType>(resType);
39- if (!floatTy)
40- return failure ();
41-
42- StringRef name;
43- if (floatTy.isF64 ())
44- name = doubleFunc;
45- else if (floatTy.isF32 ())
46- name = floatFunc;
47- else
37+ if (!isa<Ty>(resType))
4838 return failure ();
4939
5040 auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
51- SymbolTable::lookupSymbolIn (symTable, name ));
41+ SymbolTable::lookupSymbolIn (symTable, funcName ));
5242 if (!opFunc) {
5343 OpBuilder::InsertionGuard guard (rewriter);
5444 rewriter.setInsertionPointToStart (&symTable->getRegion (0 ).front ());
5545 auto funcTy = FunctionType::get (
5646 rewriter.getContext (), op->getOperandTypes (), op->getResultTypes ());
57- opFunc =
58- rewriter. create <func::FuncOp>(rewriter. getUnknownLoc (), name, funcTy);
47+ opFunc = rewriter. create <func::FuncOp>(rewriter. getUnknownLoc (), funcName,
48+ funcTy);
5949 opFunc.setPrivate ();
6050 }
61- rewriter.replaceOpWithNewOp <func::CallOp>(op, name , op.getType (),
51+ rewriter.replaceOpWithNewOp <func::CallOp>(op, funcName , op.getType (),
6252 op->getOperands ());
6353 return success ();
6454 }
6555
6656private:
67- std::string floatFunc, doubleFunc ;
57+ std::string funcName ;
6858};
6959} // namespace
7060
7161void mlir::populateComplexToROCDLConversionPatterns (
7262 RewritePatternSet &patterns) {
73- patterns.add <ComplexOpToROCDLCall<complex ::AbsOp>>(
74- patterns.getContext (), " __ocml_cabs_f32" , " __ocml_cabs_f64" );
75- patterns.add <ComplexOpToROCDLCall<complex ::ExpOp>>(
76- patterns.getContext (), " __ocml_cexp_f32" , " __ocml_cexp_f64" );
63+ patterns.add <ComplexOpToROCDLCall<complex ::AbsOp, Float32Type>>(
64+ patterns.getContext (), " __ocml_cabs_f32" );
65+ patterns.add <ComplexOpToROCDLCall<complex ::AbsOp, Float64Type>>(
66+ patterns.getContext (), " __ocml_cabs_f64" );
67+ patterns.add <ComplexOpToROCDLCall<complex ::ExpOp, Float32Type>>(
68+ patterns.getContext (), " __ocml_cexp_f32" );
69+ patterns.add <ComplexOpToROCDLCall<complex ::ExpOp, Float64Type>>(
70+ patterns.getContext (), " __ocml_cexp_f64" );
7771}
7872
7973namespace {
0 commit comments