@@ -550,7 +550,7 @@ struct BubbleUpCollapseShapeThroughExtractSlice
550550 enumerate(collapseShapeOp.getReassociationIndices ())) {
551551 OpFoldResult collapsedSize = collapsedSizes[groupIdx];
552552 OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
553- // Case #1 - size and/or offset are dynamic.
553+ // CASE #1 - size and/or offset are dynamic.
554554 // In this case, the slice can be represented as a contiguous slice only
555555 // if there is a single dimension in the reassociation group that has a
556556 // size not equal to 1.
@@ -576,16 +576,24 @@ struct BubbleUpCollapseShapeThroughExtractSlice
576576 continue ;
577577 }
578578
579- // Case #2 = size and offset are static.
579+ // CASE #2 = size and offset are static.
580580 // Verify that the slice can be represented as a contiguous slice of the
581581 // src of the collapse_shape.
582582 // Checking this is done on order of most internal dimensions first,
583583 // so traversal is done in reverse order of the reassociation group.
584584 // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
585585 // ...,An] then we first find the size and offset for n...k+1 then for k
586586 // and then for k-1...0.
587- int64_t collapsedSizeValue = getConstantIntValue (collapsedSize).value ();
588- int64_t collapsedOffsetValue =
587+
588+ // currentCollapsedsize and currentCollapsedOffset are initialized with
589+ // the original collapsed size and offset and divided by the expanded
590+ // shape size in each dimension as we go along the reassociation group.
591+ // In essence we are spreading the original collapsed size and offset over
592+ // the various expanded slice dimensions.
593+ // The variables are used both to check the validity of the slice and to
594+ // compute the expanded sizes and offsets.
595+ int64_t currentCollapsedsize = getConstantIntValue (collapsedSize).value ();
596+ int64_t currentCollapsedOffset =
589597 getConstantIntValue (collapsedOffset).value ();
590598
591599 SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
@@ -600,49 +608,55 @@ struct BubbleUpCollapseShapeThroughExtractSlice
600608 for (; idx < reassocGroupSize; ++idx) {
601609 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
602610
603- if (collapsedSizeValue < expandedShapeSize)
611+ if (currentCollapsedsize < expandedShapeSize)
604612 break ;
605613
606614 // We need to make sure that the slice size can be set to the shape size
607615 // and the offset to 0.
608- if ((collapsedSizeValue % expandedShapeSize) != 0 ||
609- (collapsedOffsetValue % expandedShapeSize) != 0 )
616+ if ((currentCollapsedsize % expandedShapeSize) != 0 ||
617+ (currentCollapsedOffset % expandedShapeSize) != 0 )
610618 return rewriter.notifyMatchFailure (
611619 sliceOp, " unsupported: cannot be extracted as a contiguous slice "
612620 " of the src of the collapse_shape" );
613621
614622 groupExpandedSizes.push_back (rewriter.getIndexAttr (expandedShapeSize));
615623 groupExpandedOffsets.push_back (rewriter.getIndexAttr (0 ));
616624
617- collapsedSizeValue /= expandedShapeSize;
618- collapsedOffsetValue /= expandedShapeSize;
625+ currentCollapsedsize /= expandedShapeSize;
626+ currentCollapsedOffset /= expandedShapeSize;
619627 }
620628
621629 // Now handle the first dim where slicing occurs on (k).
622630 if (idx < reassocGroupSize) {
623631 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
624- int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
632+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
625633 // We need to make sure that the slice size in this dim + offset will
626634 // not exceed the shape size.
627- if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize)
635+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize)
628636 return rewriter.notifyMatchFailure (
629637 sliceOp, " unsupported: slice cannot be extracted as a contiguous "
630638 " slice of the src of the collapse_shape" );
631639
632- groupExpandedSizes.push_back (rewriter.getIndexAttr (collapsedSizeValue));
640+ groupExpandedSizes.push_back (
641+ rewriter.getIndexAttr (currentCollapsedsize));
633642 groupExpandedOffsets.push_back (rewriter.getIndexAttr (offsetInDim));
634643
635- collapsedOffsetValue /= expandedShapeSize;
644+ currentCollapsedOffset /= expandedShapeSize;
636645 }
637646
638647 // Now handle the leading dimensions where the slice size is equal to 1
639648 // (k-1...0).
649+ // The size for these dimensions must be 1 because of how we constructed
650+ // the slice size of the expanded shape. We spread the original collapsed
651+ // size over the expanded shape sizes until we reached dimension k where
652+ // the remaining size was smaller than the expanded shape size, and spread
653+ // the remaining size on it. So, now we are left with only 1s.
640654 for (idx++; idx < reassocGroupSize; ++idx) {
641655 int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
642- int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
656+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
643657 groupExpandedSizes.push_back (rewriter.getIndexAttr (1 ));
644658 groupExpandedOffsets.push_back (rewriter.getIndexAttr (offsetInDim));
645- collapsedOffsetValue /= expandedShapeSize;
659+ currentCollapsedOffset /= expandedShapeSize;
646660 }
647661
648662 expandedSizes.append (groupExpandedSizes.rbegin (),
0 commit comments