Skip to content

Commit d94e2d1

Browse files
committed
address the rest of the comments
1 parent b644bfb commit d94e2d1

File tree

1 file changed

+62
-67
lines changed

1 file changed

+62
-67
lines changed

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 62 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
#include "mlir/Dialect/Arith/Transforms/Passes.h"
1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Dialect/Func/Utils/Utils.h"
15+
#include "mlir/IR/PatternMatch.h"
1516
#include "mlir/IR/Verifier.h"
16-
17-
#include "llvm/ADT/TypeSwitch.h"
17+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
1818

1919
namespace mlir {
2020
#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
@@ -24,7 +24,7 @@ namespace mlir {
2424
using namespace mlir;
2525
using namespace mlir::func;
2626

27-
/// Helper function to lookup or create the symbol for a runtime library
27+
/// Helper function to look up or create the symbol for a runtime library
2828
/// function for a binary arithmetic operation.
2929
///
3030
/// Parameter 1: APFloat semantics
@@ -46,80 +46,75 @@ lookupOrCreateBinaryFn(OpBuilder &b, Operation *moduleOp, StringRef name,
4646
}
4747

4848
/// Rewrite a binary arithmetic operation to an APFloat function call.
49-
template <typename OpTy>
50-
static LogicalResult rewriteBinaryOp(RewriterBase &rewriter, ModuleOp module,
51-
OpTy op, StringRef apfloatName) {
52-
// Get APFloat function from runtime library.
53-
FailureOr<Operation *> fn =
54-
lookupOrCreateBinaryFn(rewriter, module, apfloatName);
55-
if (failed(fn))
56-
return op->emitError("failed to lookup or create APFloat function");
49+
template <typename OpTy, const char *APFloatName>
50+
struct ArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
51+
using OpRewritePattern<OpTy>::OpRewritePattern;
5752

58-
// Cast operands to 64-bit integers.
59-
Location loc = op.getLoc();
60-
auto floatTy = cast<FloatType>(op.getType());
61-
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
62-
auto int64Type = rewriter.getI64Type();
63-
Value lhsBits = arith::ExtUIOp::create(
64-
rewriter, loc, int64Type,
65-
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
66-
Value rhsBits = arith::ExtUIOp::create(
67-
rewriter, loc, int64Type,
68-
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
53+
LogicalResult matchAndRewrite(OpTy op,
54+
PatternRewriter &rewriter) const override {
55+
auto moduleOp = op->template getParentOfType<ModuleOp>();
56+
if (!moduleOp) {
57+
op.emitError("arith op must be contained within a builtin.module");
58+
return failure();
59+
}
60+
// Get APFloat function from runtime library.
61+
FailureOr<Operation *> fn =
62+
lookupOrCreateBinaryFn(rewriter, moduleOp, APFloatName);
63+
if (failed(fn))
64+
return op->emitError("failed to lookup or create APFloat function");
6965

70-
// Call APFloat function.
71-
int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
72-
Value semValue = arith::ConstantOp::create(
73-
rewriter, loc, rewriter.getI32Type(),
74-
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
75-
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
76-
auto resultOp =
77-
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
78-
SymbolRefAttr::get(*fn), params);
66+
rewriter.setInsertionPoint(op);
67+
// Cast operands to 64-bit integers.
68+
Location loc = op.getLoc();
69+
auto floatTy = cast<FloatType>(op.getType());
70+
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
71+
auto int64Type = rewriter.getI64Type();
72+
Value lhsBits = arith::ExtUIOp::create(
73+
rewriter, loc, int64Type,
74+
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
75+
Value rhsBits = arith::ExtUIOp::create(
76+
rewriter, loc, int64Type,
77+
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
7978

80-
// Truncate result to the original width.
81-
Value truncatedBits =
82-
arith::TruncIOp::create(rewriter, loc, intWType, resultOp->getResult(0));
83-
rewriter.replaceOp(
84-
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
85-
return success();
86-
}
79+
// Call APFloat function.
80+
int32_t sem =
81+
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
82+
Value semValue = arith::ConstantOp::create(
83+
rewriter, loc, rewriter.getI32Type(),
84+
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
85+
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
86+
auto resultOp =
87+
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
88+
SymbolRefAttr::get(*fn), params);
89+
90+
// Truncate result to the original width.
91+
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
92+
resultOp->getResult(0));
93+
rewriter.replaceOp(
94+
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
95+
return success();
96+
}
97+
};
8798

8899
namespace {
89100
struct ArithToAPFloatConversionPass final
90101
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
91102
using Base::Base;
92103

93104
void runOnOperation() override {
94-
ModuleOp moduleOp = getOperation();
95-
IRRewriter rewriter(getOperation()->getContext());
96-
SmallVector<arith::AddFOp> addOps;
97-
WalkResult status = moduleOp->walk([&](Operation *op) {
98-
rewriter.setInsertionPoint(op);
99-
LogicalResult result =
100-
llvm::TypeSwitch<Operation *, LogicalResult>(op)
101-
.Case<arith::AddFOp>([&](arith::AddFOp op) {
102-
return rewriteBinaryOp(rewriter, moduleOp, op, "add");
103-
})
104-
.Case<arith::SubFOp>([&](arith::SubFOp op) {
105-
return rewriteBinaryOp(rewriter, moduleOp, op, "subtract");
106-
})
107-
.Case<arith::MulFOp>([&](arith::MulFOp op) {
108-
return rewriteBinaryOp(rewriter, moduleOp, op, "multiply");
109-
})
110-
.Case<arith::DivFOp>([&](arith::DivFOp op) {
111-
return rewriteBinaryOp(rewriter, moduleOp, op, "divide");
112-
})
113-
.Case<arith::RemFOp>([&](arith::RemFOp op) {
114-
return rewriteBinaryOp(rewriter, moduleOp, op, "remainder");
115-
})
116-
.Default([](Operation *op) { return success(); });
117-
if (failed(result))
118-
return WalkResult::interrupt();
119-
return WalkResult::advance();
120-
});
121-
if (status.wasInterrupted())
122-
return signalPassFailure();
105+
MLIRContext *context = &getContext();
106+
RewritePatternSet patterns(context);
107+
static const char add[] = "add";
108+
static const char subtract[] = "subtract";
109+
static const char multiply[] = "multiply";
110+
static const char divide[] = "divide";
111+
static const char remainder[] = "remainder";
112+
patterns.add<ArithOpToAPFloatConversion<arith::AddFOp, add>,
113+
ArithOpToAPFloatConversion<arith::SubFOp, subtract>,
114+
ArithOpToAPFloatConversion<arith::MulFOp, multiply>,
115+
ArithOpToAPFloatConversion<arith::DivFOp, divide>,
116+
ArithOpToAPFloatConversion<arith::RemFOp, remainder>>(context);
117+
walkAndApplyPatterns(getOperation(), std::move(patterns));
123118
}
124119
};
125120
} // namespace

0 commit comments

Comments
 (0)