@@ -167,56 +167,126 @@ struct BubbleUpExpandThroughParallelCollapse
167167 return failure ();
168168 }
169169
170- // Reshapes are parallel to each other if none of the reassociation indices
171- // have greater than 1 index for both reshapes.
170+ // Reshapes are parallel to each other (by construction the number of
171+ // reassociations specified in the collapse and expand are the same), if at
172+ // any position
173+ // 1. either the reassociation indices are of the same size, or
174+ // 2. either the reassociation in the collapse or the expand is of size 1.
175+ ArrayRef<int64_t > staticSourceSize = collapseOp.getSrcType ().getShape ();
176+ ArrayRef<int64_t > staticResultSize = expandOp.getStaticOutputShape ();
172177 for (auto [expandReassociation, collapseReassociation] :
173178 llvm::zip_equal (expandReInds, collapseReInds)) {
179+ if (collapseReassociation.size () == expandReassociation.size ()) {
180+ // Even if the reassociations are the same, the collapse/expand should
181+ // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
182+ // into 4x8x2 again. In presense of dynamic dimensions one can only
183+ // verify "equality" when there is only one dynamic dimension present,
184+ // and all other static dimensions are equal.
185+ ArrayRef<int64_t > collapsedStaticShapes = staticSourceSize.slice (
186+ collapseReassociation.front (), collapseReassociation.size ());
187+ int64_t numCollapsedDynamic =
188+ llvm::count_if (collapsedStaticShapes,
189+ [](int64_t d) { return ShapedType::isDynamic (d); });
190+ ArrayRef<int64_t > expandedStaticShapes = staticResultSize.slice (
191+ expandReassociation.front (), expandReassociation.size ());
192+ int64_t numExpandedDynamic =
193+ llvm::count_if (expandedStaticShapes,
194+ [](int64_t d) { return ShapedType::isDynamic (d); });
195+ if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
196+ collapsedStaticShapes != expandedStaticShapes) {
197+ return failure ();
198+ }
199+ continue ;
200+ }
201+ // If the reassociations are not same, one or the other needs to be of
202+ // size one.
174203 if (collapseReassociation.size () != 1 && expandReassociation.size () != 1 )
175204 return failure ();
176205 }
177206
178207 // Compute new reassociation indices and expanded/collaped shapes.
179208 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
180209 Location loc = expandOp->getLoc ();
181- SmallVector<OpFoldResult> collapseSizes =
210+ SmallVector<OpFoldResult> sourceSizes =
182211 tensor::getMixedSizes (rewriter, loc, collapseOp.getSrc ());
183- SmallVector<OpFoldResult> expandSizes (getMixedValues (
184- expandOp.getStaticOutputShape (), expandOp.getOutputShape (), rewriter));
212+ SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape ();
185213 SmallVector<OpFoldResult> newExpandSizes;
186- int64_t index = 0 , expandIndex = 0 , collapseIndex = 0 ;
187- for (auto [idx, collapseReassociation] : llvm::enumerate (collapseReInds)) {
214+
215+ int64_t newExpandIndex = 0 , newCollapseIndex = 0 , sourceSizeIndex = 0 ,
216+ resultSizeIndex = 0 ;
217+
218+ for (size_t idx = 0 , idxEnd = collapseReInds.size (); idx < idxEnd; idx++) {
219+ auto &collapseReassociation = collapseReInds[idx];
220+ auto &expandReassociation = expandReInds[idx];
221+
222+ // Case 1. The reassociations are same in the collapse producer
223+ // and expand consumer. In the swapped expand, each of the final
224+ // dimensions are kept as is in the expand and the collapse. So,
225+ // for every element in the `ReassocationIndices` vector add a new
226+ // `ReassociationIndices` vector for the swapped expand and collapse
227+ // (of size 1).
228+ if (collapseReassociation.size () == expandReassociation.size ()) {
229+ for (size_t i = 0 ; i < collapseReassociation.size (); ++i) {
230+ newCollapseReInds.push_back ({newCollapseIndex++});
231+ newExpandReInds.push_back ({newExpandIndex++});
232+ newExpandSizes.push_back (resultSizes[resultSizeIndex++]);
233+ sourceSizeIndex++;
234+ }
235+ continue ;
236+ }
237+
238+ // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
239+ // in the expand is of size == 1). In this case, the original dimensions
240+ // are preserved on expansion and collapsed subsequently.
188241 if (collapseReassociation.size () != 1 ) {
189242 ReassociationIndices newCollapseReassociation;
190243 for (size_t i = 0 ; i < collapseReassociation.size (); ++i) {
191- newCollapseReassociation.push_back (index );
192- newExpandReInds.push_back ({index ++});
193- newExpandSizes.push_back (collapseSizes[collapseIndex ++]);
244+ newCollapseReassociation.push_back (newCollapseIndex++ );
245+ newExpandReInds.push_back ({newExpandIndex ++});
246+ newExpandSizes.push_back (sourceSizes[sourceSizeIndex ++]);
194247 }
248+ resultSizeIndex++;
195249 newCollapseReInds.push_back (newCollapseReassociation);
196- expandIndex++;
197250 continue ;
198251 }
252+
253+ // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
254+ // in the collapse is of size == 1). In this case, the expansion happens
255+ // first and the expanded dimensions are preserved on collapse.
199256 ReassociationIndices newExpandReassociation;
200- auto expandReassociation = expandReInds[idx];
201257 for (size_t i = 0 ; i < expandReassociation.size (); ++i) {
202- newExpandReassociation.push_back (index );
203- newCollapseReInds.push_back ({index ++});
204- newExpandSizes.push_back (expandSizes[expandIndex ++]);
258+ newExpandReassociation.push_back (newExpandIndex++ );
259+ newCollapseReInds.push_back ({newCollapseIndex ++});
260+ newExpandSizes.push_back (resultSizes[resultSizeIndex ++]);
205261 }
206262 newExpandReInds.push_back (newExpandReassociation);
207- collapseIndex ++;
263+ sourceSizeIndex ++;
208264 }
209265
210266 // Swap reshape order.
211267 SmallVector<Value> dynamicSizes;
212268 SmallVector<int64_t > staticSizes;
213269 dispatchIndexOpFoldResults (newExpandSizes, dynamicSizes, staticSizes);
214270 auto expandResultType = expandOp.getResultType ().clone (staticSizes);
215- auto newExpand = rewriter.create <tensor::ExpandShapeOp>(
216- loc, expandResultType, collapseOp.getSrc (), newExpandReInds,
217- newExpandSizes);
218- rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
219- expandOp, newExpand.getResult (), newCollapseReInds);
271+ Value newCollapseSrc = collapseOp.getSrc ();
272+ // If the number of reassociation indices in the new `expand_shape` op
273+ // matches the number of dimensions of the result, then the expand_shape
274+ // is a no-op.
275+ if (newExpandReInds.size () != newExpandSizes.size ()) {
276+ newCollapseSrc = rewriter.create <tensor::ExpandShapeOp>(
277+ loc, expandResultType, newCollapseSrc, newExpandReInds,
278+ newExpandSizes);
279+ }
280+
281+ // If the number of reassociation indices in the new `collapse_shape` op
282+ // matches the number of dimensions of the source, then the collapse_shape
283+ // is a no-op.
284+ Value replacement = newCollapseSrc;
285+ if (newCollapseReInds.size () != newExpandSizes.size ()) {
286+ replacement = rewriter.create <tensor::CollapseShapeOp>(
287+ loc, newCollapseSrc, newCollapseReInds);
288+ }
289+ rewriter.replaceOp (expandOp, replacement);
220290 return success ();
221291 }
222292};
0 commit comments