Skip to content

Commit de9b2af

Browse files
walk instead of dialect conversion
1 parent 78df4a8 commit de9b2af

File tree

6 files changed

+138
-148
lines changed

6 files changed

+138
-148
lines changed

mlir/include/mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,10 @@
1212
#include <memory>
1313

1414
namespace mlir {
15-
16-
class DialectRegistry;
17-
class RewritePatternSet;
1815
class Pass;
1916

2017
#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
2118
#include "mlir/Conversion/Passes.h.inc"
22-
23-
namespace arith {
24-
void populateArithToAPFloatConversionPatterns(RewritePatternSet &patterns);
25-
} // namespace arith
2619
} // namespace mlir
2720

28-
#endif // MLIR_CONVERSION_ARITHTOAPFloat_ARITHTOAPFloat_H
21+
#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,13 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
190190
// ArithToAPFloat
191191
//===----------------------------------------------------------------------===//
192192

193-
def ArithToAPFloatConversionPass : Pass<"convert-arith-to-apfloat"> {
194-
let summary = "Convert Arith dialect ops on FP8 types to APFloat lib calls";
193+
def ArithToAPFloatConversionPass
194+
: Pass<"convert-arith-to-apfloat", "ModuleOp"> {
195+
let summary = "Convert Arith ops to APFloat runtime library calls";
195196
let description = [{
196-
This pass converts supported Arith ops which manipulate FP8 typed values to APFloat lib calls.
197+
This pass converts supported Arith ops to APFloat-based runtime library
198+
calls (APFloatWrappers.cpp). APFloat is a software implementation of
199+
floating-point arithmetic operations.
197200
}];
198201
let dependentDialects = ["func::FuncDialect"];
199202
let options = [];

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 75 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Dialect/Func/Utils/Utils.h"
1515
#include "mlir/IR/Verifier.h"
16-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16+
17+
#include "llvm/ADT/TypeSwitch.h"
1718

1819
namespace mlir {
1920
#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
@@ -23,100 +24,55 @@ namespace mlir {
2324
using namespace mlir;
2425
using namespace mlir::func;
2526

26-
#define APFLOAT_BIN_OPS(X) \
27-
X(add) \
28-
X(subtract) \
29-
X(multiply) \
30-
X(divide) \
31-
X(remainder) \
32-
X(mod)
33-
34-
#define APFLOAT_EXTERN_K(OP) kApFloat_##OP
35-
36-
#define APFLOAT_EXTERN_NAME(OP) \
37-
static constexpr llvm::StringRef APFLOAT_EXTERN_K(OP) = "_mlir_" \
38-
"apfloat_" #OP;
39-
40-
namespace mlir::func {
41-
#define LOOKUP_OR_CREATE_APFLOAT_FN_DECL(OP) \
42-
FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
43-
OpBuilder &b, Operation *moduleOp, \
44-
SymbolTableCollection *symbolTables = nullptr);
45-
46-
APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DECL)
47-
48-
#undef LOOKUP_OR_CREATE_APFLOAT_FN_DECL
49-
50-
APFLOAT_BIN_OPS(APFLOAT_EXTERN_NAME)
51-
52-
#define LOOKUP_OR_CREATE_APFLOAT_FN_DEFN(OP) \
53-
FailureOr<FuncOp> lookupOrCreateApFloat##OP##Fn( \
54-
OpBuilder &b, Operation *moduleOp, \
55-
SymbolTableCollection *symbolTables) { \
56-
return lookupOrCreateFn(b, moduleOp, APFLOAT_EXTERN_K(OP), \
57-
{IntegerType::get(moduleOp->getContext(), 32), \
58-
IntegerType::get(moduleOp->getContext(), 64), \
59-
IntegerType::get(moduleOp->getContext(), 64)}, \
60-
{IntegerType::get(moduleOp->getContext(), 64)}, \
61-
/*setPrivate*/ true, symbolTables); \
62-
}
63-
64-
APFLOAT_BIN_OPS(LOOKUP_OR_CREATE_APFLOAT_FN_DEFN)
65-
#undef LOOKUP_OR_CREATE_APFLOAT_FN_DEFN
66-
} // namespace mlir::func
67-
68-
struct FancyAddFLowering : OpRewritePattern<arith::AddFOp> {
69-
using OpRewritePattern::OpRewritePattern;
70-
71-
LogicalResult matchAndRewrite(arith::AddFOp op,
72-
PatternRewriter &rewriter) const override {
73-
// Get APFloat adder function from runtime library.
74-
auto parent = op->getParentOfType<ModuleOp>();
75-
if (!parent)
76-
return failure();
77-
if (!llvm::isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
78-
Float8E5M2FNUZType, Float8E4M3FNUZType,
79-
Float8E4M3B11FNUZType, Float8E3M4Type, Float4E2M1FNType,
80-
Float6E2M3FNType, Float6E3M2FNType, Float8E8M0FNUType>(
81-
op.getType()))
82-
return failure();
83-
FailureOr<Operation *> adder = lookupOrCreateApFloataddFn(rewriter, parent);
84-
85-
// Cast operands to 64-bit integers.
86-
Location loc = op.getLoc();
87-
auto floatTy = cast<FloatType>(op.getType());
88-
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
89-
auto int64Type = rewriter.getI64Type();
90-
Value lhsBits = arith::ExtUIOp::create(
91-
rewriter, loc, int64Type,
92-
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
93-
Value rhsBits = arith::ExtUIOp::create(
94-
rewriter, loc, int64Type,
95-
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
96-
97-
// Call software implementation of floating point addition.
98-
int32_t sem =
99-
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
100-
Value semValue = arith::ConstantOp::create(
101-
rewriter, loc, rewriter.getI32Type(),
102-
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
103-
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
104-
auto resultOp =
105-
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
106-
SymbolRefAttr::get(*adder), params);
107-
108-
// Truncate result to the original width.
109-
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
110-
resultOp->getResult(0));
111-
rewriter.replaceAllUsesWith(
112-
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
113-
return success();
114-
}
115-
};
27+
static FailureOr<Operation *>
28+
lookupOrCreateBinaryFn(OpBuilder &b, Operation *moduleOp, StringRef name,
29+
SymbolTableCollection *symbolTables = nullptr) {
30+
return lookupOrCreateFn(b, moduleOp,
31+
(llvm::Twine("_mlir_apfloat_") + name).str(),
32+
{IntegerType::get(moduleOp->getContext(), 32),
33+
IntegerType::get(moduleOp->getContext(), 64),
34+
IntegerType::get(moduleOp->getContext(), 64)},
35+
{IntegerType::get(moduleOp->getContext(), 64)},
36+
/*setPrivate*/ true, symbolTables);
37+
}
11638

117-
void arith::populateArithToAPFloatConversionPatterns(
118-
RewritePatternSet &patterns) {
119-
patterns.add<FancyAddFLowering>(patterns.getContext());
39+
template <typename OpTy>
40+
static LogicalResult rewriteBinaryOp(RewriterBase &rewriter, ModuleOp module,
41+
OpTy op, StringRef apfloatName) {
42+
// Get APFloat function from runtime library.
43+
FailureOr<Operation *> fn =
44+
lookupOrCreateBinaryFn(rewriter, module, apfloatName);
45+
if (failed(fn))
46+
return op->emitError("failed to lookup or create APFloat function");
47+
48+
// Cast operands to 64-bit integers.
49+
Location loc = op.getLoc();
50+
auto floatTy = cast<FloatType>(op.getType());
51+
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
52+
auto int64Type = rewriter.getI64Type();
53+
Value lhsBits = arith::ExtUIOp::create(
54+
rewriter, loc, int64Type,
55+
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
56+
Value rhsBits = arith::ExtUIOp::create(
57+
rewriter, loc, int64Type,
58+
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
59+
60+
// Call APFloat function.
61+
int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
62+
Value semValue = arith::ConstantOp::create(
63+
rewriter, loc, rewriter.getI32Type(),
64+
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
65+
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
66+
auto resultOp =
67+
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
68+
SymbolRefAttr::get(*fn), params);
69+
70+
// Truncate result to the original width.
71+
Value truncatedBits =
72+
arith::TruncIOp::create(rewriter, loc, intWType, resultOp->getResult(0));
73+
rewriter.replaceOp(
74+
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
75+
return success();
12076
}
12177

12278
namespace {
@@ -126,10 +82,31 @@ struct ArithToAPFloatConversionPass final
12682
ArithToAPFloatConversionPass>::ArithToAPFloatConversionPassBase;
12783

12884
void runOnOperation() override {
129-
Operation *op = getOperation();
130-
RewritePatternSet patterns(op->getContext());
131-
arith::populateArithToAPFloatConversionPatterns(patterns);
132-
if (failed(applyPatternsGreedily(op, std::move(patterns))))
85+
ModuleOp module = getOperation();
86+
IRRewriter rewriter(getOperation()->getContext());
87+
SmallVector<arith::AddFOp> addOps;
88+
WalkResult status = module->walk([&](Operation *op) {
89+
rewriter.setInsertionPoint(op);
90+
LogicalResult result =
91+
llvm::TypeSwitch<Operation *, LogicalResult>(op)
92+
.Case<arith::AddFOp>([&](arith::AddFOp op) {
93+
return rewriteBinaryOp(rewriter, module, op, "add");
94+
})
95+
.Case<arith::SubFOp>([&](arith::SubFOp op) {
96+
return rewriteBinaryOp(rewriter, module, op, "subtract");
97+
})
98+
.Case<arith::MulFOp>([&](arith::MulFOp op) {
99+
return rewriteBinaryOp(rewriter, module, op, "mulitply");
100+
})
101+
.Case<arith::DivFOp>([&](arith::DivFOp op) {
102+
return rewriteBinaryOp(rewriter, module, op, "divide");
103+
})
104+
.Default([](Operation *op) { return success(); });
105+
if (failed(result))
106+
return WalkResult::interrupt();
107+
return WalkResult::advance();
108+
});
109+
if (status.wasInterrupted())
133110
return signalPassFailure();
134111
}
135112
};

mlir/lib/ExecutionEngine/APFloatWrappers.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- ArmRunnerUtils.cpp - Utilities for configuring architecture properties //
1+
//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -16,31 +16,29 @@
1616
#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
1717
#endif
1818

19+
/// Binary operations without rounding mode.
1920
#define APFLOAT_BINARY_OP(OP) \
20-
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \
21+
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP( \
2122
int32_t semantics, uint64_t a, uint64_t b) { \
2223
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
2324
static_cast<llvm::APFloatBase::Semantics>(semantics)); \
2425
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
2526
llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
2627
llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
2728
llvm::APFloatBase::opStatus status = lhs.OP(rhs); \
28-
assert(status == llvm::APFloatBase::opOK && "expected " #OP \
29-
" opstatus to be OK"); \
3029
return lhs.bitcastToAPInt().getZExtValue(); \
3130
}
3231

32+
/// Binary operations with rounding mode.
3333
#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \
34-
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED APFloat_##OP( \
34+
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP( \
3535
int32_t semantics, uint64_t a, uint64_t b) { \
3636
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
3737
static_cast<llvm::APFloatBase::Semantics>(semantics)); \
3838
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
3939
llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
4040
llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
4141
llvm::APFloatBase::opStatus status = lhs.OP(rhs, ROUNDING_MODE); \
42-
assert(status == llvm::APFloatBase::opOK && "expected " #OP \
43-
" opstatus to be OK"); \
4442
return lhs.bitcastToAPInt().getZExtValue(); \
4543
}
4644

@@ -68,6 +66,6 @@ void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
6866
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
6967
llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
7068
double d = x.convertToDouble();
71-
std::cout << d << std::endl;
69+
fprintf(stdout, "%lg", d);
7270
}
7371
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: mlir-opt %s --convert-arith-to-apfloat | FileCheck %s
2+
3+
// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
4+
5+
// CHECK-LABEL: func.func @foo() -> f8E4M3FN {
6+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 2.250000e+00 : f8E4M3FN
7+
// CHECK: return %[[CONSTANT_0]] : f8E4M3FN
8+
// CHECK: }
9+
10+
// CHECK-LABEL: func.func @entry() {
11+
// CHECK: %[[cst:.*]] = arith.constant 1.375000e+00 : f8E4M3FN
12+
// CHECK: %[[rhs:.*]] = call @foo() : () -> f8E4M3FN
13+
// CHECK: %[[lhs_casted:.*]] = arith.bitcast %[[cst]] : f8E4M3FN to i8
14+
// CHECK: %[[lhs_ext:.*]] = arith.extui %[[lhs_casted]] : i8 to i64
15+
// CHECK: %[[rhs_casted:.*]] = arith.bitcast %[[rhs]] : f8E4M3FN to i8
16+
// CHECK: %[[rhs_ext:.*]] = arith.extui %[[rhs_casted]] : i8 to i64
17+
// CHECK: %[[c10_i32:.*]] = arith.constant 10 : i32
18+
// CHECK: %[[res:.*]] = call @_mlir_apfloat_add(%[[c10_i32]], %[[lhs_ext]], %[[rhs_ext]]) : (i32, i64, i64) -> i64
19+
// CHECK: %[[res_trunc:.*]] = arith.trunci %[[res]] : i64 to i8
20+
// CHECK: %[[res_casted:.*]] = arith.bitcast %[[res_trunc]] : i8 to f8E4M3FN
21+
// CHECK: vector.print %[[res_casted]] : f8E4M3FN
22+
// CHECK: return
23+
// CHECK: }
24+
25+
// Put rhs into separate function so that it won't be constant-folded.
26+
func.func @foo() -> f8E4M3FN {
27+
%cst = arith.constant 2.2 : f8E4M3FN
28+
return %cst : f8E4M3FN
29+
}
30+
31+
func.func @entry() {
32+
%a = arith.constant 1.4 : f8E4M3FN
33+
%b = func.call @foo() : () -> (f8E4M3FN)
34+
%c = arith.addf %a, %b : f8E4M3FN
35+
36+
vector.print %c : f8E4M3FN
37+
return
38+
}
Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,19 @@
1-
// Check that the ceildivsi lowering is correct.
2-
// We do not check any poison or UB values, as it is not possible to catch them.
3-
4-
// RUN: mlir-opt %s --convert-arith-to-apfloat
1+
// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-to-llvm | \
2+
// RUN: mlir-runner -e entry --entry-point-result=void \
3+
// RUN: --shared-libs=%mlir_c_runner_utils | FileCheck %s
54

65
// Put rhs into separate function so that it won't be constant-folded.
7-
func.func @foo() -> f4E2M1FN {
8-
%cst = arith.constant 5.0 : f4E2M1FN
9-
return %cst : f4E2M1FN
6+
func.func @foo() -> f8E4M3FN {
7+
%cst = arith.constant 2.2 : f8E4M3FN
8+
return %cst : f8E4M3FN
109
}
1110

1211
func.func @entry() {
13-
%a = arith.constant 5.0 : f4E2M1FN
14-
%b = func.call @foo() : () -> (f4E2M1FN)
15-
%c = arith.addf %a, %b : f4E2M1FN
16-
vector.print %c : f4E2M1FN
12+
%a = arith.constant 1.4 : f8E4M3FN
13+
%b = func.call @foo() : () -> (f8E4M3FN)
14+
%c = arith.addf %a, %b : f8E4M3FN
15+
16+
// CHECK: 3.5
17+
vector.print %c : f8E4M3FN
1718
return
1819
}
19-
20-
// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
21-
22-
// CHECK-LABEL: func.func @foo() -> f4E2M1FN {
23-
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 4.000000e+00 : f4E2M1FN
24-
// CHECK: return %[[CONSTANT_0]] : f4E2M1FN
25-
// CHECK: }
26-
27-
// CHECK-LABEL: func.func @entry() {
28-
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 18 : i32
29-
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 6 : i64
30-
// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f4E2M1FN
31-
// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[VAL_0]] : f4E2M1FN to i4
32-
// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i4 to i64
33-
// CHECK: %[[VAL_1:.*]] = call @_mlir_apfloat_add(%[[CONSTANT_0]], %[[EXTUI_0]], %[[CONSTANT_1]]) : (i32, i64, i64) -> i64
34-
// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i4
35-
// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[TRUNCI_0]] : i4 to f4E2M1FN
36-
// CHECK: vector.print %[[BITCAST_1]] : f4E2M1FN
37-
// CHECK: return
38-
// CHECK: }

0 commit comments

Comments
 (0)