@@ -88,101 +88,30 @@ SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
8888//  Utility functions for propagating static information
8989// ===----------------------------------------------------------------------===//
9090
91- // / Helper function that infers the constant values from a list of \p values,
92- // / a \p memRefTy, and another helper function \p getAttributes.
93- // / The inferred constant values replace the related `OpFoldResult` in
94- // / \p values.
91+ // / Helper function that sets values[i] to constValues[i] if the latter is a
92+ // / static value, as indicated by ShapedType::kDynamic.
9593// /
96- // / \note This function shouldn't be used directly, instead, use the
97- // / `getConstifiedMixedXXX` methods from the related operations.
98- // /
99- // / \p getAttributes retuns a list of potentially constant values, as determined
100- // / by \p isDynamic, from the given \p memRefTy. The returned list must have as
101- // / many elements as \p values or be empty.
102- // /
103- // / E.g., consider the following example:
104- // / ```
105- // / memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
106- // /     memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
107- // / ```
108- // / `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
109- // / Now using this helper function with:
110- // / - `values == [2, %dyn_stride]`,
111- // / - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
112- // / - `getAttributes == getConstantStrides` (i.e., a wrapper around
113- // / `getStridesAndOffset`), and
114- // / - `isDynamic == ShapedType::isDynamic`
115- // / Will yield: `values == [2, 1]`
116- static  void  constifyIndexValues (
117-     SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
118-     MLIRContext *ctxt,
119-     llvm::function_ref<SmallVector<int64_t >(MemRefType)> getAttributes,
120-     llvm::function_ref<bool(int64_t )> isDynamic) {
121-   SmallVector<int64_t > constValues = getAttributes (memRefTy);
122-   Builder builder (ctxt);
123-   for  (const  auto  &it : llvm::enumerate (constValues)) {
124-     int64_t  constValue = it.value ();
125-     if  (!isDynamic (constValue))
126-       values[it.index ()] = builder.getIndexAttr (constValue);
127-   }
128-   for  (OpFoldResult &ofr : values) {
129-     if  (auto  attr = dyn_cast<Attribute>(ofr)) {
130-       //  FIXME: We shouldn't need to do that, but right now, the static indices
131-       //  are created with the wrong type: `i64` instead of `index`.
132-       //  As a result, if we were to keep the attribute as is, we may fail to see
133-       //  that two attributes are equal because one would have the i64 type and
134-       //  the other the index type.
135-       //  The alternative would be to create constant indices with getI64Attr in
136-       //  this and the previous loop, but it doesn't logically make sense (we are
137-       //  dealing with indices here) and would only strenghten the inconsistency
138-       //  around how static indices are created (some places use getI64Attr,
139-       //  others use getIndexAttr).
140-       //  The workaround here is to stick to the IndexAttr type for all the
141-       //  values, hence we recreate the attribute even when it is already static
142-       //  to make sure the type is consistent.
143-       ofr = builder.getIndexAttr (llvm::cast<IntegerAttr>(attr).getInt ());
94+ // / If constValues[i] is dynamic, tries to extract a constant value from
95+ // / value[i] to allow for additional folding opportunities. Also convertes all
96+ // / existing attributes to index attributes. (They may be i64 attributes.)
97+ static  void  constifyIndexValues (SmallVectorImpl<OpFoldResult> &values,
98+                                 ArrayRef<int64_t > constValues) {
99+   assert (constValues.size () == values.size () &&
100+          " incorrect number of const values" 
101+   for  (auto  [i, cstVal] : llvm::enumerate (constValues)) {
102+     Builder builder (values[i].getContext ());
103+     if  (!ShapedType::isDynamic (cstVal)) {
104+       //  Constant value is known, use it directly.
105+       values[i] = builder.getIndexAttr (cstVal);
144106      continue ;
145107    }
146-     std::optional<int64_t > maybeConstant = 
147-          getConstantIntValue (cast<Value>(ofr)); 
148-     if  (maybeConstant) 
149-       ofr = builder. getIndexAttr (*maybeConstant); 
108+     if  ( std::optional<int64_t > cst =  getConstantIntValue (values[i])) { 
109+       //  Try to extract a constant or convert an existing to index. 
110+       values[i] = builder. getIndexAttr (*cst); 
111+     } 
150112  }
151113}
152114
153- // / Wrapper around `getShape` that conforms to the function signature
154- // / expected for `getAttributes` in `constifyIndexValues`.
155- static  SmallVector<int64_t > getConstantSizes (MemRefType memRefTy) {
156-   ArrayRef<int64_t > sizes = memRefTy.getShape ();
157-   return  SmallVector<int64_t >(sizes);
158- }
159- 
160- // / Wrapper around `getStridesAndOffset` that returns only the offset and
161- // / conforms to the function signature expected for `getAttributes` in
162- // / `constifyIndexValues`.
163- static  SmallVector<int64_t > getConstantOffset (MemRefType memrefType) {
164-   SmallVector<int64_t > strides;
165-   int64_t  offset;
166-   LogicalResult hasStaticInformation =
167-       memrefType.getStridesAndOffset (strides, offset);
168-   if  (failed (hasStaticInformation))
169-     return  SmallVector<int64_t >();
170-   return  SmallVector<int64_t >(1 , offset);
171- }
172- 
173- // / Wrapper around `getStridesAndOffset` that returns only the strides and
174- // / conforms to the function signature expected for `getAttributes` in
175- // / `constifyIndexValues`.
176- static  SmallVector<int64_t > getConstantStrides (MemRefType memrefType) {
177-   SmallVector<int64_t > strides;
178-   int64_t  offset;
179-   LogicalResult hasStaticInformation =
180-       memrefType.getStridesAndOffset (strides, offset);
181-   if  (failed (hasStaticInformation))
182-     return  SmallVector<int64_t >();
183-   return  strides;
184- }
185- 
186115// ===----------------------------------------------------------------------===//
187116//  AllocOp / AllocaOp
188117// ===----------------------------------------------------------------------===//
@@ -1445,24 +1374,34 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
14451374
14461375SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes () {
14471376  SmallVector<OpFoldResult> values = getAsOpFoldResult (getSizes ());
1448-   constifyIndexValues (values, getSource ().getType (), getContext (),
1449-                       getConstantSizes, ShapedType::isDynamic);
1377+   constifyIndexValues (values, getSource ().getType ().getShape ());
14501378  return  values;
14511379}
14521380
14531381SmallVector<OpFoldResult>
14541382ExtractStridedMetadataOp::getConstifiedMixedStrides () {
14551383  SmallVector<OpFoldResult> values = getAsOpFoldResult (getStrides ());
1456-   constifyIndexValues (values, getSource ().getType (), getContext (),
1457-                       getConstantStrides, ShapedType::isDynamic);
1384+   SmallVector<int64_t > staticValues;
1385+   int64_t  unused;
1386+   LogicalResult status =
1387+       getSource ().getType ().getStridesAndOffset (staticValues, unused);
1388+   (void )status;
1389+   assert (succeeded (status) && " could not get strides from type" 
1390+   constifyIndexValues (values, staticValues);
14581391  return  values;
14591392}
14601393
14611394OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset () {
14621395  OpFoldResult offsetOfr = getAsOpFoldResult (getOffset ());
14631396  SmallVector<OpFoldResult> values (1 , offsetOfr);
1464-   constifyIndexValues (values, getSource ().getType (), getContext (),
1465-                       getConstantOffset, ShapedType::isDynamic);
1397+   SmallVector<int64_t > staticValues, unused;
1398+   int64_t  offset;
1399+   LogicalResult status =
1400+       getSource ().getType ().getStridesAndOffset (unused, offset);
1401+   (void )status;
1402+   assert (succeeded (status) && " could not get offset from type" 
1403+   staticValues.push_back (offset);
1404+   constifyIndexValues (values, staticValues);
14661405  return  values[0 ];
14671406}
14681407
@@ -1975,24 +1914,32 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
19751914
19761915SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes () {
19771916  SmallVector<OpFoldResult> values = getMixedSizes ();
1978-   constifyIndexValues (values, getType (), getContext (), getConstantSizes,
1979-                       ShapedType::isDynamic);
1917+   constifyIndexValues (values, getType ().getShape ());
19801918  return  values;
19811919}
19821920
19831921SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides () {
19841922  SmallVector<OpFoldResult> values = getMixedStrides ();
1985-   constifyIndexValues (values, getType (), getContext (), getConstantStrides,
1986-                       ShapedType::isDynamic);
1923+   SmallVector<int64_t > staticValues;
1924+   int64_t  unused;
1925+   LogicalResult status = getType ().getStridesAndOffset (staticValues, unused);
1926+   (void )status;
1927+   assert (succeeded (status) && " could not get strides from type" 
1928+   constifyIndexValues (values, staticValues);
19871929  return  values;
19881930}
19891931
19901932OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset () {
19911933  SmallVector<OpFoldResult> values = getMixedOffsets ();
19921934  assert (values.size () == 1  &&
19931935         " reinterpret_cast must have one and only one offset" 
1994-   constifyIndexValues (values, getType (), getContext (), getConstantOffset,
1995-                       ShapedType::isDynamic);
1936+   SmallVector<int64_t > staticValues, unused;
1937+   int64_t  offset;
1938+   LogicalResult status = getType ().getStridesAndOffset (unused, offset);
1939+   (void )status;
1940+   assert (succeeded (status) && " could not get offset from type" 
1941+   staticValues.push_back (offset);
1942+   constifyIndexValues (values, staticValues);
19961943  return  values[0 ];
19971944}
19981945
0 commit comments