@@ -59,28 +59,92 @@ using namespace mlir;
5959// /
6060// / %2 = load %0[6 * i1 + i2, %i3] :
6161// / memref<12x42xf32>
62- static LogicalResult resolveSourceIndicesExpandShape (
63- Location loc, PatternRewriter &rewriter,
64- memref::ExpandShapeOp expandShapeOp, ValueRange indices,
65- SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
66- SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape ();
62+ static LogicalResult
63+ resolveSourceIndicesExpandShape (Location loc, PatternRewriter &rewriter,
64+ memref::ExpandShapeOp expandShapeOp,
65+ ValueRange indices,
66+ SmallVectorImpl<Value> &sourceIndices) {
67+ // Record the rewriter context for constructing ops later.
68+ MLIRContext *ctx = rewriter.getContext ();
69+
70+ // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
71+ // This is done for the purpose of inferring the output shape via
72+ // `inferExpandOutputShape` which will in turn be used for suffix product
73+ // calculation later.
74+ SmallVector<OpFoldResult> srcShape;
75+ MemRefType srcType = expandShapeOp.getSrcType ();
76+
77+ for (int64_t i = 0 , e = srcType.getRank (); i < e; ++i) {
78+ if (srcType.isDynamicDim (i)) {
79+ srcShape.push_back (
80+ rewriter.create <memref::DimOp>(loc, expandShapeOp.getSrc (), i)
81+ .getResult ());
82+ } else {
83+ srcShape.push_back (rewriter.getIndexAttr (srcType.getShape ()[i]));
84+ }
85+ }
86+
87+ auto outputShape = inferExpandShapeOutputShape (
88+ rewriter, loc, expandShapeOp.getResultType (),
89+ expandShapeOp.getReassociationIndices (), srcShape);
90+ if (!outputShape.has_value ())
91+ return failure ();
6792
6893 // Traverse all reassociation groups to determine the appropriate indices
6994 // corresponding to each one of them post op folding.
70- for (ArrayRef<int64_t > group : expandShapeOp.getReassociationIndices ()) {
71- assert (!group.empty () && " association indices groups cannot be empty" );
72- int64_t groupSize = group.size ();
73- if (groupSize == 1 ) {
74- sourceIndices.push_back (indices[group[0 ]]);
75- continue ;
95+ for (ArrayRef<int64_t > groups : expandShapeOp.getReassociationIndices ()) {
96+ assert (!groups.empty () && " association indices groups cannot be empty" );
97+ // Flag to indicate the presence of dynamic dimensions in current
98+ // reassociation group.
99+ int64_t groupSize = groups.size ();
100+
101+ // Group output dimensions utilized in this reassociation group for suffix
102+ // product calculation.
103+ SmallVector<OpFoldResult> sizesVal (groupSize);
104+ for (int64_t i = 0 ; i < groupSize; ++i) {
105+ sizesVal[i] = (*outputShape)[groups[i]];
76106 }
77- SmallVector<OpFoldResult> groupBasis =
78- llvm::map_to_vector (group, [&](int64_t d) { return destShape[d]; });
79- SmallVector<Value> groupIndices =
80- llvm::map_to_vector (group, [&](int64_t d) { return indices[d]; });
81- Value collapsedIndex = rewriter.create <affine::AffineLinearizeIndexOp>(
82- loc, groupIndices, groupBasis, /* disjoint=*/ startsInbounds);
83- sourceIndices.push_back (collapsedIndex);
107+
108+ // Calculate suffix product of relevant output dimension sizes.
109+ SmallVector<OpFoldResult> suffixProduct =
110+ memref::computeSuffixProductIRBlock (loc, rewriter, sizesVal);
111+
112+ // Create affine expression variables for dimensions and symbols in the
113+ // newly constructed affine map.
114+ SmallVector<AffineExpr> dims (groupSize), symbols (groupSize);
115+ bindDimsList<AffineExpr>(ctx, dims);
116+ bindSymbolsList<AffineExpr>(ctx, symbols);
117+
118+ // Linearize binded dimensions and symbols to construct the resultant
119+ // affine expression for this indice.
120+ AffineExpr srcIndexExpr = linearize (ctx, dims, symbols);
121+
122+ // Record the load index corresponding to each dimension in the
123+ // reassociation group. These are later supplied as operands to the affine
124+ // map used for calulating relevant index post op folding.
125+ SmallVector<OpFoldResult> dynamicIndices (groupSize);
126+ for (int64_t i = 0 ; i < groupSize; i++)
127+ dynamicIndices[i] = indices[groups[i]];
128+
129+ // Supply suffix product results followed by load op indices as operands
130+ // to the map.
131+ SmallVector<OpFoldResult> mapOperands;
132+ llvm::append_range (mapOperands, suffixProduct);
133+ llvm::append_range (mapOperands, dynamicIndices);
134+
135+ // Creating maximally folded and composed affine.apply composes better
136+ // with other transformations without interleaving canonicalization
137+ // passes.
138+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
139+ rewriter, loc,
140+ AffineMap::get (/* numDims=*/ groupSize,
141+ /* numSymbols=*/ groupSize, /* expression=*/ srcIndexExpr),
142+ mapOperands);
143+
144+ // Push index value in the op post folding corresponding to this
145+ // reassociation group.
146+ sourceIndices.push_back (
147+ getValueOrCreateConstantIndexOp (rewriter, loc, ofr));
84148 }
85149 return success ();
86150}
@@ -103,33 +167,49 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
103167 memref::CollapseShapeOp collapseShapeOp,
104168 ValueRange indices,
105169 SmallVectorImpl<Value> &sourceIndices) {
106- // Note: collapse_shape requires a strided memref, we can do this.
107- auto metadata = rewriter.create <memref::ExtractStridedMetadataOp>(
108- loc, collapseShapeOp.getSrc ());
109- SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes ();
110- for (auto [index, group] :
111- llvm::zip (indices, collapseShapeOp.getReassociationIndices ())) {
112- assert (!group.empty () && " association indices groups cannot be empty" );
113- int64_t groupSize = group.size ();
114-
115- if (groupSize == 1 ) {
116- sourceIndices.push_back (index);
117- continue ;
170+ int64_t cnt = 0 ;
171+ SmallVector<OpFoldResult> dynamicIndices;
172+ for (ArrayRef<int64_t > groups : collapseShapeOp.getReassociationIndices ()) {
173+ assert (!groups.empty () && " association indices groups cannot be empty" );
174+ dynamicIndices.push_back (indices[cnt++]);
175+ int64_t groupSize = groups.size ();
176+
177+ // Calculate suffix product for all collapse op source dimension sizes
178+ // except the most major one of each group.
179+ // We allow the most major source dimension to be dynamic but enforce all
180+ // others to be known statically.
181+ SmallVector<int64_t > sizes (groupSize, 1 );
182+ for (int64_t i = 1 ; i < groupSize; ++i) {
183+ sizes[i] = collapseShapeOp.getSrcType ().getDimSize (groups[i]);
184+ if (sizes[i] == ShapedType::kDynamic )
185+ return failure ();
118186 }
119-
120- SmallVector<OpFoldResult> basis =
121- llvm::map_to_vector (group, [&](int64_t d) { return sourceSizes[d]; });
122- auto delinearize = rewriter.create <affine::AffineDelinearizeIndexOp>(
123- loc, index, basis, /* hasOuterBound=*/ true );
124- llvm::append_range (sourceIndices, delinearize.getResults ());
187+ SmallVector<int64_t > suffixProduct = computeSuffixProduct (sizes);
188+
189+ // Derive the index values along all dimensions of the source corresponding
190+ // to the index wrt to collapsed shape op output.
191+ auto d0 = rewriter.getAffineDimExpr (0 );
192+ SmallVector<AffineExpr> delinearizingExprs = delinearize (d0, suffixProduct);
193+
194+ // Construct the AffineApplyOp for each delinearizingExpr.
195+ for (int64_t i = 0 ; i < groupSize; i++) {
196+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
197+ rewriter, loc,
198+ AffineMap::get (/* numDims=*/ 1 , /* numSymbols=*/ 0 ,
199+ delinearizingExprs[i]),
200+ dynamicIndices);
201+ sourceIndices.push_back (
202+ getValueOrCreateConstantIndexOp (rewriter, loc, ofr));
203+ }
204+ dynamicIndices.clear ();
125205 }
126206 if (collapseShapeOp.getReassociationIndices ().empty ()) {
127207 auto zeroAffineMap = rewriter.getConstantAffineMap (0 );
128208 int64_t srcRank =
129209 cast<MemRefType>(collapseShapeOp.getViewSource ().getType ()).getRank ();
130- OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
131- rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
132210 for (int64_t i = 0 ; i < srcRank; i++) {
211+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
212+ rewriter, loc, zeroAffineMap, dynamicIndices);
133213 sourceIndices.push_back (
134214 getValueOrCreateConstantIndexOp (rewriter, loc, ofr));
135215 }
@@ -433,12 +513,8 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
433513 indices.assign (expandedIndices.begin (), expandedIndices.end ());
434514 }
435515 SmallVector<Value> sourceIndices;
436- // memref.load and affine.load guarantee that indexes start inbounds
437- // while the vector operations don't. This impacts if our linearization
438- // is `disjoint`
439516 if (failed (resolveSourceIndicesExpandShape (
440- loadOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices,
441- isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation ()))))
517+ loadOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices)))
442518 return failure ();
443519 llvm::TypeSwitch<Operation *, void >(loadOp)
444520 .Case ([&](affine::AffineLoadOp op) {
@@ -600,12 +676,8 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
600676 indices.assign (expandedIndices.begin (), expandedIndices.end ());
601677 }
602678 SmallVector<Value> sourceIndices;
603- // memref.store and affine.store guarantee that indexes start inbounds
604- // while the vector operations don't. This impacts if our linearization
605- // is `disjoint`
606679 if (failed (resolveSourceIndicesExpandShape (
607- storeOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices,
608- isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation ()))))
680+ storeOp.getLoc (), rewriter, expandShapeOp, indices, sourceIndices)))
609681 return failure ();
610682 llvm::TypeSwitch<Operation *, void >(storeOp)
611683 .Case ([&](affine::AffineStoreOp op) {
0 commit comments