Skip to content

Commit c79a9ef

Browse files
committed
make everyone happy
1 parent 7d29b71 commit c79a9ef

File tree

6 files changed

+148
-87
lines changed

6 files changed

+148
-87
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,6 @@ def ArithToAPFloatConversionPass
199199
floating-point arithmetic operations.
200200
}];
201201
let dependentDialects = ["func::FuncDialect"];
202-
let options = [];
203202
}
204203

205204
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Func/Utils/Utils.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,12 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
6060
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
6161
mlir::ModuleOp moduleOp);
6262

63-
/// Create a FuncOp with signature `resultTypes`(`paramTypes`)` and name `name`.
64-
/// Return a failure if the FuncOp found has unexpected signature.
65-
FailureOr<FuncOp>
66-
lookupOrCreateFnDecl(OpBuilder &b, Operation *moduleOp, StringRef name,
67-
ArrayRef<Type> paramTypes = {},
68-
ArrayRef<Type> resultTypes = {}, bool setPrivate = false,
69-
SymbolTableCollection *symbolTables = nullptr);
63+
/// Look up a FuncOp with signature `resultTypes`(`paramTypes`)` and name
64+
/// `name`. Return a failure if the FuncOp is found but with a different
65+
/// signature.
66+
FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
67+
FunctionType funcT,
68+
SymbolTableCollection *symbolTables = nullptr);
7069

7170
} // namespace func
7271
} // namespace mlir

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#include "mlir/IR/PatternMatch.h"
1616
#include "mlir/IR/Verifier.h"
1717
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
18+
#include "llvm/Support/Debug.h"
19+
20+
#define DEBUG_TYPE "arith-to-apfloat"
1821

1922
namespace mlir {
2023
#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
@@ -24,6 +27,22 @@ namespace mlir {
2427
using namespace mlir;
2528
using namespace mlir::func;
2629

30+
static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
31+
StringRef name, FunctionType funcT, bool setPrivate,
32+
SymbolTableCollection *symbolTables = nullptr) {
33+
OpBuilder::InsertionGuard g(b);
34+
assert(!symTable->getRegion(0).empty() && "expected non-empty region");
35+
b.setInsertionPointToStart(&symTable->getRegion(0).front());
36+
FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
37+
if (setPrivate)
38+
funcOp.setPrivate();
39+
if (symbolTables) {
40+
SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
41+
symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
42+
}
43+
return funcOp;
44+
}
45+
2746
/// Helper function to look up or create the symbol for a runtime library
2847
/// function for a binary arithmetic operation.
2948
///
@@ -34,34 +53,42 @@ using namespace mlir::func;
3453
/// This function will return a failure if the function is found but has an
3554
/// unexpected signature.
3655
///
37-
static FailureOr<Operation *>
38-
lookupOrCreateBinaryFn(OpBuilder &b, Operation *moduleOp, StringRef name,
56+
static FailureOr<FuncOp>
57+
lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
3958
SymbolTableCollection *symbolTables = nullptr) {
40-
auto i32Type = IntegerType::get(moduleOp->getContext(), 32);
41-
auto i64Type = IntegerType::get(moduleOp->getContext(), 64);
42-
return lookupOrCreateFnDecl(b, moduleOp,
43-
(llvm::Twine("_mlir_apfloat_") + name).str(),
44-
{i32Type, i64Type, i64Type}, {i64Type},
45-
/*setPrivate=*/true, symbolTables);
59+
auto i32Type = IntegerType::get(symTable->getContext(), 32);
60+
auto i64Type = IntegerType::get(symTable->getContext(), 64);
61+
62+
std::string funcName = (llvm::Twine("__mlir_apfloat_") + name).str();
63+
FunctionType funcT =
64+
FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type});
65+
FailureOr<FuncOp> func =
66+
lookupFnDecl(symTable, funcName, funcT, symbolTables);
67+
// Failed due to type mismatch.
68+
if (failed(func))
69+
return func;
70+
// Successfully matched existing decl.
71+
if (*func)
72+
return *func;
73+
74+
return createFnDecl(b, symTable, funcName, funcT,
75+
/*setPrivate=*/true, symbolTables);
4676
}
4777

4878
/// Rewrite a binary arithmetic operation to an APFloat function call.
4979
template <typename OpTy, const char *APFloatName>
50-
struct ArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
51-
using OpRewritePattern<OpTy>::OpRewritePattern;
80+
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
81+
BinaryArithOpToAPFloatConversion(MLIRContext *context, PatternBenefit benefit,
82+
SymbolOpInterface symTable)
83+
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {};
5284

5385
LogicalResult matchAndRewrite(OpTy op,
5486
PatternRewriter &rewriter) const override {
55-
auto moduleOp = op->template getParentOfType<ModuleOp>();
56-
if (!moduleOp) {
57-
return rewriter.notifyMatchFailure(
58-
op, "arith op must be contained within a builtin.module");
59-
}
6087
// Get APFloat function from runtime library.
61-
FailureOr<Operation *> fn =
62-
lookupOrCreateBinaryFn(rewriter, moduleOp, APFloatName);
88+
FailureOr<FuncOp> fn =
89+
lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
6390
if (failed(fn))
64-
return op->emitError("failed to lookup or create APFloat function");
91+
return fn;
6592

6693
rewriter.setInsertionPoint(op);
6794
// Cast operands to 64-bit integers.
@@ -94,6 +121,8 @@ struct ArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
94121
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
95122
return success();
96123
}
124+
125+
SymbolOpInterface symTable;
97126
};
98127

99128
namespace {
@@ -109,12 +138,24 @@ struct ArithToAPFloatConversionPass final
109138
static const char multiply[] = "multiply";
110139
static const char divide[] = "divide";
111140
static const char remainder[] = "remainder";
112-
patterns.add<ArithOpToAPFloatConversion<arith::AddFOp, add>,
113-
ArithOpToAPFloatConversion<arith::SubFOp, subtract>,
114-
ArithOpToAPFloatConversion<arith::MulFOp, multiply>,
115-
ArithOpToAPFloatConversion<arith::DivFOp, divide>,
116-
ArithOpToAPFloatConversion<arith::RemFOp, remainder>>(context);
141+
patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp, add>,
142+
BinaryArithOpToAPFloatConversion<arith::SubFOp, subtract>,
143+
BinaryArithOpToAPFloatConversion<arith::MulFOp, multiply>,
144+
BinaryArithOpToAPFloatConversion<arith::DivFOp, divide>,
145+
BinaryArithOpToAPFloatConversion<arith::RemFOp, remainder>>(
146+
context, 1, getOperation());
147+
LogicalResult result = success();
148+
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
149+
if (diag.getSeverity() == DiagnosticSeverity::Error) {
150+
result = failure();
151+
}
152+
// NB: if you don't return failure, no other diag handlers will fire (see
153+
// mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
154+
return failure();
155+
});
117156
walkAndApplyPatterns(getOperation(), std::move(patterns));
157+
if (failed(result))
158+
return signalPassFailure();
118159
}
119160
};
120161
} // namespace

mlir/lib/Dialect/Func/Utils/Utils.cpp

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -256,44 +256,26 @@ func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
256256
}
257257

258258
FailureOr<func::FuncOp>
259-
func::lookupOrCreateFnDecl(OpBuilder &b, Operation *moduleOp, StringRef name,
260-
ArrayRef<Type> paramTypes,
261-
ArrayRef<Type> resultTypes, bool setPrivate,
262-
SymbolTableCollection *symbolTables) {
263-
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
264-
"expected SymbolTable operation");
265-
259+
func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
260+
FunctionType funcT, SymbolTableCollection *symbolTables) {
266261
FuncOp func;
267262
if (symbolTables) {
268263
func = symbolTables->lookupSymbolIn<FuncOp>(
269-
moduleOp, StringAttr::get(moduleOp->getContext(), name));
264+
symTable, StringAttr::get(symTable->getContext(), name));
270265
} else {
271266
func = llvm::dyn_cast_or_null<FuncOp>(
272-
SymbolTable::lookupSymbolIn(moduleOp, name));
267+
SymbolTable::lookupSymbolIn(symTable, name));
273268
}
274269

275-
FunctionType funcT =
276-
FunctionType::get(b.getContext(), paramTypes, resultTypes);
277-
// Assert the signature of the found function is same as expected
278-
if (func) {
279-
if (funcT != func.getFunctionType()) {
280-
func.emitError("redefinition of function '")
281-
<< name << "' of different type " << funcT << " is prohibited";
282-
return failure();
283-
}
270+
if (!func)
284271
return func;
285-
}
286272

287-
OpBuilder::InsertionGuard g(b);
288-
assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
289-
b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
290-
FuncOp funcOp = FuncOp::create(b, moduleOp->getLoc(), name, funcT);
291-
if (setPrivate)
292-
funcOp.setPrivate();
293-
if (symbolTables) {
294-
SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
295-
symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
273+
mlir::FunctionType foundFuncT = func.getFunctionType();
274+
// Assert the signature of the found function is same as expected
275+
if (funcT != foundFuncT) {
276+
return func.emitError("matched function '")
277+
<< name << "' but with different type: " << foundFuncT
278+
<< " (expected " << funcT << ")";
296279
}
297-
298-
return funcOp;
280+
return func;
299281
}

mlir/lib/ExecutionEngine/APFloatWrappers.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
/// Binary operations without rounding mode.
3131
#define APFLOAT_BINARY_OP(OP) \
32-
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP( \
32+
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED __mlir_apfloat_##OP( \
3333
int32_t semantics, uint64_t a, uint64_t b) { \
3434
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
3535
static_cast<llvm::APFloatBase::Semantics>(semantics)); \
@@ -42,7 +42,7 @@
4242

4343
/// Binary operations with rounding mode.
4444
#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \
45-
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED _mlir_apfloat_##OP( \
45+
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED __mlir_apfloat_##OP( \
4646
int32_t semantics, uint64_t a, uint64_t b) { \
4747
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
4848
static_cast<llvm::APFloatBase::Semantics>(semantics)); \

mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,45 @@
1-
// RUN: mlir-opt %s --convert-arith-to-apfloat -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s --convert-arith-to-apfloat -split-input-file -verify-diagnostics | FileCheck %s
22

3-
// CHECK-LABEL: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
3+
// CHECK-LABEL: func.func private @__mlir_apfloat_add(i32, i64, i64) -> i64
44

55
// CHECK-LABEL: func.func @foo() -> f8E4M3FN {
66
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 2.250000e+00 : f8E4M3FN
77
// CHECK: return %[[CONSTANT_0]] : f8E4M3FN
88
// CHECK: }
99

10+
// CHECK-LABEL: func.func @bar() -> f6E3M2FN {
11+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 3.000000e+00 : f6E3M2FN
12+
// CHECK: return %[[CONSTANT_0]] : f6E3M2FN
13+
// CHECK: }
14+
15+
// Illustrate that both f8E4M3FN and f6E3M2FN calling the same __mlir_apfloat_add is fine
16+
// because each gets its own semantics enum and gets bitcast/extui/trunci to its own width.
1017
// CHECK-LABEL: func.func @full_example() {
11-
// CHECK: %[[cst:.*]] = arith.constant 1.375000e+00 : f8E4M3FN
12-
// CHECK: %[[rhs:.*]] = call @foo() : () -> f8E4M3FN
13-
// CHECK: %[[lhs_casted:.*]] = arith.bitcast %[[cst]] : f8E4M3FN to i8
14-
// CHECK: %[[lhs_ext:.*]] = arith.extui %[[lhs_casted]] : i8 to i64
15-
// CHECK: %[[rhs_casted:.*]] = arith.bitcast %[[rhs]] : f8E4M3FN to i8
16-
// CHECK: %[[rhs_ext:.*]] = arith.extui %[[rhs_casted]] : i8 to i64
17-
// CHECK: %[[c10_i32:.*]] = arith.constant 10 : i32
18-
// CHECK: %[[res:.*]] = call @_mlir_apfloat_add(%[[c10_i32]], %[[lhs_ext]], %[[rhs_ext]]) : (i32, i64, i64) -> i64
19-
// CHECK: %[[res_trunc:.*]] = arith.trunci %[[res]] : i64 to i8
20-
// CHECK: %[[res_casted:.*]] = arith.bitcast %[[res_trunc]] : i8 to f8E4M3FN
21-
// CHECK: vector.print %[[res_casted]] : f8E4M3FN
18+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1.375000e+00 : f8E4M3FN
19+
// CHECK: %[[VAL_0:.*]] = call @foo() : () -> f8E4M3FN
20+
// CHECK: %[[BITCAST_0:.*]] = arith.bitcast %[[CONSTANT_0]] : f8E4M3FN to i8
21+
// CHECK: %[[EXTUI_0:.*]] = arith.extui %[[BITCAST_0]] : i8 to i64
22+
// CHECK: %[[BITCAST_1:.*]] = arith.bitcast %[[VAL_0]] : f8E4M3FN to i8
23+
// CHECK: %[[EXTUI_1:.*]] = arith.extui %[[BITCAST_1]] : i8 to i64
24+
// // fltSemantics semantics for f8E4M3FN
25+
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 10 : i32
26+
// CHECK: %[[VAL_1:.*]] = call @__mlir_apfloat_add(%[[CONSTANT_1]], %[[EXTUI_0]], %[[EXTUI_1]]) : (i32, i64, i64) -> i64
27+
// CHECK: %[[TRUNCI_0:.*]] = arith.trunci %[[VAL_1]] : i64 to i8
28+
// CHECK: %[[BITCAST_2:.*]] = arith.bitcast %[[TRUNCI_0]] : i8 to f8E4M3FN
29+
// CHECK: vector.print %[[BITCAST_2]] : f8E4M3FN
30+
31+
// CHECK: %[[CONSTANT_2:.*]] = arith.constant 2.500000e+00 : f6E3M2FN
32+
// CHECK: %[[VAL_2:.*]] = call @bar() : () -> f6E3M2FN
33+
// CHECK: %[[BITCAST_3:.*]] = arith.bitcast %[[CONSTANT_2]] : f6E3M2FN to i6
34+
// CHECK: %[[EXTUI_2:.*]] = arith.extui %[[BITCAST_3]] : i6 to i64
35+
// CHECK: %[[BITCAST_4:.*]] = arith.bitcast %[[VAL_2]] : f6E3M2FN to i6
36+
// CHECK: %[[EXTUI_3:.*]] = arith.extui %[[BITCAST_4]] : i6 to i64
37+
// // fltSemantics semantics for f6E3M2FN
38+
// CHECK: %[[CONSTANT_3:.*]] = arith.constant 16 : i32
39+
// CHECK: %[[VAL_3:.*]] = call @__mlir_apfloat_add(%[[CONSTANT_3]], %[[EXTUI_2]], %[[EXTUI_3]]) : (i32, i64, i64) -> i64
40+
// CHECK: %[[TRUNCI_1:.*]] = arith.trunci %[[VAL_3]] : i64 to i6
41+
// CHECK: %[[BITCAST_5:.*]] = arith.bitcast %[[TRUNCI_1]] : i6 to f6E3M2FN
42+
// CHECK: vector.print %[[BITCAST_5]] : f6E3M2FN
2243
// CHECK: return
2344
// CHECK: }
2445

@@ -28,60 +49,79 @@ func.func @foo() -> f8E4M3FN {
2849
return %cst : f8E4M3FN
2950
}
3051

52+
func.func @bar() -> f6E3M2FN {
53+
%cst = arith.constant 3.2 : f6E3M2FN
54+
return %cst : f6E3M2FN
55+
}
56+
3157
func.func @full_example() {
3258
%a = arith.constant 1.4 : f8E4M3FN
3359
%b = func.call @foo() : () -> (f8E4M3FN)
3460
%c = arith.addf %a, %b : f8E4M3FN
35-
3661
vector.print %c : f8E4M3FN
62+
63+
%d = arith.constant 2.4 : f6E3M2FN
64+
%e = func.call @bar() : () -> (f6E3M2FN)
65+
%f = arith.addf %d, %e : f6E3M2FN
66+
vector.print %f : f6E3M2FN
3767
return
3868
}
3969

4070
// -----
4171

42-
// CHECK: func.func private @_mlir_apfloat_add(i32, i64, i64) -> i64
72+
// CHECK: func.func private @__mlir_apfloat_add(i32, i64, i64) -> i64
4373
// CHECK: %[[sem:.*]] = arith.constant 18 : i32
44-
// CHECK: call @_mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
74+
// CHECK: call @__mlir_apfloat_add(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
75+
func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
76+
%0 = arith.addf %arg0, %arg1 : f4E2M1FN
77+
return
78+
}
79+
80+
// -----
81+
82+
// Test decl collision (different type)
83+
// expected-error@+1{{matched function '__mlir_apfloat_add' but with different type: '(i32, i32, f32) -> index' (expected '(i32, i64, i64) -> i64')}}
84+
func.func private @__mlir_apfloat_add(i32, i32, f32) -> index
4585
func.func @addf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
4686
%0 = arith.addf %arg0, %arg1 : f4E2M1FN
4787
return
4888
}
4989

5090
// -----
5191

52-
// CHECK: func.func private @_mlir_apfloat_subtract(i32, i64, i64) -> i64
92+
// CHECK: func.func private @__mlir_apfloat_subtract(i32, i64, i64) -> i64
5393
// CHECK: %[[sem:.*]] = arith.constant 18 : i32
54-
// CHECK: call @_mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
94+
// CHECK: call @__mlir_apfloat_subtract(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
5595
func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
5696
%0 = arith.subf %arg0, %arg1 : f4E2M1FN
5797
return
5898
}
5999

60100
// -----
61101

62-
// CHECK: func.func private @_mlir_apfloat_multiply(i32, i64, i64) -> i64
102+
// CHECK: func.func private @__mlir_apfloat_multiply(i32, i64, i64) -> i64
63103
// CHECK: %[[sem:.*]] = arith.constant 18 : i32
64-
// CHECK: call @_mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
104+
// CHECK: call @__mlir_apfloat_multiply(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
65105
func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
66106
%0 = arith.mulf %arg0, %arg1 : f4E2M1FN
67107
return
68108
}
69109

70110
// -----
71111

72-
// CHECK: func.func private @_mlir_apfloat_divide(i32, i64, i64) -> i64
112+
// CHECK: func.func private @__mlir_apfloat_divide(i32, i64, i64) -> i64
73113
// CHECK: %[[sem:.*]] = arith.constant 18 : i32
74-
// CHECK: call @_mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
114+
// CHECK: call @__mlir_apfloat_divide(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
75115
func.func @subf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
76116
%0 = arith.divf %arg0, %arg1 : f4E2M1FN
77117
return
78118
}
79119

80120
// -----
81121

82-
// CHECK: func.func private @_mlir_apfloat_remainder(i32, i64, i64) -> i64
122+
// CHECK: func.func private @__mlir_apfloat_remainder(i32, i64, i64) -> i64
83123
// CHECK: %[[sem:.*]] = arith.constant 18 : i32
84-
// CHECK: call @_mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
124+
// CHECK: call @__mlir_apfloat_remainder(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i64
85125
func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
86126
%0 = arith.remf %arg0, %arg1 : f4E2M1FN
87127
return

0 commit comments

Comments
 (0)