2424#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2525#include " llvm/ADT/STLExtras.h"
2626#include " llvm/ADT/SmallBitVector.h"
27+ #include " mlir/IR/OpDefinition.h"
2728#include < optional>
2829
2930namespace mlir {
@@ -35,6 +36,7 @@ namespace memref {
3536
3637using namespace mlir ;
3738using namespace mlir ::affine;
39+ using namespace mlir ::memref;
3840
3941namespace {
4042
@@ -250,23 +252,12 @@ struct ExtractStridedMetadataOpSubviewFolder
250252 }
251253};
252254
253- // / Compute the expanded sizes of the given \p expandShape for the
254- // / \p groupId-th reassociation group.
255- // / \p origSizes hold the sizes of the source shape as values.
256- // / This is used to compute the new sizes in cases of dynamic shapes.
257- // /
258- // / sizes#i =
259- // / baseSizes#groupId / product(expandShapeSizes#j,
260- // / for j in group excluding reassIdx#i)
261- // / Where reassIdx#i is the reassociation index at index i in \p groupId.
262- // /
263- // / \post result.size() == expandShape.getReassociationIndices()[groupId].size()
264- // /
265- // / TODO: Move this utility function directly within ExpandShapeOp. For now,
266- // / this is not possible because this function uses the Affine dialect and the
267- // / MemRef dialect cannot depend on the Affine dialect.
268- static SmallVector<OpFoldResult>
269- getExpandedSizes (memref::ExpandShapeOp expandShape, OpBuilder &builder,
255+ } // namespace
256+
257+ namespace mlir {
258+ namespace memref {
259+ SmallVector<OpFoldResult>
260+ getExpandedSizes (ExpandShapeOp expandShape, OpBuilder &builder,
270261 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
271262 SmallVector<int64_t , 2 > reassocGroup =
272263 expandShape.getReassociationIndices ()[groupId];
@@ -305,31 +296,7 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
305296 return expandedSizes;
306297}
307298
308- // / Compute the expanded strides of the given \p expandShape for the
309- // / \p groupId-th reassociation group.
310- // / \p origStrides and \p origSizes hold respectively the strides and sizes
311- // / of the source shape as values.
312- // / This is used to compute the strides in cases of dynamic shapes and/or
313- // / dynamic stride for this reassociation group.
314- // /
315- // / strides#i =
316- // / origStrides#reassDim * product(expandShapeSizes#j, for j in
317- // / reassIdx#i+1..reassIdx#i+group.size-1)
318- // /
319- // / Where reassIdx#i is the reassociation index for at index i in \p groupId
320- // / and expandShapeSizes#j is either:
321- // / - The constant size at dimension j, derived directly from the result type of
322- // / the expand_shape op, or
323- // / - An affine expression: baseSizes#reassDim / product of all constant sizes
324- // / in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
325- // / element.)
326- // /
327- // / \post result.size() == expandShape.getReassociationIndices()[groupId].size()
328- // /
329- // / TODO: Move this utility function directly within ExpandShapeOp. For now,
330- // / this is not possible because this function uses the Affine dialect and the
331- // / MemRef dialect cannot depend on the Affine dialect.
332- SmallVector<OpFoldResult> getExpandedStrides (memref::ExpandShapeOp expandShape,
299+ SmallVector<OpFoldResult> getExpandedStrides (ExpandShapeOp expandShape,
333300 OpBuilder &builder,
334301 ArrayRef<OpFoldResult> origSizes,
335302 ArrayRef<OpFoldResult> origStrides,
@@ -405,14 +372,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
405372 return expandedStrides;
406373}
407374
408- // / Produce an OpFoldResult object with \p builder at \p loc representing
409- // / `prod(valueOrConstant#i, for i in {indices})`,
410- // / where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
411- // / values[i] otherwise.
412- // /
413- // / \pre for all index in indices: index < values.size()
414- // / \pre for all index in indices: index < maybeConstants.size()
415- static OpFoldResult
375+ OpFoldResult
416376getProductOfValues (ArrayRef<int64_t > indices, OpBuilder &builder, Location loc,
417377 ArrayRef<int64_t > maybeConstants,
418378 ArrayRef<OpFoldResult> values,
@@ -450,8 +410,8 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
450410// / TODO: Move this utility function directly within CollapseShapeOp. For now,
451411// / this is not possible because this function uses the Affine dialect and the
452412// / MemRef dialect cannot depend on the Affine dialect.
453- static SmallVector<OpFoldResult>
454- getCollapsedSize (memref:: CollapseShapeOp collapseShape, OpBuilder &builder,
413+ SmallVector<OpFoldResult>
414+ getCollapsedSize (CollapseShapeOp collapseShape, OpBuilder &builder,
455415 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
456416 SmallVector<OpFoldResult> collapsedSize;
457417
@@ -491,8 +451,8 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
491451// /
492452// / \post result.size() == 1, in other words, each group collapse to one
493453// / dimension.
494- static SmallVector<OpFoldResult>
495- getCollapsedStride (memref:: CollapseShapeOp collapseShape, OpBuilder &builder,
454+ SmallVector<OpFoldResult>
455+ getCollapsedStride (CollapseShapeOp collapseShape, OpBuilder &builder,
496456 ArrayRef<OpFoldResult> origSizes,
497457 ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
498458 SmallVector<int64_t , 2 > reassocGroup =
@@ -546,6 +506,10 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
546506
547507 return {lastValidStride};
548508}
509+ } // namespace memref
510+ } // namespace mlir
511+
512+ namespace {
549513
550514// / From `reshape_like(memref, subSizes, subStrides))` compute
551515// /
0 commit comments