@@ -88,101 +88,30 @@ SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
8888// Utility functions for propagating static information
8989// ===----------------------------------------------------------------------===//
9090
91- // / Helper function that infers the constant values from a list of \p values,
92- // / a \p memRefTy, and another helper function \p getAttributes.
93- // / The inferred constant values replace the related `OpFoldResult` in
94- // / \p values.
91+ // / Helper function that sets values[i] to constValues[i] if the latter is a
92+ // / static value, as indicated by ShapedType::kDynamic.
9593// /
96- // / \note This function shouldn't be used directly, instead, use the
97- // / `getConstifiedMixedXXX` methods from the related operations.
98- // /
99- // / \p getAttributes retuns a list of potentially constant values, as determined
100- // / by \p isDynamic, from the given \p memRefTy. The returned list must have as
101- // / many elements as \p values or be empty.
102- // /
103- // / E.g., consider the following example:
104- // / ```
105- // / memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
106- // / memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
107- // / ```
108- // / `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
109- // / Now using this helper function with:
110- // / - `values == [2, %dyn_stride]`,
111- // / - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
112- // / - `getAttributes == getConstantStrides` (i.e., a wrapper around
113- // / `getStridesAndOffset`), and
114- // / - `isDynamic == ShapedType::isDynamic`
115- // / Will yield: `values == [2, 1]`
116- static void constifyIndexValues (
117- SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
118- MLIRContext *ctxt,
119- llvm::function_ref<SmallVector<int64_t >(MemRefType)> getAttributes,
120- llvm::function_ref<bool(int64_t )> isDynamic) {
121- SmallVector<int64_t > constValues = getAttributes (memRefTy);
122- Builder builder (ctxt);
123- for (const auto &it : llvm::enumerate (constValues)) {
124- int64_t constValue = it.value ();
125- if (!isDynamic (constValue))
126- values[it.index ()] = builder.getIndexAttr (constValue);
127- }
128- for (OpFoldResult &ofr : values) {
129- if (auto attr = dyn_cast<Attribute>(ofr)) {
130- // FIXME: We shouldn't need to do that, but right now, the static indices
131- // are created with the wrong type: `i64` instead of `index`.
132- // As a result, if we were to keep the attribute as is, we may fail to see
133- // that two attributes are equal because one would have the i64 type and
134- // the other the index type.
135- // The alternative would be to create constant indices with getI64Attr in
136- // this and the previous loop, but it doesn't logically make sense (we are
137- // dealing with indices here) and would only strenghten the inconsistency
138- // around how static indices are created (some places use getI64Attr,
139- // others use getIndexAttr).
140- // The workaround here is to stick to the IndexAttr type for all the
141- // values, hence we recreate the attribute even when it is already static
142- // to make sure the type is consistent.
143- ofr = builder.getIndexAttr (llvm::cast<IntegerAttr>(attr).getInt ());
94+ // / If constValues[i] is dynamic, tries to extract a constant value from
95+ // / value[i] to allow for additional folding opportunities. Also convertes all
96+ // / existing attributes to index attributes. (They may be i64 attributes.)
97+ static void constifyIndexValues (SmallVectorImpl<OpFoldResult> &values,
98+ ArrayRef<int64_t > constValues) {
99+ assert (constValues.size () == values.size () &&
100+ " incorrect number of const values" );
101+ for (auto [i, cstVal] : llvm::enumerate (constValues)) {
102+ Builder builder (values[i].getContext ());
103+ if (!ShapedType::isDynamic (cstVal)) {
104+ // Constant value is known, use it directly.
105+ values[i] = builder.getIndexAttr (cstVal);
144106 continue ;
145107 }
146- std::optional<int64_t > maybeConstant =
147- getConstantIntValue (cast<Value>(ofr));
148- if (maybeConstant)
149- ofr = builder. getIndexAttr (*maybeConstant);
108+ if ( std::optional<int64_t > cst = getConstantIntValue (values[i])) {
109+ // Try to extract a constant or convert an existing to index.
110+ values[i] = builder. getIndexAttr (*cst);
111+ }
150112 }
151113}
152114
153- // / Wrapper around `getShape` that conforms to the function signature
154- // / expected for `getAttributes` in `constifyIndexValues`.
155- static SmallVector<int64_t > getConstantSizes (MemRefType memRefTy) {
156- ArrayRef<int64_t > sizes = memRefTy.getShape ();
157- return SmallVector<int64_t >(sizes);
158- }
159-
160- // / Wrapper around `getStridesAndOffset` that returns only the offset and
161- // / conforms to the function signature expected for `getAttributes` in
162- // / `constifyIndexValues`.
163- static SmallVector<int64_t > getConstantOffset (MemRefType memrefType) {
164- SmallVector<int64_t > strides;
165- int64_t offset;
166- LogicalResult hasStaticInformation =
167- memrefType.getStridesAndOffset (strides, offset);
168- if (failed (hasStaticInformation))
169- return SmallVector<int64_t >();
170- return SmallVector<int64_t >(1 , offset);
171- }
172-
173- // / Wrapper around `getStridesAndOffset` that returns only the strides and
174- // / conforms to the function signature expected for `getAttributes` in
175- // / `constifyIndexValues`.
176- static SmallVector<int64_t > getConstantStrides (MemRefType memrefType) {
177- SmallVector<int64_t > strides;
178- int64_t offset;
179- LogicalResult hasStaticInformation =
180- memrefType.getStridesAndOffset (strides, offset);
181- if (failed (hasStaticInformation))
182- return SmallVector<int64_t >();
183- return strides;
184- }
185-
186115// ===----------------------------------------------------------------------===//
187116// AllocOp / AllocaOp
188117// ===----------------------------------------------------------------------===//
@@ -1445,24 +1374,34 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
14451374
14461375SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes () {
14471376 SmallVector<OpFoldResult> values = getAsOpFoldResult (getSizes ());
1448- constifyIndexValues (values, getSource ().getType (), getContext (),
1449- getConstantSizes, ShapedType::isDynamic);
1377+ constifyIndexValues (values, getSource ().getType ().getShape ());
14501378 return values;
14511379}
14521380
14531381SmallVector<OpFoldResult>
14541382ExtractStridedMetadataOp::getConstifiedMixedStrides () {
14551383 SmallVector<OpFoldResult> values = getAsOpFoldResult (getStrides ());
1456- constifyIndexValues (values, getSource ().getType (), getContext (),
1457- getConstantStrides, ShapedType::isDynamic);
1384+ SmallVector<int64_t > staticValues;
1385+ int64_t unused;
1386+ LogicalResult status =
1387+ getSource ().getType ().getStridesAndOffset (staticValues, unused);
1388+ (void )status;
1389+ assert (succeeded (status) && " could not get strides from type" );
1390+ constifyIndexValues (values, staticValues);
14581391 return values;
14591392}
14601393
14611394OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset () {
14621395 OpFoldResult offsetOfr = getAsOpFoldResult (getOffset ());
14631396 SmallVector<OpFoldResult> values (1 , offsetOfr);
1464- constifyIndexValues (values, getSource ().getType (), getContext (),
1465- getConstantOffset, ShapedType::isDynamic);
1397+ SmallVector<int64_t > staticValues, unused;
1398+ int64_t offset;
1399+ LogicalResult status =
1400+ getSource ().getType ().getStridesAndOffset (unused, offset);
1401+ (void )status;
1402+ assert (succeeded (status) && " could not get offset from type" );
1403+ staticValues.push_back (offset);
1404+ constifyIndexValues (values, staticValues);
14661405 return values[0 ];
14671406}
14681407
@@ -1975,24 +1914,32 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
19751914
19761915SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes () {
19771916 SmallVector<OpFoldResult> values = getMixedSizes ();
1978- constifyIndexValues (values, getType (), getContext (), getConstantSizes,
1979- ShapedType::isDynamic);
1917+ constifyIndexValues (values, getType ().getShape ());
19801918 return values;
19811919}
19821920
19831921SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides () {
19841922 SmallVector<OpFoldResult> values = getMixedStrides ();
1985- constifyIndexValues (values, getType (), getContext (), getConstantStrides,
1986- ShapedType::isDynamic);
1923+ SmallVector<int64_t > staticValues;
1924+ int64_t unused;
1925+ LogicalResult status = getType ().getStridesAndOffset (staticValues, unused);
1926+ (void )status;
1927+ assert (succeeded (status) && " could not get strides from type" );
1928+ constifyIndexValues (values, staticValues);
19871929 return values;
19881930}
19891931
19901932OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset () {
19911933 SmallVector<OpFoldResult> values = getMixedOffsets ();
19921934 assert (values.size () == 1 &&
19931935 " reinterpret_cast must have one and only one offset" );
1994- constifyIndexValues (values, getType (), getContext (), getConstantOffset,
1995- ShapedType::isDynamic);
1936+ SmallVector<int64_t > staticValues, unused;
1937+ int64_t offset;
1938+ LogicalResult status = getType ().getStridesAndOffset (unused, offset);
1939+ (void )status;
1940+ assert (succeeded (status) && " could not get offset from type" );
1941+ staticValues.push_back (offset);
1942+ constifyIndexValues (values, staticValues);
19961943 return values[0 ];
19971944}
19981945
0 commit comments