1515#ifndef MLIR_DIALECT_COMMONFOLDERS_H
1616#define MLIR_DIALECT_COMMONFOLDERS_H
1717
18+ #include " mlir/IR/Attributes.h"
19+ #include " mlir/IR/BuiltinAttributeInterfaces.h"
1820#include " mlir/IR/BuiltinAttributes.h"
19- #include " mlir/IR/BuiltinTypes.h"
21+ #include " mlir/IR/BuiltinTypeInterfaces.h"
22+ #include " mlir/IR/Types.h"
2023#include " llvm/ADT/ArrayRef.h"
2124#include " llvm/ADT/STLExtras.h"
25+
26+ #include < cassert>
27+ #include < cstddef>
2228#include < optional>
2329
2430namespace mlir {
@@ -30,11 +36,13 @@ class PoisonAttr;
3036// / Uses `resultType` for the type of the returned attribute.
3137// / Optional PoisonAttr template argument allows to specify 'poison' attribute
3238// / which will be directly propagated to result.
33- template <class AttrElementT ,
39+ template <class AttrElementT , //
3440 class ElementValueT = typename AttrElementT::ValueType,
3541 class PoisonAttr = ub::PoisonAttr,
42+ class ResultAttrElementT = AttrElementT,
43+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
3644 class CalculationT = function_ref<
37- std::optional<ElementValueT >(ElementValueT, ElementValueT)>>
45+ std::optional<ResultElementValueT >(ElementValueT, ElementValueT)>>
3846Attribute constFoldBinaryOpConditional (ArrayRef<Attribute> operands,
3947 Type resultType,
4048 CalculationT &&calculate) {
@@ -65,7 +73,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
6573 if (!calRes)
6674 return {};
6775
68- return AttrElementT ::get (resultType, *calRes);
76+ return ResultAttrElementT ::get (resultType, *calRes);
6977 }
7078
7179 if (isa<SplatElementsAttr>(operands[0 ]) &&
@@ -99,7 +107,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
99107 return {};
100108 auto lhsIt = *maybeLhsIt;
101109 auto rhsIt = *maybeRhsIt;
102- SmallVector<ElementValueT , 4 > elementResults;
110+ SmallVector<ResultElementValueT , 4 > elementResults;
103111 elementResults.reserve (lhs.getNumElements ());
104112 for (size_t i = 0 , e = lhs.getNumElements (); i < e; ++i, ++lhsIt, ++rhsIt) {
105113 auto elementResult = calculate (*lhsIt, *rhsIt);
@@ -119,11 +127,13 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
119127// / attribute.
120128// / Optional PoisonAttr template argument allows to specify 'poison' attribute
121129// / which will be directly propagated to result.
122- template <class AttrElementT ,
130+ template <class AttrElementT , //
123131 class ElementValueT = typename AttrElementT::ValueType,
124132 class PoisonAttr = ub::PoisonAttr,
133+ class ResultAttrElementT = AttrElementT,
134+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
125135 class CalculationT = function_ref<
126- std::optional<ElementValueT >(ElementValueT, ElementValueT)>>
136+ std::optional<ResultElementValueT >(ElementValueT, ElementValueT)>>
127137Attribute constFoldBinaryOpConditional (ArrayRef<Attribute> operands,
128138 CalculationT &&calculate) {
129139 assert (operands.size () == 2 && " binary op takes two operands" );
@@ -139,64 +149,73 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
139149 return operands[1 ];
140150 }
141151
142- auto getResultType = [](Attribute attr) -> Type {
152+ auto getAttrType = [](Attribute attr) -> Type {
143153 if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
144154 return typed.getType ();
145155 return {};
146156 };
147157
148- Type lhsType = getResultType (operands[0 ]);
149- Type rhsType = getResultType (operands[1 ]);
158+ Type lhsType = getAttrType (operands[0 ]);
159+ Type rhsType = getAttrType (operands[1 ]);
150160 if (!lhsType || !rhsType)
151161 return {};
152162 if (lhsType != rhsType)
153163 return {};
154164
155165 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
166+ ResultAttrElementT, ResultElementValueT,
156167 CalculationT>(
157168 operands, lhsType, std::forward<CalculationT>(calculate));
158169}
159170
160171template <class AttrElementT ,
161172 class ElementValueT = typename AttrElementT::ValueType,
162- class PoisonAttr = void ,
173+ class PoisonAttr = void , //
174+ class ResultAttrElementT = AttrElementT,
175+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
163176 class CalculationT =
164- function_ref<ElementValueT (ElementValueT, ElementValueT)>>
177+ function_ref<ResultElementValueT (ElementValueT, ElementValueT)>>
165178Attribute constFoldBinaryOp (ArrayRef<Attribute> operands, Type resultType,
166179 CalculationT &&calculate) {
167- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
180+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
181+ ResultAttrElementT>(
168182 operands, resultType,
169- [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170- return calculate (a, b);
171- });
183+ [&](ElementValueT a, ElementValueT b)
184+ -> std::optional<ResultElementValueT> { return calculate (a, b); });
172185}
173186
174- template <class AttrElementT ,
187+ template <class AttrElementT , //
175188 class ElementValueT = typename AttrElementT::ValueType,
176189 class PoisonAttr = ub::PoisonAttr,
190+ class ResultAttrElementT = AttrElementT,
191+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
177192 class CalculationT =
178- function_ref<ElementValueT (ElementValueT, ElementValueT)>>
193+ function_ref<ResultElementValueT (ElementValueT, ElementValueT)>>
179194Attribute constFoldBinaryOp (ArrayRef<Attribute> operands,
180195 CalculationT &&calculate) {
181- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
196+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
197+ ResultAttrElementT>(
182198 operands,
183- [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184- return calculate (a, b);
185- });
199+ [&](ElementValueT a, ElementValueT b)
200+ -> std::optional<ResultElementValueT> { return calculate (a, b); });
186201}
187202
188203// / Performs constant folding `calculate` with element-wise behavior on the one
189204// / attributes in `operands` and returns the result if possible.
205+ // / Uses `resultType` for the type of the returned attribute.
190206// / Optional PoisonAttr template argument allows to specify 'poison' attribute
191207// / which will be directly propagated to result.
192- template <class AttrElementT ,
208+ template <class AttrElementT , //
193209 class ElementValueT = typename AttrElementT::ValueType,
194210 class PoisonAttr = ub::PoisonAttr,
211+ class ResultAttrElementT = AttrElementT,
212+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
195213 class CalculationT =
196- function_ref<std::optional<ElementValueT >(ElementValueT)>>
214+ function_ref<std::optional<ResultElementValueT >(ElementValueT)>>
197215Attribute constFoldUnaryOpConditional (ArrayRef<Attribute> operands,
216+ Type resultType,
198217 CalculationT &&calculate) {
199- if (!llvm::getSingleElement (operands))
218+ if (!resultType || ! llvm::getSingleElement (operands))
200219 return {};
201220
202221 static_assert (
@@ -214,7 +233,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
214233 auto res = calculate (op.getValue ());
215234 if (!res)
216235 return {};
217- return AttrElementT ::get (op. getType () , *res);
236+ return ResultAttrElementT ::get (resultType , *res);
218237 }
219238 if (isa<SplatElementsAttr>(operands[0 ])) {
220239 // Both operands are splats so we can avoid expanding the values out and
@@ -224,7 +243,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
224243 auto elementResult = calculate (op.getSplatValue <ElementValueT>());
225244 if (!elementResult)
226245 return {};
227- return DenseElementsAttr::get (op. getType ( ), *elementResult);
246+ return DenseElementsAttr::get (cast<ShapedType>(resultType ), *elementResult);
228247 } else if (isa<ElementsAttr>(operands[0 ])) {
229248 // Operands are ElementsAttr-derived; perform an element-wise fold by
230249 // expanding the values.
@@ -234,27 +253,89 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
234253 if (!maybeOpIt)
235254 return {};
236255 auto opIt = *maybeOpIt;
237- SmallVector<ElementValueT > elementResults;
256+ SmallVector<ResultElementValueT > elementResults;
238257 elementResults.reserve (op.getNumElements ());
239258 for (size_t i = 0 , e = op.getNumElements (); i < e; ++i, ++opIt) {
240259 auto elementResult = calculate (*opIt);
241260 if (!elementResult)
242261 return {};
243262 elementResults.push_back (*elementResult);
244263 }
245- return DenseElementsAttr::get (op. getShapedType ( ), elementResults);
264+ return DenseElementsAttr::get (cast<ShapedType>(resultType ), elementResults);
246265 }
247266 return {};
248267}
249268
250- template <class AttrElementT ,
269+ // / Performs constant folding `calculate` with element-wise behavior on the one
270+ // / attributes in `operands` and returns the result if possible.
271+ // / Uses the operand element type for the element type of the returned
272+ // / attribute.
273+ // / Optional PoisonAttr template argument allows to specify 'poison' attribute
274+ // / which will be directly propagated to result.
275+ template <class AttrElementT , //
276+ class ElementValueT = typename AttrElementT::ValueType,
277+ class PoisonAttr = ub::PoisonAttr,
278+ class ResultAttrElementT = AttrElementT,
279+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
280+ class CalculationT =
281+ function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
282+ Attribute constFoldUnaryOpConditional (ArrayRef<Attribute> operands,
283+ CalculationT &&calculate) {
284+ if (!llvm::getSingleElement (operands))
285+ return {};
286+
287+ static_assert (
288+ std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
289+ " PoisonAttr is undefined, either add a dependency on UB dialect or pass "
290+ " void as template argument to opt-out from poison semantics." );
291+ if constexpr (!std::is_void_v<PoisonAttr>) {
292+ if (isa<PoisonAttr>(operands[0 ]))
293+ return operands[0 ];
294+ }
295+
296+ auto getAttrType = [](Attribute attr) -> Type {
297+ if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
298+ return typed.getType ();
299+ return {};
300+ };
301+
302+ Type operandType = getAttrType (operands[0 ]);
303+ if (!operandType)
304+ return {};
305+
306+ return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
307+ ResultAttrElementT, ResultElementValueT,
308+ CalculationT>(
309+ operands, operandType, std::forward<CalculationT>(calculate));
310+ }
311+
312+ template <class AttrElementT , //
313+ class ElementValueT = typename AttrElementT::ValueType,
314+ class PoisonAttr = ub::PoisonAttr,
315+ class ResultAttrElementT = AttrElementT,
316+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
317+ class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
318+ Attribute constFoldUnaryOp (ArrayRef<Attribute> operands, Type resultType,
319+ CalculationT &&calculate) {
320+ return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
321+ ResultAttrElementT>(
322+ operands, resultType,
323+ [&](ElementValueT a) -> std::optional<ResultElementValueT> {
324+ return calculate (a);
325+ });
326+ }
327+
328+ template <class AttrElementT , //
251329 class ElementValueT = typename AttrElementT::ValueType,
252330 class PoisonAttr = ub::PoisonAttr,
253- class CalculationT = function_ref<ElementValueT(ElementValueT)>>
331+ class ResultAttrElementT = AttrElementT,
332+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
333+ class CalculationT = function_ref<ResultElementValueT(ElementValueT)>>
254334Attribute constFoldUnaryOp (ArrayRef<Attribute> operands,
255335 CalculationT &&calculate) {
256- return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
257- operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
336+ return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
337+ ResultAttrElementT>(
338+ operands, [&](ElementValueT a) -> std::optional<ResultElementValueT> {
258339 return calculate (a);
259340 });
260341}
0 commit comments