@@ -731,6 +731,127 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
731
731
}
732
732
};
733
733
734
+ struct PadSliceOptimization : public OpRewritePattern <tosa::SliceOp> {
735
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
736
+
737
+ LogicalResult matchAndRewrite (tosa::SliceOp sliceOp,
738
+ PatternRewriter &rewriter) const override {
739
+ Value sliceInput = sliceOp.getInput1 ();
740
+
741
+ // Check if producer is a PadOp
742
+ auto padOp = sliceInput.getDefiningOp <tosa::PadOp>();
743
+ if (!padOp)
744
+ return rewriter.notifyMatchFailure (sliceOp,
745
+ " slice input must be a pad operation" );
746
+
747
+ // Check PadOp has a single consumer
748
+ if (!padOp->hasOneUse ())
749
+ return rewriter.notifyMatchFailure (sliceOp,
750
+ " pad shall have a single consumer" );
751
+
752
+ // Check input is statically ranked
753
+ auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1 ().getType ());
754
+ auto padTy = dyn_cast<RankedTensorType>(padOp.getType ());
755
+ if (!inputTy || !padTy)
756
+ return rewriter.notifyMatchFailure (
757
+ sliceOp, " slice input must be a static ranked tensor" );
758
+
759
+ // Validate and extract tosa::PadOp padding
760
+ DenseIntElementsAttr paddingElems;
761
+ if (!matchPattern (padOp.getPadding (), m_Constant (&paddingElems))) {
762
+ return rewriter.notifyMatchFailure (
763
+ sliceOp,
764
+ " The `padding` input specified on the tosa::PadOp must be constant." );
765
+ }
766
+ llvm::SmallVector<int64_t > padPaddings =
767
+ llvm::to_vector (paddingElems.getValues <int64_t >());
768
+
769
+ // Extract slice parameters
770
+ DenseElementsAttr startElems;
771
+ if (!matchPattern (sliceOp.getStart (), m_Constant (&startElems)))
772
+ return rewriter.notifyMatchFailure (
773
+ sliceOp, " start of slice must be a static ranked shape" );
774
+ llvm::SmallVector<int64_t > sliceStarts =
775
+ llvm::to_vector (startElems.getValues <int64_t >());
776
+
777
+ DenseElementsAttr sizeElems;
778
+ if (!matchPattern (sliceOp.getSize (), m_Constant (&sizeElems)))
779
+ return rewriter.notifyMatchFailure (
780
+ sliceOp, " size of slice must be a static ranked shape" );
781
+ llvm::SmallVector<int64_t > sliceSizes =
782
+ llvm::to_vector (sizeElems.getValues <int64_t >());
783
+
784
+ // Update the paddings
785
+ int64_t rank = inputTy.getRank ();
786
+ llvm::SmallVector<int64_t > newSliceStarts (rank, 0 );
787
+ llvm::SmallVector<int64_t > newPadPaddings (2 * rank, 0 );
788
+ llvm::SmallVector<int64_t > newPadShape (rank, 0 );
789
+ bool updated = false ;
790
+ for (int64_t i = 0 ; i < rank; ++i) {
791
+ const int64_t padLo = padPaddings[i * 2 ];
792
+ const int64_t padHi = padPaddings[i * 2 + 1 ];
793
+ const int64_t sliceStart = sliceStarts[i];
794
+ const int64_t sliceSize = sliceSizes[i];
795
+ const int64_t sliceEnd = sliceStart + sliceSize;
796
+
797
+ const int64_t dimSize = inputTy.getShape ()[i];
798
+ const int64_t dimStart = padLo;
799
+ const int64_t dimEnd = padLo + dimSize;
800
+ const int64_t dimTotal = padLo + dimSize + padHi;
801
+
802
+ // Check slice within bounds
803
+ if (sliceStart < 0 || sliceEnd > dimTotal)
804
+ return rewriter.notifyMatchFailure (sliceOp, " slice out-of-bounds" );
805
+
806
+ const int64_t newPadLo = std::max<int64_t >(padLo - sliceStart, 0 );
807
+ const int64_t newPadHi =
808
+ std::max<int64_t >(sliceEnd - (padLo + dimSize), 0 );
809
+ const int64_t newSliceStart = std::max<int64_t >(sliceStart - padLo, 0 );
810
+
811
+ // Compute update slice/pad parameters
812
+ if (sliceStart < dimStart || sliceEnd > dimEnd) {
813
+ // Handle slice when not within the original input entirely
814
+ updated |= (newPadLo != padLo) || (newPadHi != padHi) ||
815
+ (newSliceStart != sliceStart);
816
+ newPadPaddings[i * 2 ] = newPadLo;
817
+ newPadPaddings[i * 2 + 1 ] = newPadHi;
818
+ newSliceStarts[i] = newSliceStart;
819
+ } else {
820
+ // Slice is within the original input
821
+ updated |= newSliceStart != sliceStart;
822
+ newSliceStarts[i] = newSliceStart;
823
+ }
824
+
825
+ // Calculate new pad output shape
826
+ newPadShape[i] =
827
+ newPadPaddings[i * 2 ] + dimSize + newPadPaddings[i * 2 + 1 ];
828
+ }
829
+
830
+ // Check that we actually need to proceed with the rewrite
831
+ if (!updated)
832
+ return rewriter.notifyMatchFailure (
833
+ sliceOp, " terminate condition; nothing to rewrite" );
834
+
835
+ // Create a PadOp with updated padding
836
+ auto newPaddingsOp =
837
+ getTosaConstShape (rewriter, sliceOp.getLoc (), newPadPaddings);
838
+ auto newPadTy =
839
+ RankedTensorType::get (newPadShape, inputTy.getElementType ());
840
+ auto newPadOp = rewriter.create <tosa::PadOp>(
841
+ padOp.getLoc (), newPadTy, padOp.getInput1 (), newPaddingsOp,
842
+ padOp.getPadConst ());
843
+
844
+ // Update SliceOp and point to new PadOp
845
+ auto newStartOp =
846
+ getTosaConstShape (rewriter, sliceOp.getLoc (), newSliceStarts);
847
+ rewriter.replaceOpWithNewOp <tosa::SliceOp>(sliceOp, sliceOp.getType (),
848
+ newPadOp.getResult (), newStartOp,
849
+ sliceOp.getSize ());
850
+
851
+ return success ();
852
+ }
853
+ };
854
+
734
855
// Update size operand of tosa.slice if size has dynamic dims but corresponding
735
856
// output dim is static
736
857
struct SliceDynamicSizeCanonicalization
@@ -779,8 +900,8 @@ struct SliceDynamicSizeCanonicalization
779
900
780
901
void SliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
781
902
MLIRContext *context) {
782
- results.add <ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
783
- context);
903
+ results.add <ConcatSliceOptimization, PadSliceOptimization,
904
+ SliceDynamicSizeCanonicalization>( context);
784
905
}
785
906
786
907
// ===----------------------------------------------------------------------===//
0 commit comments