1- // ===- ArithToAPFloat.cpp - Arithmetic to APFloat impl conversion ---------===//
1+ // ===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion ----- ---------===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
1313#include " mlir/Dialect/Func/IR/FuncOps.h"
1414#include " mlir/Dialect/Func/Utils/Utils.h"
1515#include " mlir/IR/Verifier.h"
16- #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
16+
17+ #include " llvm/ADT/TypeSwitch.h"
1718
1819namespace mlir {
1920#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
@@ -23,100 +24,66 @@ namespace mlir {
2324using namespace mlir ;
2425using namespace mlir ::func;
2526
26- #define APFLOAT_BIN_OPS (X ) \
27- X (add) \
28- X(subtract) \
29- X(multiply) \
30- X(divide) \
31- X(remainder) \
32- X(mod)
33-
34- #define APFLOAT_EXTERN_K (OP ) kApFloat_ ##OP
35-
36- #define APFLOAT_EXTERN_NAME (OP ) \
37- static constexpr llvm::StringRef APFLOAT_EXTERN_K (OP) = "_mlir_" \
38- "apfloat_" #OP;
39-
40- namespace mlir ::func {
41- #define LOOKUP_OR_CREATE_APFLOAT_FN_DECL (OP ) \
42- FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
43- OpBuilder &b, Operation *moduleOp, \
44- SymbolTableCollection *symbolTables = nullptr );
45-
46- APFLOAT_BIN_OPS (LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
47-
48- #undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
49-
50- APFLOAT_BIN_OPS (APFLOAT_EXTERN_NAME)
51-
52- #define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN (OP ) \
53- FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
54- OpBuilder &b, Operation *moduleOp, \
55- SymbolTableCollection *symbolTables) { \
56- return lookupOrCreateFn (b, moduleOp, APFLOAT_EXTERN_K (OP), \
57- {IntegerType::get (moduleOp->getContext (), 32 ), \
58- IntegerType::get (moduleOp->getContext (), 64 ), \
59- IntegerType::get (moduleOp->getContext (), 64 )}, \
60- {IntegerType::get (moduleOp->getContext (), 64 )}, \
61- /* setPrivate*/ true , symbolTables); \
62- }
63-
64- APFLOAT_BIN_OPS (LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
65- #undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
66- } // namespace mlir::func
67-
68- struct FancyAddFLowering : OpRewritePattern<arith::AddFOp> {
69- using OpRewritePattern::OpRewritePattern;
70-
71- LogicalResult matchAndRewrite (arith::AddFOp op,
72- PatternRewriter &rewriter) const override {
73- // Get APFloat adder function from runtime library.
74- auto parent = op->getParentOfType <ModuleOp>();
75- if (!parent)
76- return failure ();
77- if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
78- Float8E5M2FNUZType, Float8E4M3FNUZType,
79- Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
80- Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
81- op.getType ()))
82- return failure ();
83- FailureOr<Operation *> adder = lookupOrCreateApFloataddFn (rewriter, parent);
84-
85- // Cast operands to 64-bit integers.
86- Location loc = op.getLoc ();
87- auto floatTy = cast<FloatType>(op.getType ());
88- auto intWType = rewriter.getIntegerType (floatTy.getWidth ());
89- auto int64Type = rewriter.getI64Type ();
90- Value lhsBits = arith::ExtUIOp::create (
91- rewriter, loc, int64Type,
92- arith::BitcastOp::create (rewriter, loc, intWType, op.getLhs ()));
93- Value rhsBits = arith::ExtUIOp::create (
94- rewriter, loc, int64Type,
95- arith::BitcastOp::create (rewriter, loc, intWType, op.getRhs ()));
96-
97- // Call software implementation of floating point addition.
98- int32_t sem =
99- llvm::APFloatBase::SemanticsToEnum (floatTy.getFloatSemantics ());
100- Value semValue = arith::ConstantOp::create (
101- rewriter, loc, rewriter.getI32Type (),
102- rewriter.getIntegerAttr (rewriter.getI32Type (), sem));
103- SmallVector<Value> params = {semValue, lhsBits, rhsBits};
104- auto resultOp =
105- func::CallOp::create (rewriter, loc, TypeRange (rewriter.getI64Type ()),
106- SymbolRefAttr::get (*adder), params);
107-
108- // Truncate result to the original width.
109- Value truncatedBits = arith::TruncIOp::create (rewriter, loc, intWType,
110- resultOp->getResult (0 ));
111- rewriter.replaceAllUsesWith (
112- op, arith::BitcastOp::create (rewriter, loc, floatTy, truncatedBits));
113- return success ();
114- }
115- };
27+ // / Helper function to lookup or create the symbol for a runtime library
28+ // / function for a binary arithmetic operation.
29+ // /
30+ // / Parameter 1: APFloat semantics
31+ // / Parameter 2: Left-hand side operand
32+ // / Parameter 3: Right-hand side operand
33+ // /
34+ // / This function will return a failure if the function is found but has an
35+ // / unexpected signature.
36+ // /
37+ static FailureOr<Operation *>
38+ lookupOrCreateBinaryFn (OpBuilder &b, Operation *moduleOp, StringRef name,
39+ SymbolTableCollection *symbolTables = nullptr ) {
40+ return lookupOrCreateFn (b, moduleOp,
41+ (llvm::Twine (" _mlir_apfloat_" ) + name).str (),
42+ {IntegerType::get (moduleOp->getContext (), 32 ),
43+ IntegerType::get (moduleOp->getContext (), 64 ),
44+ IntegerType::get (moduleOp->getContext (), 64 )},
45+ {IntegerType::get (moduleOp->getContext (), 64 )},
46+ /* setPrivate=*/ true , symbolTables);
47+ }
11648
117- void arith::populateArithToAPFloatConversionPatterns (
118- RewritePatternSet &patterns) {
119- patterns.add <FancyAddFLowering>(patterns.getContext ());
49+ // / Rewrite a binary arithmetic operation to an APFloat function call.
50+ template <typename OpTy>
51+ static LogicalResult rewriteBinaryOp (RewriterBase &rewriter, ModuleOp module ,
52+ OpTy op, StringRef apfloatName) {
53+ // Get APFloat function from runtime library.
54+ FailureOr<Operation *> fn =
55+ lookupOrCreateBinaryFn (rewriter, module , apfloatName);
56+ if (failed (fn))
57+ return op->emitError (" failed to lookup or create APFloat function" );
58+
59+ // Cast operands to 64-bit integers.
60+ Location loc = op.getLoc ();
61+ auto floatTy = cast<FloatType>(op.getType ());
62+ auto intWType = rewriter.getIntegerType (floatTy.getWidth ());
63+ auto int64Type = rewriter.getI64Type ();
64+ Value lhsBits = arith::ExtUIOp::create (
65+ rewriter, loc, int64Type,
66+ arith::BitcastOp::create (rewriter, loc, intWType, op.getLhs ()));
67+ Value rhsBits = arith::ExtUIOp::create (
68+ rewriter, loc, int64Type,
69+ arith::BitcastOp::create (rewriter, loc, intWType, op.getRhs ()));
70+
71+ // Call APFloat function.
72+ int32_t sem = llvm::APFloatBase::SemanticsToEnum (floatTy.getFloatSemantics ());
73+ Value semValue = arith::ConstantOp::create (
74+ rewriter, loc, rewriter.getI32Type (),
75+ rewriter.getIntegerAttr (rewriter.getI32Type (), sem));
76+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
77+ auto resultOp =
78+ func::CallOp::create (rewriter, loc, TypeRange (rewriter.getI64Type ()),
79+ SymbolRefAttr::get (*fn), params);
80+
81+ // Truncate result to the original width.
82+ Value truncatedBits =
83+ arith::TruncIOp::create (rewriter, loc, intWType, resultOp->getResult (0 ));
84+ rewriter.replaceOp (
85+ op, arith::BitcastOp::create (rewriter, loc, floatTy, truncatedBits));
86+ return success ();
12087}
12188
12289namespace {
@@ -126,10 +93,34 @@ struct ArithToAPFloatConversionPass final
12693 ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase;
12794
12895 void runOnOperation () override {
129- Operation *op = getOperation ();
130- RewritePatternSet patterns (op->getContext ());
131- arith::populateArithToAPFloatConversionPatterns (patterns);
132- if (failed (applyPatternsGreedily (op, std::move (patterns))))
96+ ModuleOp module = getOperation ();
97+ IRRewriter rewriter (getOperation ()->getContext ());
98+ SmallVector<arith::AddFOp> addOps;
99+ WalkResult status = module ->walk ([&](Operation *op) {
100+ rewriter.setInsertionPoint (op);
101+ LogicalResult result =
102+ llvm::TypeSwitch<Operation *, LogicalResult>(op)
103+ .Case <arith::AddFOp>([&](arith::AddFOp op) {
104+ return rewriteBinaryOp (rewriter, module , op, " add" );
105+ })
106+ .Case <arith::SubFOp>([&](arith::SubFOp op) {
107+ return rewriteBinaryOp (rewriter, module , op, " subtract" );
108+ })
109+ .Case <arith::MulFOp>([&](arith::MulFOp op) {
110+ return rewriteBinaryOp (rewriter, module , op, " multiply" );
111+ })
112+ .Case <arith::DivFOp>([&](arith::DivFOp op) {
113+ return rewriteBinaryOp (rewriter, module , op, " divide" );
114+ })
115+ .Case <arith::RemFOp>([&](arith::RemFOp op) {
116+ return rewriteBinaryOp (rewriter, module , op, " remainder" );
117+ })
118+ .Default ([](Operation *op) { return success (); });
119+ if (failed (result))
120+ return WalkResult::interrupt ();
121+ return WalkResult::advance ();
122+ });
123+ if (status.wasInterrupted ())
133124 return signalPassFailure ();
134125 }
135126};
0 commit comments