@@ -167,56 +167,127 @@ struct BubbleUpExpandThroughParallelCollapse
167167 return failure ();
168168 }
169169
170- // Reshapes are parallel to each other if none of the reassociation indices
171- // have greater than 1 index for both reshapes.
170+ // Reshapes are parallel to each other (by construction the number of
171+ // reassociations specified in the collapse and expand are the same), if at
172+ // any position
173+ // 1. either the reassociation indices are of the same size, or
174+ // 2. either the reassociation in the collapse or the expand is of size 1.
175+ ArrayRef<int64_t > staticSourceSize = collapseOp.getSrcType ().getShape ();
176+ ArrayRef<int64_t > staticResultSize = expandOp.getStaticOutputShape ();
172177 for (auto [expandReassociation, collapseReassociation] :
173178 llvm::zip_equal (expandReInds, collapseReInds)) {
179+ if (collapseReassociation.size () == expandReassociation.size ()) {
180+ // Even if the reassociations are the same, the collapse/expand should
181+ // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
182+ // into 4x8x2 again. In presense of dynamic dimensions one can only
183+ // verify "equality" when there is only one dynamic dimension present,
184+ // and all other static dimensions are equal.
185+ ArrayRef<int64_t > collapsedStaticShapes = staticSourceSize.slice (
186+ collapseReassociation.front (), collapseReassociation.size ());
187+ int64_t numCollapsedDynamic =
188+ llvm::count_if (collapsedStaticShapes,
189+ [](int64_t d) { return ShapedType::isDynamic (d); });
190+ ArrayRef<int64_t > expandedStaticShapes = staticResultSize.slice (
191+ expandReassociation.front (), expandReassociation.size ());
192+ int64_t numExpandedDynamic =
193+ llvm::count_if (expandedStaticShapes,
194+ [](int64_t d) { return ShapedType::isDynamic (d); });
195+ if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
196+ collapsedStaticShapes != expandedStaticShapes) {
197+ return failure ();
198+ }
199+ continue ;
200+ }
201+ // If the reassociations are not same, one or the other needs to be of
202+ // size one.
174203 if (collapseReassociation.size () != 1 && expandReassociation.size () != 1 )
175204 return failure ();
176205 }
177206
178207 // Compute new reassociation indices and expanded/collaped shapes.
179208 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
180209 Location loc = expandOp->getLoc ();
181- SmallVector<OpFoldResult> collapseSizes =
210+ SmallVector<OpFoldResult> sourceSizes =
182211 tensor::getMixedSizes (rewriter, loc, collapseOp.getSrc ());
183- SmallVector<OpFoldResult> expandSizes (getMixedValues (
184- expandOp.getStaticOutputShape (), expandOp.getOutputShape (), rewriter));
212+ SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape ();
185213 SmallVector<OpFoldResult> newExpandSizes;
186- int64_t index = 0 , expandIndex = 0 , collapseIndex = 0 ;
187- for (auto [idx, collapseReassociation] : llvm::enumerate (collapseReInds)) {
214+
215+ int64_t newExpandIndex = 0 , newCollapseIndex = 0 , sourceSizeIndex = 0 ,
216+ resultSizeIndex = 0 ;
217+
218+ for (size_t idx = 0 , idxEnd = collapseReInds.size (); idx < idxEnd;
219+ idx++) {
220+ auto &collapseReassociation = collapseReInds[idx];
221+ auto &expandReassociation = expandReInds[idx];
222+
223+ // Case 1. The reassociations are same in the collapse producer
224+ // and expand consumer. In the swapped expand, each of the final
225+ // dimensions are kept as is in the expand and the collapse. So,
226+ // for every element in the `ReassocationIndices` vector add a new
227+ // `ReassociationIndices` vector for the swapped expand and collapse
228+ // (of size 1).
229+ if (collapseReassociation.size () == expandReassociation.size ()) {
230+ for (size_t i = 0 ; i < collapseReassociation.size (); ++i) {
231+ newCollapseReInds.push_back ({newCollapseIndex++});
232+ newExpandReInds.push_back ({newExpandIndex++});
233+ newExpandSizes.push_back (resultSizes[resultSizeIndex++]);
234+ sourceSizeIndex++;
235+ }
236+ continue ;
237+ }
238+
239+ // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
240+ // in the expand is of size == 1). In this case, the original dimensions
241+ // are preserved on expansion and collapsed subsequently.
188242 if (collapseReassociation.size () != 1 ) {
189243 ReassociationIndices newCollapseReassociation;
190244 for (size_t i = 0 ; i < collapseReassociation.size (); ++i) {
191- newCollapseReassociation.push_back (index );
192- newExpandReInds.push_back ({index ++});
193- newExpandSizes.push_back (collapseSizes[collapseIndex ++]);
245+ newCollapseReassociation.push_back (newCollapseIndex++ );
246+ newExpandReInds.push_back ({newExpandIndex ++});
247+ newExpandSizes.push_back (sourceSizes[sourceSizeIndex ++]);
194248 }
249+ resultSizeIndex++;
195250 newCollapseReInds.push_back (newCollapseReassociation);
196- expandIndex++;
197251 continue ;
198252 }
253+
254+ // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
255+ // in the collapse is of size == 1). In this case, the expansion happens
256+ // first and the expanded dimensions are preserved on collapse.
199257 ReassociationIndices newExpandReassociation;
200- auto expandReassociation = expandReInds[idx];
201258 for (size_t i = 0 ; i < expandReassociation.size (); ++i) {
202- newExpandReassociation.push_back (index );
203- newCollapseReInds.push_back ({index ++});
204- newExpandSizes.push_back (expandSizes[expandIndex ++]);
259+ newExpandReassociation.push_back (newExpandIndex++ );
260+ newCollapseReInds.push_back ({newCollapseIndex ++});
261+ newExpandSizes.push_back (resultSizes[resultSizeIndex ++]);
205262 }
206263 newExpandReInds.push_back (newExpandReassociation);
207- collapseIndex ++;
264+ sourceSizeIndex ++;
208265 }
209266
210267 // Swap reshape order.
211268 SmallVector<Value> dynamicSizes;
212269 SmallVector<int64_t > staticSizes;
213270 dispatchIndexOpFoldResults (newExpandSizes, dynamicSizes, staticSizes);
214271 auto expandResultType = expandOp.getResultType ().clone (staticSizes);
215- auto newExpand = rewriter.create <tensor::ExpandShapeOp>(
216- loc, expandResultType, collapseOp.getSrc (), newExpandReInds,
217- newExpandSizes);
218- rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
219- expandOp, newExpand.getResult (), newCollapseReInds);
272+ Value newCollapseSrc = collapseOp.getSrc ();
273+ // If the number of reassociation indices in the new `expand_shape` op
274+ // matches the number of dimensions of the result, then the expand_shape
275+ // is a no-op.
276+ if (newExpandReInds.size () != newExpandSizes.size ()) {
277+ newCollapseSrc = rewriter.create <tensor::ExpandShapeOp>(
278+ loc, expandResultType, newCollapseSrc, newExpandReInds,
279+ newExpandSizes);
280+ }
281+
282+ // If the number of reassociation indices in the new `collapse_shape` op
283+ // matches the number of dimensions of the source, then the collapse_shape
284+ // is a no-op.
285+ Value replacement = newCollapseSrc;
286+ if (newCollapseReInds.size () != newExpandSizes.size ()) {
287+ replacement = rewriter.create <tensor::CollapseShapeOp>(
288+ loc, newCollapseSrc, newCollapseReInds);
289+ }
290+ rewriter.replaceOp (expandOp, replacement);
220291 return success ();
221292 }
222293};
0 commit comments