15
15
#ifndef MLIR_DIALECT_COMMONFOLDERS_H
16
16
#define MLIR_DIALECT_COMMONFOLDERS_H
17
17
18
+ #include " mlir/IR/Attributes.h"
19
+ #include " mlir/IR/BuiltinAttributeInterfaces.h"
18
20
#include " mlir/IR/BuiltinAttributes.h"
19
- #include " mlir/IR/BuiltinTypes.h"
21
+ #include " mlir/IR/BuiltinTypeInterfaces.h"
22
+ #include " mlir/IR/Types.h"
20
23
#include " llvm/ADT/ArrayRef.h"
21
24
#include " llvm/ADT/STLExtras.h"
25
+
26
+ #include < cassert>
27
+ #include < cstddef>
22
28
#include < optional>
23
29
24
30
namespace mlir {
@@ -30,11 +36,13 @@ class PoisonAttr;
30
36
// / Uses `resultType` for the type of the returned attribute.
31
37
// / Optional PoisonAttr template argument allows to specify 'poison' attribute
32
38
// / which will be directly propagated to result.
33
- template <class AttrElementT ,
39
+ template <class AttrElementT , //
34
40
class ElementValueT = typename AttrElementT::ValueType,
35
41
class PoisonAttr = ub::PoisonAttr,
42
+ class ResultAttrElementT = AttrElementT,
43
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
36
44
class CalculationT = function_ref<
37
- std::optional<ElementValueT >(ElementValueT, ElementValueT)>>
45
+ std::optional<ResultElementValueT >(ElementValueT, ElementValueT)>>
38
46
Attribute constFoldBinaryOpConditional (ArrayRef<Attribute> operands,
39
47
Type resultType,
40
48
CalculationT &&calculate) {
@@ -65,7 +73,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
65
73
if (!calRes)
66
74
return {};
67
75
68
- return AttrElementT ::get (resultType, *calRes);
76
+ return ResultAttrElementT ::get (resultType, *calRes);
69
77
}
70
78
71
79
if (isa<SplatElementsAttr>(operands[0 ]) &&
@@ -99,7 +107,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
99
107
return {};
100
108
auto lhsIt = *maybeLhsIt;
101
109
auto rhsIt = *maybeRhsIt;
102
- SmallVector<ElementValueT , 4 > elementResults;
110
+ SmallVector<ResultElementValueT , 4 > elementResults;
103
111
elementResults.reserve (lhs.getNumElements ());
104
112
for (size_t i = 0 , e = lhs.getNumElements (); i < e; ++i, ++lhsIt, ++rhsIt) {
105
113
auto elementResult = calculate (*lhsIt, *rhsIt);
@@ -119,11 +127,13 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
119
127
// / attribute.
120
128
// / Optional PoisonAttr template argument allows to specify 'poison' attribute
121
129
// / which will be directly propagated to result.
122
- template <class AttrElementT ,
130
+ template <class AttrElementT , //
123
131
class ElementValueT = typename AttrElementT::ValueType,
124
132
class PoisonAttr = ub::PoisonAttr,
133
+ class ResultAttrElementT = AttrElementT,
134
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
125
135
class CalculationT = function_ref<
126
- std::optional<ElementValueT >(ElementValueT, ElementValueT)>>
136
+ std::optional<ResultElementValueT >(ElementValueT, ElementValueT)>>
127
137
Attribute constFoldBinaryOpConditional (ArrayRef<Attribute> operands,
128
138
CalculationT &&calculate) {
129
139
assert (operands.size () == 2 && " binary op takes two operands" );
@@ -139,64 +149,73 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
139
149
return operands[1 ];
140
150
}
141
151
142
- auto getResultType = [](Attribute attr) -> Type {
152
+ auto getAttrType = [](Attribute attr) -> Type {
143
153
if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
144
154
return typed.getType ();
145
155
return {};
146
156
};
147
157
148
- Type lhsType = getResultType (operands[0 ]);
149
- Type rhsType = getResultType (operands[1 ]);
158
+ Type lhsType = getAttrType (operands[0 ]);
159
+ Type rhsType = getAttrType (operands[1 ]);
150
160
if (!lhsType || !rhsType)
151
161
return {};
152
162
if (lhsType != rhsType)
153
163
return {};
154
164
155
165
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
166
+ ResultAttrElementT, ResultElementValueT,
156
167
CalculationT>(
157
168
operands, lhsType, std::forward<CalculationT>(calculate));
158
169
}
159
170
160
171
template <class AttrElementT ,
161
172
class ElementValueT = typename AttrElementT::ValueType,
162
- class PoisonAttr = void ,
173
+ class PoisonAttr = void , //
174
+ class ResultAttrElementT = AttrElementT,
175
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
163
176
class CalculationT =
164
- function_ref<ElementValueT (ElementValueT, ElementValueT)>>
177
+ function_ref<ResultElementValueT (ElementValueT, ElementValueT)>>
165
178
Attribute constFoldBinaryOp (ArrayRef<Attribute> operands, Type resultType,
166
179
CalculationT &&calculate) {
167
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
180
+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
181
+ ResultAttrElementT>(
168
182
operands, resultType,
169
- [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170
- return calculate (a, b);
171
- });
183
+ [&](ElementValueT a, ElementValueT b)
184
+ -> std::optional<ResultElementValueT> { return calculate (a, b); });
172
185
}
173
186
174
- template <class AttrElementT ,
187
+ template <class AttrElementT , //
175
188
class ElementValueT = typename AttrElementT::ValueType,
176
189
class PoisonAttr = ub::PoisonAttr,
190
+ class ResultAttrElementT = AttrElementT,
191
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
177
192
class CalculationT =
178
- function_ref<ElementValueT (ElementValueT, ElementValueT)>>
193
+ function_ref<ResultElementValueT (ElementValueT, ElementValueT)>>
179
194
Attribute constFoldBinaryOp (ArrayRef<Attribute> operands,
180
195
CalculationT &&calculate) {
181
- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
196
+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
197
+ ResultAttrElementT>(
182
198
operands,
183
- [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184
- return calculate (a, b);
185
- });
199
+ [&](ElementValueT a, ElementValueT b)
200
+ -> std::optional<ResultElementValueT> { return calculate (a, b); });
186
201
}
187
202
188
203
// / Performs constant folding `calculate` with element-wise behavior on the one
189
204
// / attributes in `operands` and returns the result if possible.
205
+ // / Uses `resultType` for the type of the returned attribute.
190
206
// / Optional PoisonAttr template argument allows to specify 'poison' attribute
191
207
// / which will be directly propagated to result.
192
- template <class AttrElementT ,
208
+ template <class AttrElementT , //
193
209
class ElementValueT = typename AttrElementT::ValueType,
194
210
class PoisonAttr = ub::PoisonAttr,
211
+ class ResultAttrElementT = AttrElementT,
212
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
195
213
class CalculationT =
196
- function_ref<std::optional<ElementValueT >(ElementValueT)>>
214
+ function_ref<std::optional<ResultElementValueT >(ElementValueT)>>
197
215
Attribute constFoldUnaryOpConditional (ArrayRef<Attribute> operands,
216
+ Type resultType,
198
217
CalculationT &&calculate) {
199
- if (!llvm::getSingleElement (operands))
218
+ if (!resultType || ! llvm::getSingleElement (operands))
200
219
return {};
201
220
202
221
static_assert (
@@ -214,7 +233,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
214
233
auto res = calculate (op.getValue ());
215
234
if (!res)
216
235
return {};
217
- return AttrElementT ::get (op. getType () , *res);
236
+ return ResultAttrElementT ::get (resultType , *res);
218
237
}
219
238
if (isa<SplatElementsAttr>(operands[0 ])) {
220
239
// Both operands are splats so we can avoid expanding the values out and
@@ -224,7 +243,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
224
243
auto elementResult = calculate (op.getSplatValue <ElementValueT>());
225
244
if (!elementResult)
226
245
return {};
227
- return DenseElementsAttr::get (op. getType ( ), *elementResult);
246
+ return DenseElementsAttr::get (cast<ShapedType>(resultType ), *elementResult);
228
247
} else if (isa<ElementsAttr>(operands[0 ])) {
229
248
// Operands are ElementsAttr-derived; perform an element-wise fold by
230
249
// expanding the values.
@@ -234,27 +253,89 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
234
253
if (!maybeOpIt)
235
254
return {};
236
255
auto opIt = *maybeOpIt;
237
- SmallVector<ElementValueT > elementResults;
256
+ SmallVector<ResultElementValueT > elementResults;
238
257
elementResults.reserve (op.getNumElements ());
239
258
for (size_t i = 0 , e = op.getNumElements (); i < e; ++i, ++opIt) {
240
259
auto elementResult = calculate (*opIt);
241
260
if (!elementResult)
242
261
return {};
243
262
elementResults.push_back (*elementResult);
244
263
}
245
- return DenseElementsAttr::get (op. getShapedType ( ), elementResults);
264
+ return DenseElementsAttr::get (cast<ShapedType>(resultType ), elementResults);
246
265
}
247
266
return {};
248
267
}
249
268
250
- template <class AttrElementT ,
269
+ // / Performs constant folding `calculate` with element-wise behavior on the one
270
+ // / attributes in `operands` and returns the result if possible.
271
+ // / Uses the operand element type for the element type of the returned
272
+ // / attribute.
273
+ // / Optional PoisonAttr template argument allows to specify 'poison' attribute
274
+ // / which will be directly propagated to result.
275
+ template <class AttrElementT , //
276
+ class ElementValueT = typename AttrElementT::ValueType,
277
+ class PoisonAttr = ub::PoisonAttr,
278
+ class ResultAttrElementT = AttrElementT,
279
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
280
+ class CalculationT =
281
+ function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
282
+ Attribute constFoldUnaryOpConditional (ArrayRef<Attribute> operands,
283
+ CalculationT &&calculate) {
284
+ if (!llvm::getSingleElement (operands))
285
+ return {};
286
+
287
+ static_assert (
288
+ std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
289
+ " PoisonAttr is undefined, either add a dependency on UB dialect or pass "
290
+ " void as template argument to opt-out from poison semantics." );
291
+ if constexpr (!std::is_void_v<PoisonAttr>) {
292
+ if (isa<PoisonAttr>(operands[0 ]))
293
+ return operands[0 ];
294
+ }
295
+
296
+ auto getAttrType = [](Attribute attr) -> Type {
297
+ if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
298
+ return typed.getType ();
299
+ return {};
300
+ };
301
+
302
+ Type operandType = getAttrType (operands[0 ]);
303
+ if (!operandType)
304
+ return {};
305
+
306
+ return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
307
+ ResultAttrElementT, ResultElementValueT,
308
+ CalculationT>(
309
+ operands, operandType, std::forward<CalculationT>(calculate));
310
+ }
311
+
312
+ template <class AttrElementT , //
313
+ class ElementValueT = typename AttrElementT::ValueType,
314
+ class PoisonAttr = ub::PoisonAttr,
315
+ class ResultAttrElementT = AttrElementT,
316
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
317
+ class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
318
+ Attribute constFoldUnaryOp (ArrayRef<Attribute> operands, Type resultType,
319
+ CalculationT &&calculate) {
320
+ return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
321
+ ResultAttrElementT>(
322
+ operands, resultType,
323
+ [&](ElementValueT a) -> std::optional<ResultElementValueT> {
324
+ return calculate (a);
325
+ });
326
+ }
327
+
328
+ template <class AttrElementT , //
251
329
class ElementValueT = typename AttrElementT::ValueType,
252
330
class PoisonAttr = ub::PoisonAttr,
253
- class CalculationT = function_ref<ElementValueT(ElementValueT)>>
331
+ class ResultAttrElementT = AttrElementT,
332
+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
333
+ class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
254
334
Attribute constFoldUnaryOp (ArrayRef<Attribute> operands,
255
335
CalculationT &&calculate) {
256
- return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
257
- operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
336
+ return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
337
+ ResultAttrElementT>(
338
+ operands, [&](ElementValueT a) -> std::optional<ResultElementValueT> {
258
339
return calculate (a);
259
340
});
260
341
}
0 commit comments