@@ -429,17 +429,19 @@ struct BubbleUpExpandShapeThroughExtractSlice
429429 }
430430};
431431
432- // / Converts `tensor.collapse_shape (tensor.extract_slice )` to
433- // / `tensor.extract_slice (tensor.collapse_shape )`.
432+ // / Converts `tensor.extract_slice (tensor.collapse_shape )` to
433+ // / `tensor.collapse_shape (tensor.extract_slice )`.
434434// /
435- // / For this transformation to be possible, the slice must be representable as a
436- // / contiguous slice within each reassociation group of the src.
435+ // / For this transformation to be possible - after bubbling up, the extraction
436+ // / of the contiguous slice must be representable as a single slice obtained via
437+ // / tensor.extract_slice within each reassociation group of the src.
437438// /
438439// / In case the size and offset extracted are static then this is possible if
439- // / the following conditions are met:
440- // / Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn]
441- // / be the shape of a desired slice. A slice of shape S can be extracted as a
442- // / contiguous block of memory if and only if there exists an index k in {0, 1,
440+ // / the following conditions are met within each reassociation group:
441+ // / Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
442+ // / dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
443+ // / shape of a desired slice. A slice of shape S can be extracted as a
444+ // / contiguous span of elements if and only if there exists an index k in {0, 1,
443445// / ..., n} such that:
444446// / S_i = 1 for all i < k (that is, all leading dimensions are singleton),
445447// / 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
@@ -475,6 +477,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
475477// / %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
476478// / tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
477479// / ```
480+ // /
481+ // / Negative example:
482+ // / The transformation is not possible because we cannot use a single slice to
483+ // / represent the reassociation group [2x3x10->???]. If we would want the
484+ // / collapse to be after the extraction, we would need to extract multiple
485+ // / slices and concat them together.
486+ // / ```
487+ // / %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
488+ // / tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
489+ // / tensor<60xf32> to tensor<15xf32>
490+ // / ```
491+ // / If we would want the collapse to be after the extraction, a possible
492+ // / alternate transformation could be to extract multiple slices and concat them
493+ // / together:
494+ // / ```
495+ // / %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
496+ // / tensor<2x3x10xf32> to tensor <1x1x10xf32>
497+ // / %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
498+ // / tensor<2x3x10xf32> to tensor <1x1x5xf32>
499+ // / %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
500+ // / (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
501+ // / %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
502+ // / to tensor<15xf32>
503+ // / ```
504+ // / But this is not the intended purpose of the transformation.
478505struct BubbleUpCollapseShapeThroughExtractSlice
479506 : public OpRewritePattern<tensor::ExtractSliceOp> {
480507 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
@@ -552,47 +579,69 @@ struct BubbleUpCollapseShapeThroughExtractSlice
552579 // Case #2 = size and offset are static.
553580 // Verify that the slice can be represented as a contiguous slice of the
554581 // src of the collapse_shape.
555- // Checking this must be done on order of most
556- // internal dimensions first, so traversal is done in reverse order of the
557- // reassociation group.
582+ // Checking this is done on order of most internal dimensions first,
583+ // so traversal is done in reverse order of the reassociation group.
584+ // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
585+ // ...,An] then we first find the size and offset for n...k+1 then for k
586+ // and then for k-1...0.
558587 int64_t collapsedSizeValue = getConstantIntValue (collapsedSize).value ();
559588 int64_t collapsedOffsetValue =
560589 getConstantIntValue (collapsedOffset).value ();
561590
562591 SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
563592
564- for (int64_t expandedShapeIdx : llvm::reverse (reassocIndices)) {
565- int64_t expandedShapeSize = srcShape[expandedShapeIdx];
593+ ReassociationIndices reversedReassocIndices (reassocIndices.rbegin (),
594+ reassocIndices.rend ());
595+ int64_t idx = 0 ;
596+ int64_t reassocGroupSize = reassocIndices.size ();
597+
598+ // First handle the trailing dimensions where the slice size should be
599+ // equal to the tensor shape and the offset should be 0 (n...k+1).
600+ for (; idx < reassocGroupSize; ++idx) {
601+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
566602
567- // This is a dimension that slicing will occur on, so need to make sure
568- // that the slice size can be set to the shape size and the offset to 0.
569- if (collapsedSizeValue >= expandedShapeSize &&
570- (collapsedSizeValue % expandedShapeSize != 0 ||
571- collapsedOffsetValue % expandedShapeSize != 0 )) {
603+ if (collapsedSizeValue < expandedShapeSize)
604+ break ;
605+
606+ // We need to make sure that the slice size can be set to the shape size
607+ // and the offset to 0.
608+ if ((collapsedSizeValue % expandedShapeSize) != 0 ||
609+ (collapsedOffsetValue % expandedShapeSize) != 0 )
572610 return rewriter.notifyMatchFailure (
573611 sliceOp, " unsupported: cannot be extracted as a contiguous slice "
574612 " of the src of the collapse_shape" );
575- }
576613
577- int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
614+ groupExpandedSizes.push_back (rewriter.getIndexAttr (expandedShapeSize));
615+ groupExpandedOffsets.push_back (rewriter.getIndexAttr (0 ));
616+
617+ collapsedSizeValue /= expandedShapeSize;
618+ collapsedOffsetValue /= expandedShapeSize;
619+ }
578620
579- // This is the dimension that slicing will occur along, so need to make
580- // sure that the slice size + offset will not exceed the shape size.
581- if (collapsedSizeValue < expandedShapeSize &&
582- (collapsedSizeValue + offsetInDim) >= expandedShapeSize) {
621+ // Now handle the first dim where slicing occurs on (k).
622+ if (idx < reassocGroupSize) {
623+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
624+ int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
625+ // We need to make sure that the slice size in this dim + offset will
626+ // not exceed the shape size.
627+ if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize)
583628 return rewriter.notifyMatchFailure (
584629 sliceOp, " unsupported: slice cannot be extracted as a contiguous "
585630 " slice of the src of the collapse_shape" );
586- }
587631
588- groupExpandedSizes.push_back (rewriter.getIndexAttr (
589- std::min (collapsedSizeValue, expandedShapeSize)));
632+ groupExpandedSizes.push_back (rewriter.getIndexAttr (collapsedSizeValue));
590633 groupExpandedOffsets.push_back (rewriter.getIndexAttr (offsetInDim));
591634
592- // Remove the size and offset of trailing dimensions from the size and
593- // offset of the slice.
594- collapsedSizeValue /= expandedShapeSize;
595- collapsedSizeValue = std::max<int64_t >(collapsedSizeValue, 1 );
635+ collapsedOffsetValue /= expandedShapeSize;
636+ }
637+
638+ // Now handle the leading dimensions where the slice size is equal to 1
639+ // (k-1...0).
640+ for (idx++; idx < reassocGroupSize; ++idx) {
641+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
642+ int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
643+ groupExpandedSizes.push_back (rewriter.getIndexAttr (1 ));
644+ groupExpandedOffsets.push_back (rewriter.getIndexAttr (offsetInDim));
596645 collapsedOffsetValue /= expandedShapeSize;
597646 }
598647
0 commit comments