1515#ifndef MLIR_DIALECT_COMMONFOLDERS_H
1616#define MLIR_DIALECT_COMMONFOLDERS_H
1717
18+ #include " mlir/IR/Attributes.h"
19+ #include " mlir/IR/BuiltinAttributeInterfaces.h"
1820#include " mlir/IR/BuiltinAttributes.h"
19- #include " mlir/IR/BuiltinTypes.h"
21+ #include " mlir/IR/BuiltinTypeInterfaces.h"
22+ #include " mlir/IR/Types.h"
2023#include " llvm/ADT/ArrayRef.h"
2124#include " llvm/ADT/STLExtras.h"
25+
26+ #include < cassert>
27+ #include < cstddef>
2228#include < optional>
2329
2430namespace mlir {
@@ -30,11 +36,13 @@ class PoisonAttr;
3036// / Uses `resultType` for the type of the returned attribute.
3137// / Optional PoisonAttr template argument allows to specify 'poison' attribute
3238// / which will be directly propagated to result.
33- template <class AttrElementT ,
39+ template <class AttrElementT , //
3440 class ElementValueT = typename AttrElementT::ValueType,
3541 class PoisonAttr = ub::PoisonAttr,
42+ class ResultAttrElementT = AttrElementT,
43+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
3644 class CalculationT = function_ref<
37- std::optional<ElementValueT >(ElementValueT, ElementValueT)>>
45+ std::optional<ResultElementValueT >(ElementValueT, ElementValueT)>>
3846Attribute constFoldBinaryOpConditional (ArrayRef<Attribute> operands,
3947 Type resultType,
4048 CalculationT &&calculate) {
@@ -65,7 +73,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
6573 if (!calRes)
6674 return {};
6775
68- return AttrElementT ::get (resultType, *calRes);
76+ return ResultAttrElementT ::get (resultType, *calRes);
6977 }
7078
7179 if (isa<SplatElementsAttr>(operands[0 ]) &&
@@ -99,7 +107,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
99107 return {};
100108 auto lhsIt = *maybeLhsIt;
101109 auto rhsIt = *maybeRhsIt;
102- SmallVector<ElementValueT , 4 > elementResults;
110+ SmallVector<ResultElementValueT , 4 > elementResults;
103111 elementResults.reserve (lhs.getNumElements ());
104112 for (size_t i = 0 , e = lhs.getNumElements (); i < e; ++i, ++lhsIt, ++rhsIt) {
105113 auto elementResult = calculate (*lhsIt, *rhsIt);
@@ -119,11 +127,13 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
119127// / attribute.
120128// / Optional PoisonAttr template argument allows to specify 'poison' attribute
121129// / which will be directly propagated to result.
122- template <class AttrElementT ,
130+ template <class AttrElementT , //
123131 class ElementValueT = typename AttrElementT::ValueType,
124132 class PoisonAttr = ub::PoisonAttr,
133+ class ResultAttrElementT = AttrElementT,
134+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
125135 class CalculationT = function_ref<
126- std::optional<ElementValueT >(ElementValueT, ElementValueT)>>
136+ std::optional<ResultElementValueT >(ElementValueT, ElementValueT)>>
127137Attribute constFoldBinaryOpConditional (ArrayRef<Attribute> operands,
128138 CalculationT &&calculate) {
129139 assert (operands.size () == 2 && " binary op takes two operands" );
@@ -153,36 +163,41 @@ Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
153163 return {};
154164
155165 return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
166+ ResultAttrElementT, ResultElementValueT,
156167 CalculationT>(
157168 operands, lhsType, std::forward<CalculationT>(calculate));
158169}
159170
160171template <class AttrElementT ,
161172 class ElementValueT = typename AttrElementT::ValueType,
162- class PoisonAttr = void ,
173+ class PoisonAttr = void , //
174+ class ResultAttrElementT = AttrElementT,
175+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
163176 class CalculationT =
164- function_ref<ElementValueT (ElementValueT, ElementValueT)>>
177+ function_ref<ResultElementValueT (ElementValueT, ElementValueT)>>
165178Attribute constFoldBinaryOp (ArrayRef<Attribute> operands, Type resultType,
166179 CalculationT &&calculate) {
167- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
180+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
181+ ResultAttrElementT>(
168182 operands, resultType,
169- [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
170- return calculate (a, b);
171- });
183+ [&](ElementValueT a, ElementValueT b)
184+ -> std::optional<ResultElementValueT> { return calculate (a, b); });
172185}
173186
174- template <class AttrElementT ,
187+ template <class AttrElementT , //
175188 class ElementValueT = typename AttrElementT::ValueType,
176189 class PoisonAttr = ub::PoisonAttr,
190+ class ResultAttrElementT = AttrElementT,
191+ class ResultElementValueT = typename ResultAttrElementT::ValueType,
177192 class CalculationT =
178- function_ref<ElementValueT (ElementValueT, ElementValueT)>>
193+ function_ref<ResultElementValueT (ElementValueT, ElementValueT)>>
179194Attribute constFoldBinaryOp (ArrayRef<Attribute> operands,
180195 CalculationT &&calculate) {
181- return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
196+ return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
197+ ResultAttrElementT>(
182198 operands,
183- [&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
184- return calculate (a, b);
185- });
199+ [&](ElementValueT a, ElementValueT b)
200+ -> std::optional<ResultElementValueT> { return calculate (a, b); });
186201}
187202
188203// / Performs constant folding `calculate` with element-wise behavior on the one
0 commit comments