1212#include " mlir/Dialect/Tensor/Transforms/Transforms.h"
1313#include " mlir/IR/PatternMatch.h"
1414#include " mlir/Interfaces/ValueBoundsOpInterface.h"
15+ #include " llvm/ADT/STLExtras.h"
1516#include " llvm/Support/Debug.h"
1617#include " llvm/Support/LogicalResult.h"
18+ #include < algorithm>
1719
1820using namespace mlir ;
1921using namespace mlir ::tensor;
@@ -428,6 +430,190 @@ struct BubbleUpExpandShapeThroughExtractSlice
428430 }
429431};
430432
433+ // / Converts `tensor.collapse_shape(tensor.extract_slice)` to
434+ // / `tensor.extract_slice(tensor.collapse_shape)`.
435+ // /
436+ // / For this transformation to be possible, the slice must be representable as a
437+ // / contiguous slice within each reassociation group of the src.
438+ // /
439+ // / In case the size and offset extracted are static then this is possible if
440+ // / the following conditions are met:
441+ // / Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn]
442+ // / be the shape of a desired slice. A slice of shape S can be extracted as a
443+ // / contiguous block of memory if and only if there exists an index k in {0, 1,
444+ // / ..., n} such that:
445+ // / S_i = 1 for all i < k (that is, all leading dimensions are singleton),
446+ // / 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
447+ // / one dimension),
448+ // / S_i = A_i for all i > k (that is, all trailing dimensions are preserved
449+ // / in full).
450+ // / In other words, the slice shape S must be of the form:
451+ // / [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
452+ // /
453+ // / In case the size and/or offset extracted are dynamic then this is possible
454+ // / only if there is single dimension in the reassociation group that has a size
455+ // / not equal to 1.
456+ // / In other words, the tensor shape must be of the form:
457+ // / [ 1, 1, ..., 1, A, 1, ...,1 ]
458+ // / Note - it might be possible to enable this pattern for more cases when the
459+ // / size/offset are dynamic via performing an analysis of the possible values
460+ // / that could be given to the size/offset.
461+ // /
462+ // / Example:
463+ // / The transformation is possible because each reassociation group can be
464+ // / represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
465+ // / [20->10]).
466+ // / ```
467+ // / BEFORE:
468+ // / %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
469+ // / tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
470+ // / %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
471+ // / tensor<128x7x20xf32> to tensor<32x?x10xf32>
472+ // /
473+ // / AFTER:
474+ // / %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
475+ // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
476+ // / %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
477+ // / tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
478+ // / ```
479+ struct BubbleUpCollapseShapeThroughExtractSlice
480+ : public OpRewritePattern<tensor::ExtractSliceOp> {
481+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
482+
483+ LogicalResult matchAndRewrite (tensor::ExtractSliceOp sliceOp,
484+ PatternRewriter &rewriter) const override {
485+ auto collapseShapeOp =
486+ sliceOp.getSource ().getDefiningOp <tensor::CollapseShapeOp>();
487+ if (!collapseShapeOp)
488+ return rewriter.notifyMatchFailure (
489+ sliceOp,
490+ " tensor.extract_slice source not produced by tensor.collapse_shape" );
491+
492+ if (!sliceOp.hasUnitStride ()) {
493+ return rewriter.notifyMatchFailure (
494+ sliceOp, " unsupported: non-unit stride. Only contiguous slices can "
495+ " be supported in this transformation." );
496+ }
497+
498+ // The tensor.extract_slice before applying the pattern works on the result
499+ // of the tensor.collapse_shape, so variables (i.e. inputs for
500+ // ExtractSliceOp) referring to the state before applying the pattern are
501+ // named with the prefix "collapsed", and ones referring to the state after
502+ // applying the pattern are named with the prefix "expanded".
503+ SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets ();
504+ SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes ();
505+
506+ if (static_cast <size_t >(sliceOp.getResultType ().getRank ()) !=
507+ collapsedSizes.size ())
508+ return rewriter.notifyMatchFailure (sliceOp,
509+ " unimplemented: rank reducing slice" );
510+
511+ ArrayRef<int64_t > srcShape = collapseShapeOp.getSrcType ().getShape ();
512+ SmallVector<ReassociationIndices, 4 > reassociationIndices =
513+ collapseShapeOp.getReassociationIndices ();
514+
515+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
516+ // The new tensor.extract_slice will work on a tensor that has has a rank
517+ // equal to the rank of the src of the collapse_shape. In each iteration of
518+ // the loop, the offsets and sizes will be computed per reassociation group.
519+ SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
520+ SmallVector<OpFoldResult> expandedStrides (srcShape.size (),
521+ rewriter.getIndexAttr (1 ));
522+
523+ for (auto [groupIdx, reassocIndices] :
524+ enumerate(collapseShapeOp.getReassociationIndices ())) {
525+ OpFoldResult collapsedSize = collapsedSizes[groupIdx];
526+ OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
527+ // Case #1 - size and/or offset are dynamic.
528+ // In this case, the slice can be represented as a contiguous slice only
529+ // if there is a single dimension in the reassociation group that has a
530+ // size not equal to 1.
531+ if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
532+ int nonUnitSizeCount = 0 ;
533+ for (int64_t expandedShapeIdx : reassocIndices) {
534+ if (srcShape[expandedShapeIdx] != 1 ) {
535+ nonUnitSizeCount++;
536+ expandedSizes.emplace_back (collapsedSize);
537+ expandedOffsets.emplace_back (collapsedOffset);
538+ continue ;
539+ }
540+
541+ expandedSizes.emplace_back (rewriter.getIndexAttr (1 ));
542+ expandedOffsets.emplace_back (rewriter.getIndexAttr (0 ));
543+ }
544+
545+ if (nonUnitSizeCount != 1 ) {
546+ return rewriter.notifyMatchFailure (
547+ sliceOp,
548+ " unsupported: slice cannot be verified to be contiguous" );
549+ }
550+ continue ;
551+ }
552+
553+ // Case #2 = size and offset are static.
554+ // Verify that the slice can be represented as a contiguous slice of the
555+ // src of the collapse_shape.
556+ // Checking this must be done on order of most
557+ // internal dimensions first, so traversal is done in reverse order of the
558+ // reassociation group.
559+ int64_t collapsedSizeValue = getConstantIntValue (collapsedSize).value ();
560+ int64_t collapsedOffsetValue =
561+ getConstantIntValue (collapsedOffset).value ();
562+
563+ SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
564+
565+ for (int64_t expandedShapeIdx : llvm::reverse (reassocIndices)) {
566+ int64_t expandedShapeSize = srcShape[expandedShapeIdx];
567+
568+ // This is a dimension that slicing will occur on, so need to make sure
569+ // that the slice size can be set to the shape size and the offset to 0.
570+ if (collapsedSizeValue >= expandedShapeSize &&
571+ (collapsedSizeValue % expandedShapeSize != 0 ||
572+ collapsedOffsetValue % expandedShapeSize != 0 )) {
573+ return rewriter.notifyMatchFailure (
574+ sliceOp, " unsupported: cannot be extracted as a contiguous slice "
575+ " of the src of the collapse_shape" );
576+ }
577+
578+ int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
579+
580+ // This is the dimension that slicing will occur along, so need to make
581+ // sure that the slice size + offset will not exceed the shape size.
582+ if (collapsedSizeValue < expandedShapeSize &&
583+ (collapsedSizeValue + offsetInDim) >= expandedShapeSize) {
584+ return rewriter.notifyMatchFailure (
585+ sliceOp, " unsupported: slice cannot be extracted as a contiguous "
586+ " slice of the src of the collapse_shape" );
587+ }
588+
589+ groupExpandedSizes.emplace_back (rewriter.getIndexAttr (
590+ std::min (collapsedSizeValue, expandedShapeSize)));
591+ groupExpandedOffsets.emplace_back (rewriter.getIndexAttr (offsetInDim));
592+
593+ // Remove the size and offset of trailing dimensions from the size and
594+ // offset of the slice.
595+ collapsedSizeValue /= expandedShapeSize;
596+ collapsedSizeValue = std::max<int64_t >(collapsedSizeValue, 1 );
597+ collapsedOffsetValue /= expandedShapeSize;
598+ }
599+
600+ expandedSizes.append (groupExpandedSizes.rbegin (),
601+ groupExpandedSizes.rend ());
602+ expandedOffsets.append (groupExpandedOffsets.rbegin (),
603+ groupExpandedOffsets.rend ());
604+ }
605+
606+ Value newSliceOp = rewriter.create <tensor::ExtractSliceOp>(
607+ collapseShapeOp->getLoc (), collapseShapeOp.getSrc (), expandedOffsets,
608+ expandedSizes, expandedStrides);
609+ rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(
610+ sliceOp, sliceOp.getResultType (), newSliceOp,
611+ collapseShapeOp.getReassociationIndices ());
612+
613+ return success ();
614+ }
615+ };
616+
431617} // namespace
432618
433619void mlir::tensor::populateReassociativeReshapeFoldingPatterns (
@@ -448,5 +634,6 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
448634
449635void mlir::tensor::populateBubbleUpExtractSliceOpPatterns (
450636 RewritePatternSet &patterns) {
451- patterns.add <BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext ());
637+ patterns.add <BubbleUpExpandShapeThroughExtractSlice,
638+ BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext ());
452639}
0 commit comments