@@ -59,92 +59,28 @@ using namespace mlir;
5959///
6060/// %2 = load %0[6 * i1 + i2, %i3] :
6161/// memref<12x42xf32>
62- static LogicalResult
63- resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
64- memref::ExpandShapeOp expandShapeOp,
65- ValueRange indices,
66- SmallVectorImpl<Value> &sourceIndices) {
67- // Record the rewriter context for constructing ops later.
68- MLIRContext *ctx = rewriter.getContext();
69-
70- // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
71- // This is done for the purpose of inferring the output shape via
72- // `inferExpandOutputShape` which will in turn be used for suffix product
73- // calculation later.
74- SmallVector<OpFoldResult> srcShape;
75- MemRefType srcType = expandShapeOp.getSrcType();
76-
77- for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
78- if (srcType.isDynamicDim(i)) {
79- srcShape.push_back(
80- rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
81- .getResult());
82- } else {
83- srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
84- }
85- }
86-
87- auto outputShape = inferExpandShapeOutputShape(
88- rewriter, loc, expandShapeOp.getResultType(),
89- expandShapeOp.getReassociationIndices(), srcShape);
90- if (!outputShape.has_value())
91- return failure();
62+ static LogicalResult resolveSourceIndicesExpandShape(
63+ Location loc, PatternRewriter &rewriter,
64+ memref::ExpandShapeOp expandShapeOp, ValueRange indices,
65+ SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
66+ SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
9267
9368 // Traverse all reassociation groups to determine the appropriate indices
9469 // corresponding to each one of them post op folding.
95- for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
96- assert(!groups.empty() && "association indices groups cannot be empty");
97- // Flag to indicate the presence of dynamic dimensions in current
98- // reassociation group.
99- int64_t groupSize = groups.size();
100-
101- // Group output dimensions utilized in this reassociation group for suffix
102- // product calculation.
103- SmallVector<OpFoldResult> sizesVal(groupSize);
104- for (int64_t i = 0; i < groupSize; ++i) {
105- sizesVal[i] = (*outputShape)[groups[i]];
70+ for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
71+ assert(!group.empty() && "association indices groups cannot be empty");
72+ int64_t groupSize = group.size();
73+ if (groupSize == 1) {
74+ sourceIndices.push_back(indices[group[0]]);
75+ continue;
10676 }
107-
108- // Calculate suffix product of relevant output dimension sizes.
109- SmallVector<OpFoldResult> suffixProduct =
110- memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
111-
112- // Create affine expression variables for dimensions and symbols in the
113- // newly constructed affine map.
114- SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
115- bindDimsList<AffineExpr>(ctx, dims);
116- bindSymbolsList<AffineExpr>(ctx, symbols);
117-
118- // Linearize binded dimensions and symbols to construct the resultant
119- // affine expression for this indice.
120- AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
121-
122- // Record the load index corresponding to each dimension in the
123- // reassociation group. These are later supplied as operands to the affine
124- // map used for calulating relevant index post op folding.
125- SmallVector<OpFoldResult> dynamicIndices(groupSize);
126- for (int64_t i = 0; i < groupSize; i++)
127- dynamicIndices[i] = indices[groups[i]];
128-
129- // Supply suffix product results followed by load op indices as operands
130- // to the map.
131- SmallVector<OpFoldResult> mapOperands;
132- llvm::append_range(mapOperands, suffixProduct);
133- llvm::append_range(mapOperands, dynamicIndices);
134-
135- // Creating maximally folded and composed affine.apply composes better
136- // with other transformations without interleaving canonicalization
137- // passes.
138- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
139- rewriter, loc,
140- AffineMap::get(/*numDims=*/groupSize,
141- /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
142- mapOperands);
143-
144- // Push index value in the op post folding corresponding to this
145- // reassociation group.
146- sourceIndices.push_back(
147- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
77+ SmallVector<OpFoldResult> groupBasis =
78+ llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
79+ SmallVector<Value> groupIndices =
80+ llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
81+ Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
82+ loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
83+ sourceIndices.push_back(collapsedIndex);
14884 }
14985 return success();
15086}
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
167103 memref::CollapseShapeOp collapseShapeOp,
168104 ValueRange indices,
169105 SmallVectorImpl<Value> &sourceIndices) {
170- int64_t cnt = 0;
171- SmallVector<OpFoldResult> dynamicIndices;
172- for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
173- assert(!groups.empty() && "association indices groups cannot be empty");
174- dynamicIndices.push_back(indices[cnt++]);
175- int64_t groupSize = groups.size();
176-
177- // Calculate suffix product for all collapse op source dimension sizes
178- // except the most major one of each group.
179- // We allow the most major source dimension to be dynamic but enforce all
180- // others to be known statically.
181- SmallVector<int64_t> sizes(groupSize, 1);
182- for (int64_t i = 1; i < groupSize; ++i) {
183- sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
184- if (sizes[i] == ShapedType::kDynamic)
185- return failure();
186- }
187- SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
188-
189- // Derive the index values along all dimensions of the source corresponding
190- // to the index wrt to collapsed shape op output.
191- auto d0 = rewriter.getAffineDimExpr(0);
192- SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
193-
194- // Construct the AffineApplyOp for each delinearizingExpr.
195- for (int64_t i = 0; i < groupSize; i++) {
196- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
197- rewriter, loc,
198- AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
199- delinearizingExprs[i]),
200- dynamicIndices);
201- sourceIndices.push_back(
202- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
106+ MemRefType sourceType = collapseShapeOp.getSrcType();
107+ // Note: collapse_shape requires a strided memref, we can do this.
108+ auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
109+ loc, collapseShapeOp.getSrc());
110+ SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
111+ for (auto [index, group] :
112+ llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
113+ assert(!group.empty() && "association indices groups cannot be empty");
114+ int64_t groupSize = group.size();
115+
116+ if (groupSize == 1) {
117+ sourceIndices.push_back(index);
118+ continue;
203119 }
204- dynamicIndices.clear();
120+
121+ SmallVector<OpFoldResult> basis =
122+ llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
123+ auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
124+ loc, index, basis, /*hasOuterBound=*/true);
125+ llvm::append_range(sourceIndices, delinearize.getResults());
205126 }
206127 if (collapseShapeOp.getReassociationIndices().empty()) {
207128 auto zeroAffineMap = rewriter.getConstantAffineMap(0);
208129 int64_t srcRank =
209130 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
131+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
132+ rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
210133 for (int64_t i = 0; i < srcRank; i++) {
211- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
212- rewriter, loc, zeroAffineMap, dynamicIndices);
213134 sourceIndices.push_back(
214135 getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
215136 }
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
513434 indices.assign(expandedIndices.begin(), expandedIndices.end());
514435 }
515436 SmallVector<Value> sourceIndices;
437+ // memref.load and affine.load guarantee that indexes start inbounds
438+ // while the vector operations don't. This impacts if our linearization
439+ // is `disjoint`
516440 if (failed(resolveSourceIndicesExpandShape(
517- loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
441+ loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
442+ isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
518443 return failure();
519444 llvm::TypeSwitch<Operation *, void>(loadOp)
520445 .Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
676601 indices.assign(expandedIndices.begin(), expandedIndices.end());
677602 }
678603 SmallVector<Value> sourceIndices;
604+ // memref.store and affine.store guarantee that indexes start inbounds
605+ // while the vector operations don't. This impacts if our linearization
606+ // is `disjoint`
679607 if (failed(resolveSourceIndicesExpandShape(
680- storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
608+ storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
609+ isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
681610 return failure();
682611 llvm::TypeSwitch<Operation *, void>(storeOp)
683612 .Case([&](affine::AffineStoreOp op) {
0 commit comments