@@ -403,61 +403,58 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
403
403
srcType.getRank () < dstType.getRank () ||
404
404
parentSrcType.getRank () == dstType.getRank ())
405
405
return failure ();
406
+
406
407
// Check if the result tensor_reshape after folding the reshapeOp and
407
408
// parentReshapeOp are combined.
408
409
// If the final tensor_reshape is folding, the parentReshapeOp is
409
410
// introducing unit-dims, and the reshapeOp does an actual reshape.
410
- // If the final tensor_reshape op is expanding, the reshapeOp is introducing
411
- // unit-dims, and the parentReshapeOp does an actual reshape.
411
+ // If the final tensor_reshape op is expanding, the reshapeOp is
412
+ // introducing unit-dims, and the parentReshapeOp does an actual reshape.
412
413
bool isFoldingPattern = parentSrcType.getRank () > dstType.getRank ();
413
- auto reassociationMaps = isFoldingPattern
414
- ? reshapeOp.getReassociationMaps ()
415
- : parentReshapeOp.getReassociationMaps ();
416
- DenseSet<unsigned > conservedDimensions;
417
- for (auto &map : reassociationMaps) {
418
- if (map.getNumResults () == 1 ) {
419
- conservedDimensions.insert (
420
- map.getResult (0 ).cast <AffineDimExpr>().getPosition ());
421
- }
422
- }
423
-
424
- // Find positions at which the unit-dims exist.
425
- int64_t nonUnitDimPos = 0 ;
426
- DenseMap<unsigned , unsigned > nonUnitSrcDims;
427
- ArrayRef<int64_t > nonUnitShape =
414
+ ArrayRef<int64_t > expandedShape =
428
415
isFoldingPattern ? parentSrcType.getShape () : dstType.getShape ();
429
- for (auto shape : enumerate(srcType.getShape ())) {
430
- // Case 1 : It is a conserved dimension.
431
- if (conservedDimensions.count (shape.index ())) {
432
- nonUnitSrcDims[shape.index ()] = nonUnitDimPos++;
433
- continue ;
416
+ ArrayRef<int64_t > foldedShape =
417
+ isFoldingPattern ? dstType.getShape () : parentSrcType.getShape ();
418
+
419
+ unsigned expandedDim = 0 , foldedDim = 0 ;
420
+ SmallVector<SmallVector<AffineExpr, 4 >, 4 > reassociationExprs (
421
+ foldedShape.size ());
422
+ while (expandedDim < expandedShape.size () &&
423
+ foldedDim < foldedShape.size ()) {
424
+ int64_t dstSize = foldedShape[foldedDim];
425
+ int64_t srcSize = expandedShape[expandedDim];
426
+ while (srcSize < dstSize && expandedDim < expandedShape.size ()) {
427
+ reassociationExprs[foldedDim].push_back (
428
+ rewriter.getAffineDimExpr (expandedDim++));
429
+ srcSize *= expandedShape[expandedDim];
434
430
}
435
- // Case 2 : Dimensions dont match but the intermediate tensor is unit-dim.
436
- if (shape.value () == 1 )
437
- continue ;
438
- // Case 3 : Dimensions match, treat it as a non-unit src dim.
439
- if (nonUnitDimPos < static_cast <int64_t >(nonUnitShape.size ()) &&
440
- nonUnitShape[nonUnitDimPos] == shape.value ()) {
441
- nonUnitSrcDims[shape.index ()] = nonUnitDimPos++;
442
- continue ;
431
+ if (srcSize == dstSize) {
432
+ reassociationExprs[foldedDim].push_back (
433
+ rewriter.getAffineDimExpr (expandedDim++));
434
+ // If the next dim in foldedShape is not 1, treat subsequent dims in
435
+ // expandedShape which are 1 to be collapsed.
436
+ if (foldedDim == foldedShape.size () - 1 ||
437
+ foldedShape[foldedDim + 1 ] != 1 ) {
438
+ while (expandedDim < expandedShape.size () &&
439
+ expandedShape[expandedDim] == 1 ) {
440
+ reassociationExprs[foldedDim].push_back (
441
+ rewriter.getAffineDimExpr (expandedDim++));
442
+ }
443
+ }
444
+ } else {
445
+ return failure ();
443
446
}
444
- return failure () ;
447
+ foldedDim++ ;
445
448
}
449
+ if (expandedDim != expandedShape.size ())
450
+ return failure ();
446
451
447
- // Compute reassociation maps for the final operation. Use the reassociation
448
- // maps that is actually doing a reshape (and not just introducing
449
- // unit-dims). From these maps, prune the unit-extent dimensions.
450
- for (AffineMap &map : reassociationMaps) {
451
- SmallVector<AffineExpr, 4 > exprs;
452
- exprs.reserve (nonUnitSrcDims.size ());
453
- for (auto result : map.getResults ()) {
454
- unsigned dim = result.cast <AffineDimExpr>().getPosition ();
455
- if (nonUnitSrcDims.count (dim))
456
- exprs.push_back (rewriter.getAffineDimExpr (nonUnitSrcDims[dim]));
457
- }
458
- map = AffineMap::get (nonUnitSrcDims.size (), 0 , exprs,
459
- rewriter.getContext ());
460
- }
452
+ SmallVector<AffineMap, 4 > reassociationMaps =
453
+ llvm::to_vector<4 >(llvm::map_range (
454
+ reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
455
+ return AffineMap::get (expandedShape.size (), 0 , exprs,
456
+ rewriter.getContext ());
457
+ }));
461
458
rewriter.replaceOpWithNewOp <TensorReshapeOp>(
462
459
reshapeOp, dstType, parentReshapeOp.src (),
463
460
rewriter.getAffineMapArrayAttr (reassociationMaps));
0 commit comments