1212#include " mlir/Dialect/Arith/Transforms/Passes.h"
1313#include " mlir/Dialect/Func/IR/FuncOps.h"
1414#include " mlir/Dialect/Func/Utils/Utils.h"
15+ #include " mlir/IR/PatternMatch.h"
1516#include " mlir/IR/Verifier.h"
16-
17- #include " llvm/ADT/TypeSwitch.h"
17+ #include " mlir/Transforms/WalkPatternRewriteDriver.h"
1818
1919namespace mlir {
2020#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
@@ -24,7 +24,7 @@ namespace mlir {
2424using namespace mlir ;
2525using namespace mlir ::func;
2626
27- // / Helper function to lookup or create the symbol for a runtime library
27+ // / Helper function to look up or create the symbol for a runtime library
2828// / function for a binary arithmetic operation.
2929// /
3030// / Parameter 1: APFloat semantics
@@ -46,80 +46,75 @@ lookupOrCreateBinaryFn(OpBuilder &b, Operation *moduleOp, StringRef name,
4646}
4747
4848// / Rewrite a binary arithmetic operation to an APFloat function call.
49- template <typename OpTy>
50- static LogicalResult rewriteBinaryOp (RewriterBase &rewriter, ModuleOp module ,
51- OpTy op, StringRef apfloatName) {
52- // Get APFloat function from runtime library.
53- FailureOr<Operation *> fn =
54- lookupOrCreateBinaryFn (rewriter, module , apfloatName);
55- if (failed (fn))
56- return op->emitError (" failed to lookup or create APFloat function" );
49+ template <typename OpTy, const char *APFloatName>
50+ struct ArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
51+ using OpRewritePattern<OpTy>::OpRewritePattern;
5752
58- // Cast operands to 64-bit integers.
59- Location loc = op.getLoc ();
60- auto floatTy = cast<FloatType>(op.getType ());
61- auto intWType = rewriter.getIntegerType (floatTy.getWidth ());
62- auto int64Type = rewriter.getI64Type ();
63- Value lhsBits = arith::ExtUIOp::create (
64- rewriter, loc, int64Type,
65- arith::BitcastOp::create (rewriter, loc, intWType, op.getLhs ()));
66- Value rhsBits = arith::ExtUIOp::create (
67- rewriter, loc, int64Type,
68- arith::BitcastOp::create (rewriter, loc, intWType, op.getRhs ()));
53+ LogicalResult matchAndRewrite (OpTy op,
54+ PatternRewriter &rewriter) const override {
55+ auto moduleOp = op->template getParentOfType <ModuleOp>();
56+ if (!moduleOp) {
57+ op.emitError (" arith op must be contained within a builtin.module" );
58+ return failure ();
59+ }
60+ // Get APFloat function from runtime library.
61+ FailureOr<Operation *> fn =
62+ lookupOrCreateBinaryFn (rewriter, moduleOp, APFloatName);
63+ if (failed (fn))
64+ return op->emitError (" failed to lookup or create APFloat function" );
6965
70- // Call APFloat function.
71- int32_t sem = llvm::APFloatBase::SemanticsToEnum (floatTy.getFloatSemantics ());
72- Value semValue = arith::ConstantOp::create (
73- rewriter, loc, rewriter.getI32Type (),
74- rewriter.getIntegerAttr (rewriter.getI32Type (), sem));
75- SmallVector<Value> params = {semValue, lhsBits, rhsBits};
76- auto resultOp =
77- func::CallOp::create (rewriter, loc, TypeRange (rewriter.getI64Type ()),
78- SymbolRefAttr::get (*fn), params);
66+ rewriter.setInsertionPoint (op);
67+ // Cast operands to 64-bit integers.
68+ Location loc = op.getLoc ();
69+ auto floatTy = cast<FloatType>(op.getType ());
70+ auto intWType = rewriter.getIntegerType (floatTy.getWidth ());
71+ auto int64Type = rewriter.getI64Type ();
72+ Value lhsBits = arith::ExtUIOp::create (
73+ rewriter, loc, int64Type,
74+ arith::BitcastOp::create (rewriter, loc, intWType, op.getLhs ()));
75+ Value rhsBits = arith::ExtUIOp::create (
76+ rewriter, loc, int64Type,
77+ arith::BitcastOp::create (rewriter, loc, intWType, op.getRhs ()));
7978
80- // Truncate result to the original width.
81- Value truncatedBits =
82- arith::TruncIOp::create (rewriter, loc, intWType, resultOp->getResult (0 ));
83- rewriter.replaceOp (
84- op, arith::BitcastOp::create (rewriter, loc, floatTy, truncatedBits));
85- return success ();
86- }
79+ // Call APFloat function.
80+ int32_t sem =
81+ llvm::APFloatBase::SemanticsToEnum (floatTy.getFloatSemantics ());
82+ Value semValue = arith::ConstantOp::create (
83+ rewriter, loc, rewriter.getI32Type (),
84+ rewriter.getIntegerAttr (rewriter.getI32Type (), sem));
85+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
86+ auto resultOp =
87+ func::CallOp::create (rewriter, loc, TypeRange (rewriter.getI64Type ()),
88+ SymbolRefAttr::get (*fn), params);
89+
90+ // Truncate result to the original width.
91+ Value truncatedBits = arith::TruncIOp::create (rewriter, loc, intWType,
92+ resultOp->getResult (0 ));
93+ rewriter.replaceOp (
94+ op, arith::BitcastOp::create (rewriter, loc, floatTy, truncatedBits));
95+ return success ();
96+ }
97+ };
8798
8899namespace {
89100struct ArithToAPFloatConversionPass final
90101 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
91102 using Base::Base;
92103
93104 void runOnOperation () override {
94- ModuleOp moduleOp = getOperation ();
95- IRRewriter rewriter (getOperation ()->getContext ());
96- SmallVector<arith::AddFOp> addOps;
97- WalkResult status = moduleOp->walk ([&](Operation *op) {
98- rewriter.setInsertionPoint (op);
99- LogicalResult result =
100- llvm::TypeSwitch<Operation *, LogicalResult>(op)
101- .Case <arith::AddFOp>([&](arith::AddFOp op) {
102- return rewriteBinaryOp (rewriter, moduleOp, op, " add" );
103- })
104- .Case <arith::SubFOp>([&](arith::SubFOp op) {
105- return rewriteBinaryOp (rewriter, moduleOp, op, " subtract" );
106- })
107- .Case <arith::MulFOp>([&](arith::MulFOp op) {
108- return rewriteBinaryOp (rewriter, moduleOp, op, " multiply" );
109- })
110- .Case <arith::DivFOp>([&](arith::DivFOp op) {
111- return rewriteBinaryOp (rewriter, moduleOp, op, " divide" );
112- })
113- .Case <arith::RemFOp>([&](arith::RemFOp op) {
114- return rewriteBinaryOp (rewriter, moduleOp, op, " remainder" );
115- })
116- .Default ([](Operation *op) { return success (); });
117- if (failed (result))
118- return WalkResult::interrupt ();
119- return WalkResult::advance ();
120- });
121- if (status.wasInterrupted ())
122- return signalPassFailure ();
105+ MLIRContext *context = &getContext ();
106+ RewritePatternSet patterns (context);
107+ static const char add[] = " add" ;
108+ static const char subtract[] = " subtract" ;
109+ static const char multiply[] = " multiply" ;
110+ static const char divide[] = " divide" ;
111+ static const char remainder[] = " remainder" ;
112+ patterns.add <ArithOpToAPFloatConversion<arith::AddFOp, add>,
113+ ArithOpToAPFloatConversion<arith::SubFOp, subtract>,
114+ ArithOpToAPFloatConversion<arith::MulFOp, multiply>,
115+ ArithOpToAPFloatConversion<arith::DivFOp, divide>,
116+ ArithOpToAPFloatConversion<arith::RemFOp, remainder>>(context);
117+ walkAndApplyPatterns (getOperation (), std::move (patterns));
123118 }
124119};
125120} // namespace
0 commit comments