@@ -149,14 +149,14 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
149149      return  operands[1 ];
150150  }
151151
152-   auto  getResultType  = [](Attribute attr) -> Type {
152+   auto  getAttrType  = [](Attribute attr) -> Type {
153153    if  (auto  typed = dyn_cast_or_null<TypedAttr>(attr))
154154      return  typed.getType ();
155155    return  {};
156156  };
157157
158-   Type lhsType = getResultType (operands[0 ]);
159-   Type rhsType = getResultType (operands[1 ]);
158+   Type lhsType = getAttrType (operands[0 ]);
159+   Type rhsType = getAttrType (operands[1 ]);
160160  if  (!lhsType || !rhsType)
161161    return  {};
162162  if  (lhsType != rhsType)
@@ -202,16 +202,20 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
202202
203203// / Performs constant folding `calculate` with element-wise behavior on the one
204204// / attributes in `operands` and returns the result if possible.
205+ // / Uses `resultType` for the type of the returned attribute.
205206// / Optional PoisonAttr template argument allows to specify 'poison' attribute
206207// / which will be directly propagated to result.
207- template  <class  AttrElementT ,
208+ template  <class  AttrElementT ,  // 
208209          class  ElementValueT  = typename  AttrElementT::ValueType,
209210          class  PoisonAttr  = ub::PoisonAttr,
211+           class  ResultAttrElementT  = AttrElementT,
212+           class  ResultElementValueT  = typename  ResultAttrElementT::ValueType,
210213          class  CalculationT  =
211-               function_ref<std::optional<ElementValueT >(ElementValueT)>>
214+               function_ref<std::optional<ResultElementValueT >(ElementValueT)>>
212215Attribute constFoldUnaryOpConditional (ArrayRef<Attribute> operands,
216+                                       Type resultType,
213217                                      CalculationT &&calculate) {
214-   if  (!llvm::getSingleElement (operands))
218+   if  (!resultType || ! llvm::getSingleElement (operands))
215219    return  {};
216220
217221  static_assert (
@@ -229,7 +233,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
229233    auto  res = calculate (op.getValue ());
230234    if  (!res)
231235      return  {};
232-     return  AttrElementT ::get (op. getType () , *res);
236+     return  ResultAttrElementT ::get (resultType , *res);
233237  }
234238  if  (isa<SplatElementsAttr>(operands[0 ])) {
235239    //  Both operands are splats so we can avoid expanding the values out and
@@ -239,7 +243,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
239243    auto  elementResult = calculate (op.getSplatValue <ElementValueT>());
240244    if  (!elementResult)
241245      return  {};
242-     return  DenseElementsAttr::get (op. getType ( ), *elementResult);
246+     return  DenseElementsAttr::get (cast<ShapedType>(resultType ), *elementResult);
243247  } else  if  (isa<ElementsAttr>(operands[0 ])) {
244248    //  Operands are ElementsAttr-derived; perform an element-wise fold by
245249    //  expanding the values.
@@ -249,27 +253,89 @@ Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
249253    if  (!maybeOpIt)
250254      return  {};
251255    auto  opIt = *maybeOpIt;
252-     SmallVector<ElementValueT > elementResults;
256+     SmallVector<ResultElementValueT > elementResults;
253257    elementResults.reserve (op.getNumElements ());
254258    for  (size_t  i = 0 , e = op.getNumElements (); i < e; ++i, ++opIt) {
255259      auto  elementResult = calculate (*opIt);
256260      if  (!elementResult)
257261        return  {};
258262      elementResults.push_back (*elementResult);
259263    }
260-     return  DenseElementsAttr::get (op. getShapedType ( ), elementResults);
264+     return  DenseElementsAttr::get (cast<ShapedType>(resultType ), elementResults);
261265  }
262266  return  {};
263267}
264268
265- template  <class  AttrElementT ,
269+ // / Performs constant folding `calculate` with element-wise behavior on the one
270+ // / attributes in `operands` and returns the result if possible.
271+ // / Uses the operand element type for the element type of the returned
272+ // / attribute.
273+ // / Optional PoisonAttr template argument allows to specify 'poison' attribute
274+ // / which will be directly propagated to result.
275+ template  <class  AttrElementT , // 
276+           class  ElementValueT  = typename  AttrElementT::ValueType,
277+           class  PoisonAttr  = ub::PoisonAttr,
278+           class  ResultAttrElementT  = AttrElementT,
279+           class  ResultElementValueT  = typename  ResultAttrElementT::ValueType,
280+           class  CalculationT  =
281+               function_ref<std::optional<ResultElementValueT>(ElementValueT)>>
282+ Attribute constFoldUnaryOpConditional (ArrayRef<Attribute> operands,
283+                                       CalculationT &&calculate) {
284+   if  (!llvm::getSingleElement (operands))
285+     return  {};
286+ 
287+   static_assert (
288+       std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
289+       " PoisonAttr is undefined, either add a dependency on UB dialect or pass " 
290+       " void as template argument to opt-out from poison semantics."  );
291+   if  constexpr  (!std::is_void_v<PoisonAttr>) {
292+     if  (isa<PoisonAttr>(operands[0 ]))
293+       return  operands[0 ];
294+   }
295+ 
296+   auto  getAttrType = [](Attribute attr) -> Type {
297+     if  (auto  typed = dyn_cast_or_null<TypedAttr>(attr))
298+       return  typed.getType ();
299+     return  {};
300+   };
301+ 
302+   Type operandType = getAttrType (operands[0 ]);
303+   if  (!operandType)
304+     return  {};
305+ 
306+   return  constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
307+                                      ResultAttrElementT, ResultElementValueT,
308+                                      CalculationT>(
309+       operands, operandType, std::forward<CalculationT>(calculate));
310+ }
311+ 
312+ template  <class  AttrElementT , // 
313+           class  ElementValueT  = typename  AttrElementT::ValueType,
314+           class  PoisonAttr  = ub::PoisonAttr,
315+           class  ResultAttrElementT  = AttrElementT,
316+           class  ResultElementValueT  = typename  ResultAttrElementT::ValueType,
317+           class  CalculationT  = function_ref<ResultElementValueT(ElementValueT)>>
318+ Attribute constFoldUnaryOp (ArrayRef<Attribute> operands, Type resultType,
319+                            CalculationT &&calculate) {
320+   return  constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
321+                                      ResultAttrElementT>(
322+       operands, resultType,
323+       [&](ElementValueT a) -> std::optional<ResultElementValueT> {
324+         return  calculate (a);
325+       });
326+ }
327+ 
328+ template  <class  AttrElementT , // 
266329          class  ElementValueT  = typename  AttrElementT::ValueType,
267330          class  PoisonAttr  = ub::PoisonAttr,
268-           class  CalculationT  = function_ref<ElementValueT(ElementValueT)>>
331+           class  ResultAttrElementT  = AttrElementT,
332+           class  ResultElementValueT  = typename  ResultAttrElementT::ValueType,
333+           class  CalculationT  = function_ref<ResultElementValueT(ElementValueT)>>
269334Attribute constFoldUnaryOp (ArrayRef<Attribute> operands,
270335                           CalculationT &&calculate) {
271-   return  constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
272-       operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
336+   return  constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
337+                                      ResultAttrElementT>(
338+       operands, [&](ElementValueT a) -> std::optional<ResultElementValueT> {
273339        return  calculate (a);
274340      });
275341}
0 commit comments