1313#include " mlir/Dialect/Func/IR/FuncOps.h"
1414#include " mlir/Dialect/Func/Utils/Utils.h"
1515#include " mlir/IR/Verifier.h"
16- #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
16+
17+ #include " llvm/ADT/TypeSwitch.h"
1718
1819namespace mlir {
1920#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
@@ -23,100 +24,55 @@ namespace mlir {
2324using namespace mlir ;
2425using namespace mlir ::func;
2526
26- #define APFLOAT_BIN_OPS (X ) \
27- X (add) \
28- X(subtract) \
29- X(multiply) \
30- X(divide) \
31- X(remainder) \
32- X(mod)
33-
34- #define APFLOAT_EXTERN_K (OP ) kApFloat_ ##OP
35-
36- #define APFLOAT_EXTERN_NAME (OP ) \
37- static constexpr llvm::StringRef APFLOAT_EXTERN_K (OP) = "_mlir_" \
38- "apfloat_" #OP;
39-
40- namespace mlir ::func {
41- #define LOOKUP_OR_CREATE_APFLOAT_FN_DECL (OP ) \
42- FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
43- OpBuilder &b, Operation *moduleOp, \
44- SymbolTableCollection *symbolTables = nullptr );
45-
46- APFLOAT_BIN_OPS (LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
47-
48- #undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
49-
50- APFLOAT_BIN_OPS (APFLOAT_EXTERN_NAME)
51-
52- #define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN (OP ) \
53- FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
54- OpBuilder &b, Operation *moduleOp, \
55- SymbolTableCollection *symbolTables) { \
56- return lookupOrCreateFn (b, moduleOp, APFLOAT_EXTERN_K (OP), \
57- {IntegerType::get (moduleOp->getContext (), 32 ), \
58- IntegerType::get (moduleOp->getContext (), 64 ), \
59- IntegerType::get (moduleOp->getContext (), 64 )}, \
60- {IntegerType::get (moduleOp->getContext (), 64 )}, \
61- /* setPrivate*/ true , symbolTables); \
62- }
63-
64- APFLOAT_BIN_OPS (LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
65- #undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
66- } // namespace mlir::func
67-
68- struct FancyAddFLowering : OpRewritePattern<arith::AddFOp> {
69- using OpRewritePattern::OpRewritePattern;
70-
71- LogicalResult matchAndRewrite (arith::AddFOp op,
72- PatternRewriter &rewriter) const override {
73- // Get APFloat adder function from runtime library.
74- auto parent = op->getParentOfType <ModuleOp>();
75- if (!parent)
76- return failure ();
77- if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
78- Float8E5M2FNUZType, Float8E4M3FNUZType,
79- Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
80- Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
81- op.getType ()))
82- return failure ();
83- FailureOr<Operation *> adder = lookupOrCreateApFloataddFn (rewriter, parent);
84-
85- // Cast operands to 64-bit integers.
86- Location loc = op.getLoc ();
87- auto floatTy = cast<FloatType>(op.getType ());
88- auto intWType = rewriter.getIntegerType (floatTy.getWidth ());
89- auto int64Type = rewriter.getI64Type ();
90- Value lhsBits = arith::ExtUIOp::create (
91- rewriter, loc, int64Type,
92- arith::BitcastOp::create (rewriter, loc, intWType, op.getLhs ()));
93- Value rhsBits = arith::ExtUIOp::create (
94- rewriter, loc, int64Type,
95- arith::BitcastOp::create (rewriter, loc, intWType, op.getRhs ()));
96-
97- // Call software implementation of floating point addition.
98- int32_t sem =
99- llvm::APFloatBase::SemanticsToEnum (floatTy.getFloatSemantics ());
100- Value semValue = arith::ConstantOp::create (
101- rewriter, loc, rewriter.getI32Type (),
102- rewriter.getIntegerAttr (rewriter.getI32Type (), sem));
103- SmallVector<Value> params = {semValue, lhsBits, rhsBits};
104- auto resultOp =
105- func::CallOp::create (rewriter, loc, TypeRange (rewriter.getI64Type ()),
106- SymbolRefAttr::get (*adder), params);
107-
108- // Truncate result to the original width.
109- Value truncatedBits = arith::TruncIOp::create (rewriter, loc, intWType,
110- resultOp->getResult (0 ));
111- rewriter.replaceAllUsesWith (
112- op, arith::BitcastOp::create (rewriter, loc, floatTy, truncatedBits));
113- return success ();
114- }
115- };
27+ static FailureOr<Operation *>
28+ lookupOrCreateBinaryFn (OpBuilder &b, Operation *moduleOp, StringRef name,
29+ SymbolTableCollection *symbolTables = nullptr ) {
30+ return lookupOrCreateFn (b, moduleOp,
31+ (llvm::Twine (" _mlir_apfloat_" ) + name).str (),
32+ {IntegerType::get (moduleOp->getContext (), 32 ),
33+ IntegerType::get (moduleOp->getContext (), 64 ),
34+ IntegerType::get (moduleOp->getContext (), 64 )},
35+ {IntegerType::get (moduleOp->getContext (), 64 )},
36+ /* setPrivate*/ true , symbolTables);
37+ }
11638
117- void arith::populateArithToAPFloatConversionPatterns (
118- RewritePatternSet &patterns) {
119- patterns.add <FancyAddFLowering>(patterns.getContext ());
39+ template <typename OpTy>
40+ static LogicalResult rewriteBinaryOp (RewriterBase &rewriter, ModuleOp module ,
41+ OpTy op, StringRef apfloatName) {
42+ // Get APFloat function from runtime library.
43+ FailureOr<Operation *> fn =
44+ lookupOrCreateBinaryFn (rewriter, module , apfloatName);
45+ if (failed (fn))
46+ return op->emitError (" failed to lookup or create APFloat function" );
47+
48+ // Cast operands to 64-bit integers.
49+ Location loc = op.getLoc ();
50+ auto floatTy = cast<FloatType>(op.getType ());
51+ auto intWType = rewriter.getIntegerType (floatTy.getWidth ());
52+ auto int64Type = rewriter.getI64Type ();
53+ Value lhsBits = arith::ExtUIOp::create (
54+ rewriter, loc, int64Type,
55+ arith::BitcastOp::create (rewriter, loc, intWType, op.getLhs ()));
56+ Value rhsBits = arith::ExtUIOp::create (
57+ rewriter, loc, int64Type,
58+ arith::BitcastOp::create (rewriter, loc, intWType, op.getRhs ()));
59+
60+ // Call APFloat function.
61+ int32_t sem = llvm::APFloatBase::SemanticsToEnum (floatTy.getFloatSemantics ());
62+ Value semValue = arith::ConstantOp::create (
63+ rewriter, loc, rewriter.getI32Type (),
64+ rewriter.getIntegerAttr (rewriter.getI32Type (), sem));
65+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
66+ auto resultOp =
67+ func::CallOp::create (rewriter, loc, TypeRange (rewriter.getI64Type ()),
68+ SymbolRefAttr::get (*fn), params);
69+
70+ // Truncate result to the original width.
71+ Value truncatedBits =
72+ arith::TruncIOp::create (rewriter, loc, intWType, resultOp->getResult (0 ));
73+ rewriter.replaceOp (
74+ op, arith::BitcastOp::create (rewriter, loc, floatTy, truncatedBits));
75+ return success ();
12076}
12177
12278namespace {
@@ -126,10 +82,31 @@ struct ArithToAPFloatConversionPass final
12682 ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase;
12783
12884 void runOnOperation () override {
129- Operation *op = getOperation ();
130- RewritePatternSet patterns (op->getContext ());
131- arith::populateArithToAPFloatConversionPatterns (patterns);
132- if (failed (applyPatternsGreedily (op, std::move (patterns))))
85+ ModuleOp module = getOperation ();
86+ IRRewriter rewriter (getOperation ()->getContext ());
87+ SmallVector<arith::AddFOp> addOps;
88+ WalkResult status = module ->walk ([&](Operation *op) {
89+ rewriter.setInsertionPoint (op);
90+ LogicalResult result =
91+ llvm::TypeSwitch<Operation *, LogicalResult>(op)
92+ .Case <arith::AddFOp>([&](arith::AddFOp op) {
93+ return rewriteBinaryOp (rewriter, module , op, " add" );
94+ })
95+ .Case <arith::SubFOp>([&](arith::SubFOp op) {
96+ return rewriteBinaryOp (rewriter, module , op, " subtract" );
97+ })
98+ .Case <arith::MulFOp>([&](arith::MulFOp op) {
99+ return rewriteBinaryOp (rewriter, module , op, " mulitply" );
100+ })
101+ .Case <arith::DivFOp>([&](arith::DivFOp op) {
102+ return rewriteBinaryOp (rewriter, module , op, " divide" );
103+ })
104+ .Default ([](Operation *op) { return success (); });
105+ if (failed (result))
106+ return WalkResult::interrupt ();
107+ return WalkResult::advance ();
108+ });
109+ if (status.wasInterrupted ())
133110 return signalPassFailure ();
134111 }
135112};
0 commit comments