Skip to content

Commit 1180064

Browse files
walk instead of dialect conversion
1 parent 78df4a8 commit 1180064

File tree

6 files changed

+219
-155
lines changed

6 files changed

+219
-155
lines changed

mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,10 @@
1212
#include <memory>
1313

1414
namespace mlir {
15-
16-
class DialectRegistry;
17-
class RewritePatternSet;
1815
class Pass;
1916

2017
#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
2118
#include "mlir/Conversion/Passes.h.inc"
22-
23-
namespace arith {
24-
void populateArithToAPFloatConversionPatterns(RewritePatternSet &patterns);
25-
} // namespace arith
2619
} // namespace mlir
2720

28-
#endif // MLIR_CONVERSION_ARITHTOAPFloat_ARITHTOAPFloat_H
21+
#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,13 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
190190
// ArithToAPFloat
191191
//===----------------------------------------------------------------------===//
192192

193-
def ArithToAPFloatConversionPass : Pass<"convert-arith-to-apfloat"> {
194-
let summary = "Convert Arith dialect ops on FP8 types to APFloat lib calls";
193+
def ArithToAPFloatConversionPass
194+
: Pass<"convert-arith-to-apfloat", "ModuleOp"> {
195+
let summary = "Convert Arith ops to APFloat runtime library calls";
195196
let description = [{
196-
This pass converts supported Arith ops which manipulate FP8 typed values to APFloat lib calls.
197+
This pass converts supported Arith ops to APFloat-based runtime library
198+
calls (APFloatWrappers.cpp). APFloat is a software implementation of
199+
floating-point arithmetic operations.
197200
}];
198201
let dependentDialects = ["func::FuncDialect"];
199202
let options = [];

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 90 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- ArithToAPFloat.cpp - Arithmetic to APFloat impl conversion ---------===//
1+
//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -13,7 +13,8 @@
1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Dialect/Func/Utils/Utils.h"
1515
#include "mlir/IR/Verifier.h"
16-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16+
17+
#include "llvm/ADT/TypeSwitch.h"
1718

1819
namespace mlir {
1920
#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
@@ -23,100 +24,66 @@ namespace mlir {
2324
using namespace mlir;
2425
using namespace mlir::func;
2526

26-
#define APFLOAT_BIN_OPS(X) \
27-
X(add) \
28-
X(subtract) \
29-
X(multiply) \
30-
X(divide) \
31-
X(remainder) \
32-
X(mod)
33-
34-
#define APFLOAT_EXTERN_K(OP) kApFloat_##OP
35-
36-
#define APFLOAT_EXTERN_NAME(OP) \
37-
static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "_mlir_" \
38-
"apfloat_" #OP;
39-
40-
namespace mlir::func {
41-
#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \
42-
FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
43-
OpBuilder &b, Operation *moduleOp, \
44-
SymbolTableCollection *symbolTables = nullptr);
45-
46-
APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
47-
48-
#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
49-
50-
APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME)
51-
52-
#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP) \
53-
FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
54-
OpBuilder &b, Operation *moduleOp, \
55-
SymbolTableCollection *symbolTables) { \
56-
return lookupOrCreateFn(b, moduleOp, APFLOAT_EXTERN_K(OP), \
57-
{IntegerType::get(moduleOp->getContext(), 32), \
58-
IntegerType::get(moduleOp->getContext(), 64), \
59-
IntegerType::get(moduleOp->getContext(), 64)}, \
60-
{IntegerType::get(moduleOp->getContext(), 64)}, \
61-
/*setPrivate*/ true, symbolTables); \
62-
}
63-
64-
APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
65-
#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
66-
} // namespace mlir::func
67-
68-
struct FancyAddFLowering : OpRewritePattern<arith::AddFOp> {
69-
using OpRewritePattern::OpRewritePattern;
70-
71-
LogicalResult matchAndRewrite(arith::AddFOp op,
72-
PatternRewriter &rewriter) const override {
73-
// Get APFloat adder function from runtime library.
74-
auto parent = op->getParentOfType<ModuleOp>();
75-
if (!parent)
76-
return failure();
77-
if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
78-
Float8E5M2FNUZType, Float8E4M3FNUZType,
79-
Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
80-
Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
81-
op.getType()))
82-
return failure();
83-
FailureOr<Operation *> adder = lookupOrCreateApFloataddFn(rewriter, parent);
84-
85-
// Cast operands to 64-bit integers.
86-
Location loc = op.getLoc();
87-
auto floatTy = cast<FloatType>(op.getType());
88-
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
89-
auto int64Type = rewriter.getI64Type();
90-
Value lhsBits = arith::ExtUIOp::create(
91-
rewriter, loc, int64Type,
92-
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
93-
Value rhsBits = arith::ExtUIOp::create(
94-
rewriter, loc, int64Type,
95-
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
96-
97-
// Call software implementation of floating point addition.
98-
int32_t sem =
99-
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
100-
Value semValue = arith::ConstantOp::create(
101-
rewriter, loc, rewriter.getI32Type(),
102-
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
103-
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
104-
auto resultOp =
105-
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
106-
SymbolRefAttr::get(*adder), params);
107-
108-
// Truncate result to the original width.
109-
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
110-
resultOp->getResult(0));
111-
rewriter.replaceAllUsesWith(
112-
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
113-
return success();
114-
}
115-
};
27+
/// Helper function to lookup or create the symbol for a runtime library
28+
/// function for a binary arithmetic operation.
29+
///
30+
/// Parameter 1: APFloat semantics
31+
/// Parameter 2: Left-hand side operand
32+
/// Parameter 3: Right-hand side operand
33+
///
34+
/// This function will return a failure if the function is found but has an
35+
/// unexpected signature.
36+
///
37+
static FailureOr<Operation *>
38+
lookupOrCreateBinaryFn(OpBuilder &b, Operation *moduleOp, StringRef name,
39+
SymbolTableCollection *symbolTables = nullptr) {
40+
return lookupOrCreateFn(b, moduleOp,
41+
(llvm::Twine("_mlir_apfloat_") + name).str(),
42+
{IntegerType::get(moduleOp->getContext(), 32),
43+
IntegerType::get(moduleOp->getContext(), 64),
44+
IntegerType::get(moduleOp->getContext(), 64)},
45+
{IntegerType::get(moduleOp->getContext(), 64)},
46+
/*setPrivate=*/true, symbolTables);
47+
}
11648

117-
void arith::populateArithToAPFloatConversionPatterns(
118-
RewritePatternSet &patterns) {
119-
patterns.add<FancyAddFLowering>(patterns.getContext());
49+
/// Rewrite a binary arithmetic operation to an APFloat function call.
50+
template <typename OpTy>
51+
static LogicalResult rewriteBinaryOp(RewriterBase &rewriter, ModuleOp module,
52+
OpTy op, StringRef apfloatName) {
53+
// Get APFloat function from runtime library.
54+
FailureOr<Operation *> fn =
55+
lookupOrCreateBinaryFn(rewriter, module, apfloatName);
56+
if (failed(fn))
57+
return op->emitError("failed to lookup or create APFloat function");
58+
59+
// Cast operands to 64-bit integers.
60+
Location loc = op.getLoc();
61+
auto floatTy = cast<FloatType>(op.getType());
62+
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
63+
auto int64Type = rewriter.getI64Type();
64+
Value lhsBits = arith::ExtUIOp::create(
65+
rewriter, loc, int64Type,
66+
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
67+
Value rhsBits = arith::ExtUIOp::create(
68+
rewriter, loc, int64Type,
69+
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
70+
71+
// Call APFloat function.
72+
int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
73+
Value semValue = arith::ConstantOp::create(
74+
rewriter, loc, rewriter.getI32Type(),
75+
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
76+
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
77+
auto resultOp =
78+
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
79+
SymbolRefAttr::get(*fn), params);
80+
81+
// Truncate result to the original width.
82+
Value truncatedBits =
83+
arith::TruncIOp::create(rewriter, loc, intWType, resultOp->getResult(0));
84+
rewriter.replaceOp(
85+
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
86+
return success();
12087
}
12188

12289
namespace {
@@ -126,10 +93,34 @@ struct ArithToAPFloatConversionPass final
12693
ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase;
12794

12895
void runOnOperation() override {
129-
Operation *op = getOperation();
130-
RewritePatternSet patterns(op->getContext());
131-
arith::populateArithToAPFloatConversionPatterns(patterns);
132-
if (failed(applyPatternsGreedily(op, std::move(patterns))))
96+
ModuleOp module = getOperation();
97+
IRRewriter rewriter(getOperation()->getContext());
98+
SmallVector<arith::AddFOp> addOps;
99+
WalkResult status = module->walk([&](Operation *op) {
100+
rewriter.setInsertionPoint(op);
101+
LogicalResult result =
102+
llvm::TypeSwitch<Operation *, LogicalResult>(op)
103+
.Case<arith::AddFOp>([&](arith::AddFOp op) {
104+
return rewriteBinaryOp(rewriter, module, op, "add");
105+
})
106+
.Case<arith::SubFOp>([&](arith::SubFOp op) {
107+
return rewriteBinaryOp(rewriter, module, op, "subtract");
108+
})
109+
.Case<arith::MulFOp>([&](arith::MulFOp op) {
110+
return rewriteBinaryOp(rewriter, module, op, "multiply");
111+
})
112+
.Case<arith::DivFOp>([&](arith::DivFOp op) {
113+
return rewriteBinaryOp(rewriter, module, op, "divide");
114+
})
115+
.Case<arith::RemFOp>([&](arith::RemFOp op) {
116+
return rewriteBinaryOp(rewriter, module, op, "remainder");
117+
})
118+
.Default([](Operation *op) { return success(); });
119+
if (failed(result))
120+
return WalkResult::interrupt();
121+
return WalkResult::advance();
122+
});
123+
if (status.wasInterrupted())
133124
return signalPassFailure();
134125
}
135126
};
Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,55 @@
1-
//===- ArmRunnerUtils.cpp - Utilities for configuring architecture properties //
1+
//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
8+
//
9+
// This file exposes the APFloat infrastructure to MLIR programs as a runtime
10+
// library. APFloat is a software implementation of floating point arithmetics.
11+
//
12+
// On the MLIR side, floating-point values must be bitcasted to 64-bit integers
13+
// before calling a runtime function. If a floating-point type has less than
14+
// 64 bits, it must be zero-extended to 64 bits after bitcasting it to an
15+
// integer.
16+
//
17+
// Runtime functions receive the floating-point operands of the arithmeic
18+
// operation in the form of 64-bit integers, along with the APFloat semantics
19+
// in the form of a 32-bit integer, which will be interpreted as an
20+
// APFloatBase::Semantics enum value.
21+
//
922
#include "llvm/ADT/APFloat.h"
1023

11-
#include <iostream>
12-
1324
#if (defined(_WIN32) || defined(__CYGWIN__))
1425
#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
1526
#else
1627
#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
1728
#endif
1829

30+
/// Binary operations without rounding mode.
1931
#define APFLOAT_BINARY_OP(OP) \
20-
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \
32+
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP( \
2133
int32_t semantics, uint64_t a, uint64_t b) { \
2234
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
2335
static_cast<llvm::APFloatBase::Semantics>(semantics)); \
2436
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
2537
llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
2638
llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
27-
llvm::APFloatBase::opStatus status = lhs.OP(rhs); \
28-
assert(status == llvm::APFloatBase::opOK && "expected " #OP \
29-
" opstatus to be OK"); \
39+
lhs.OP(rhs); \
3040
return lhs.bitcastToAPInt().getZExtValue(); \
3141
}
3242

43+
/// Binary operations with rounding mode.
3344
#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \
34-
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \
45+
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP( \
3546
int32_t semantics, uint64_t a, uint64_t b) { \
3647
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
3748
static_cast<llvm::APFloatBase::Semantics>(semantics)); \
3849
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
3950
llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
4051
llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
41-
llvm::APFloatBase::opStatus status = lhs.OP(rhs, ROUNDING_MODE); \
42-
assert(status == llvm::APFloatBase::opOK && "expected " #OP \
43-
" opstatus to be OK"); \
52+
lhs.OP(rhs, ROUNDING_MODE); \
4453
return lhs.bitcastToAPInt().getZExtValue(); \
4554
}
4655

@@ -57,7 +66,6 @@ BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE)
5766
#undef APFLOAT_BINARY_OP_ROUNDING_MODE
5867

5968
APFLOAT_BINARY_OP(remainder)
60-
APFLOAT_BINARY_OP(mod)
6169

6270
#undef APFLOAT_BINARY_OP
6371

@@ -68,6 +76,6 @@ void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
6876
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
6977
llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
7078
double d = x.convertToDouble();
71-
std::cout << d << std::endl;
79+
fprintf(stdout, "%lg", d);
7280
}
7381
}

0 commit comments

Comments
 (0)