7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10
+ #include " mlir/Dialect/CommonFolders.h"
10
11
#include " mlir/Dialect/Math/IR/Math.h"
11
12
#include " mlir/IR/Builders.h"
12
13
@@ -25,119 +26,65 @@ using namespace mlir::math;
25
26
// ===----------------------------------------------------------------------===//
26
27
27
28
OpFoldResult math::AbsOp::fold (ArrayRef<Attribute> operands) {
28
- auto constOperand = operands.front ();
29
- if (!constOperand)
30
- return {};
31
-
32
- auto attr = constOperand.dyn_cast <FloatAttr>();
33
- if (!attr)
34
- return {};
35
-
36
- auto ft = getType ().cast <FloatType>();
37
-
38
- APFloat apf = attr.getValue ();
39
-
40
- if (ft.getWidth () == 64 )
41
- return FloatAttr::get (getType (), fabs (apf.convertToDouble ()));
42
-
43
- if (ft.getWidth () == 32 )
44
- return FloatAttr::get (getType (), fabsf (apf.convertToFloat ()));
45
-
46
- return {};
29
+ return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
30
+ APFloat result (a);
31
+ return abs (result);
32
+ });
47
33
}
48
34
49
35
// ===----------------------------------------------------------------------===//
50
36
// CeilOp folder
51
37
// ===----------------------------------------------------------------------===//
52
38
53
39
OpFoldResult math::CeilOp::fold (ArrayRef<Attribute> operands) {
54
- auto constOperand = operands.front ();
55
- if (!constOperand)
56
- return {};
57
-
58
- auto attr = constOperand.dyn_cast <FloatAttr>();
59
- if (!attr)
60
- return {};
61
-
62
- APFloat sourceVal = attr.getValue ();
63
- sourceVal.roundToIntegral (llvm::RoundingMode::TowardPositive);
64
-
65
- return FloatAttr::get (getType (), sourceVal);
40
+ return constFoldUnaryOp<FloatAttr>(operands, [](const APFloat &a) {
41
+ APFloat result (a);
42
+ result.roundToIntegral (llvm::RoundingMode::TowardPositive);
43
+ return result;
44
+ });
66
45
}
67
46
68
47
// ===----------------------------------------------------------------------===//
69
48
// CopySignOp folder
70
49
// ===----------------------------------------------------------------------===//
71
50
72
51
OpFoldResult math::CopySignOp::fold (ArrayRef<Attribute> operands) {
73
- auto ft = getType ().dyn_cast <FloatType>();
74
- if (!ft)
75
- return {};
76
-
77
- APFloat vals[2 ]{APFloat (ft.getFloatSemantics ()),
78
- APFloat (ft.getFloatSemantics ())};
79
- for (int i = 0 ; i < 2 ; ++i) {
80
- if (!operands[i])
81
- return {};
82
-
83
- auto attr = operands[i].dyn_cast <FloatAttr>();
84
- if (!attr)
85
- return {};
86
-
87
- vals[i] = attr.getValue ();
88
- }
89
-
90
- vals[0 ].copySign (vals[1 ]);
91
-
92
- return FloatAttr::get (getType (), vals[0 ]);
52
+ return constFoldBinaryOp<FloatAttr>(operands,
53
+ [](const APFloat &a, const APFloat &b) {
54
+ APFloat result (a);
55
+ result.copySign (b);
56
+ return result;
57
+ });
93
58
}
94
59
95
60
// ===----------------------------------------------------------------------===//
96
61
// CountLeadingZerosOp folder
97
62
// ===----------------------------------------------------------------------===//
98
63
99
64
OpFoldResult math::CountLeadingZerosOp::fold (ArrayRef<Attribute> operands) {
100
- auto constOperand = operands.front ();
101
- if (!constOperand)
102
- return {};
103
-
104
- auto attr = constOperand.dyn_cast <IntegerAttr>();
105
- if (!attr)
106
- return {};
107
-
108
- return IntegerAttr::get (getType (), attr.getValue ().countLeadingZeros ());
65
+ return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
66
+ return APInt (a.getBitWidth (), a.countLeadingZeros ());
67
+ });
109
68
}
110
69
111
70
// ===----------------------------------------------------------------------===//
112
71
// CountTrailingZerosOp folder
113
72
// ===----------------------------------------------------------------------===//
114
73
115
74
OpFoldResult math::CountTrailingZerosOp::fold (ArrayRef<Attribute> operands) {
116
- auto constOperand = operands.front ();
117
- if (!constOperand)
118
- return {};
119
-
120
- auto attr = constOperand.dyn_cast <IntegerAttr>();
121
- if (!attr)
122
- return {};
123
-
124
- return IntegerAttr::get (getType (), attr.getValue ().countTrailingZeros ());
75
+ return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
76
+ return APInt (a.getBitWidth (), a.countTrailingZeros ());
77
+ });
125
78
}
126
79
127
80
// ===----------------------------------------------------------------------===//
128
81
// CtPopOp folder
129
82
// ===----------------------------------------------------------------------===//
130
83
131
84
OpFoldResult math::CtPopOp::fold (ArrayRef<Attribute> operands) {
132
- auto constOperand = operands.front ();
133
- if (!constOperand)
134
- return {};
135
-
136
- auto attr = constOperand.dyn_cast <IntegerAttr>();
137
- if (!attr)
138
- return {};
139
-
140
- return IntegerAttr::get (getType (), attr.getValue ().countPopulation ());
85
+ return constFoldUnaryOp<IntegerAttr>(operands, [](const APInt &a) {
86
+ return APInt (a.getBitWidth (), a.countPopulation ());
87
+ });
141
88
}
142
89
143
90
// ===----------------------------------------------------------------------===//
0 commit comments