Skip to content

Commit 83bd4fe

Browse files
jacquesguanjacquesguan
authored andcommitted
[mlir][Math] Replace some constant folder functions with common folder functions.
Differential Revision: https://reviews.llvm.org/D123485
1 parent 4aeb2a5 commit 83bd4fe

File tree

1 file changed

+25
-78
lines changed

1 file changed

+25
-78
lines changed

mlir/lib/Dialect/Math/IR/MathOps.cpp

Lines changed: 25 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10+
#include "mlir/Dialect/CommonFolders.h"
1011
#include "mlir/Dialect/Math/IR/Math.h"
1112
#include "mlir/IR/Builders.h"
1213

@@ -25,119 +26,65 @@ using namespace mlir::math;
2526
//===----------------------------------------------------------------------===//
2627

2728
OpFoldResult math::AbsOp::fold(ArrayRef<Attribute> operands) {
28-
auto constOperand = operands.front();
29-
if (!constOperand)
30-
return {};
31-
32-
auto attr = constOperand.dyn_cast<FloatAttr>();
33-
if (!attr)
34-
return {};
35-
36-
auto ft = getType().cast<FloatType>();
37-
38-
APFloat apf = attr.getValue();
39-
40-
if (ft.getWidth() == 64)
41-
return FloatAttr::get(getType(), fabs(apf.convertToDouble()));
42-
43-
if (ft.getWidth() == 32)
44-
return FloatAttr::get(getType(), fabsf(apf.convertToFloat()));
45-
46-
return {};
29+
return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
30+
APFloat result(a);
31+
return abs(result);
32+
});
4733
}
4834

4935
//===----------------------------------------------------------------------===//
5036
// CeilOp folder
5137
//===----------------------------------------------------------------------===//
5238

5339
OpFoldResult math::CeilOp::fold(ArrayRef<Attribute> operands) {
54-
auto constOperand = operands.front();
55-
if (!constOperand)
56-
return {};
57-
58-
auto attr = constOperand.dyn_cast<FloatAttr>();
59-
if (!attr)
60-
return {};
61-
62-
APFloat sourceVal = attr.getValue();
63-
sourceVal.roundToIntegral(llvm::RoundingMode::TowardPositive);
64-
65-
return FloatAttr::get(getType(), sourceVal);
40+
return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
41+
APFloat result(a);
42+
result.roundToIntegral(llvm::RoundingMode::TowardPositive);
43+
return result;
44+
});
6645
}
6746

6847
//===----------------------------------------------------------------------===//
6948
// CopySignOp folder
7049
//===----------------------------------------------------------------------===//
7150

7251
OpFoldResult math::CopySignOp::fold(ArrayRef<Attribute> operands) {
73-
auto ft = getType().dyn_cast<FloatType>();
74-
if (!ft)
75-
return {};
76-
77-
APFloat vals[2]{APFloat(ft.getFloatSemantics()),
78-
APFloat(ft.getFloatSemantics())};
79-
for (int i = 0; i < 2; ++i) {
80-
if (!operands[i])
81-
return {};
82-
83-
auto attr = operands[i].dyn_cast<FloatAttr>();
84-
if (!attr)
85-
return {};
86-
87-
vals[i] = attr.getValue();
88-
}
89-
90-
vals[0].copySign(vals[1]);
91-
92-
return FloatAttr::get(getType(), vals[0]);
52+
return constFoldBinaryOp<FloatAttr>(operands,
53+
[](const APFloat &a, const APFloat &b) {
54+
APFloat result(a);
55+
result.copySign(b);
56+
return result;
57+
});
9358
}
9459

9560
//===----------------------------------------------------------------------===//
9661
// CountLeadingZerosOp folder
9762
//===----------------------------------------------------------------------===//
9863

9964
OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef<Attribute> operands) {
100-
auto constOperand = operands.front();
101-
if (!constOperand)
102-
return {};
103-
104-
auto attr = constOperand.dyn_cast<IntegerAttr>();
105-
if (!attr)
106-
return {};
107-
108-
return IntegerAttr::get(getType(), attr.getValue().countLeadingZeros());
65+
return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
66+
return APInt(a.getBitWidth(), a.countLeadingZeros());
67+
});
10968
}
11069

11170
//===----------------------------------------------------------------------===//
11271
// CountTrailingZerosOp folder
11372
//===----------------------------------------------------------------------===//
11473

11574
OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef<Attribute> operands) {
116-
auto constOperand = operands.front();
117-
if (!constOperand)
118-
return {};
119-
120-
auto attr = constOperand.dyn_cast<IntegerAttr>();
121-
if (!attr)
122-
return {};
123-
124-
return IntegerAttr::get(getType(), attr.getValue().countTrailingZeros());
75+
return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
76+
return APInt(a.getBitWidth(), a.countTrailingZeros());
77+
});
12578
}
12679

12780
//===----------------------------------------------------------------------===//
12881
// CtPopOp folder
12982
//===----------------------------------------------------------------------===//
13083

13184
OpFoldResult math::CtPopOp::fold(ArrayRef<Attribute> operands) {
132-
auto constOperand = operands.front();
133-
if (!constOperand)
134-
return {};
135-
136-
auto attr = constOperand.dyn_cast<IntegerAttr>();
137-
if (!attr)
138-
return {};
139-
140-
return IntegerAttr::get(getType(), attr.getValue().countPopulation());
85+
return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
86+
return APInt(a.getBitWidth(), a.countPopulation());
87+
});
14188
}
14289

14390
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)