@@ -731,6 +731,127 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
731731 }
732732};
733733
734+ struct PadSliceOptimization : public OpRewritePattern <tosa::SliceOp> {
735+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
736+
737+ LogicalResult matchAndRewrite (tosa::SliceOp sliceOp,
738+ PatternRewriter &rewriter) const override {
739+ Value sliceInput = sliceOp.getInput1 ();
740+
741+ // Check if producer is a PadOp
742+ auto padOp = sliceInput.getDefiningOp <tosa::PadOp>();
743+ if (!padOp)
744+ return rewriter.notifyMatchFailure (sliceOp,
745+ " slice input must be a pad operation" );
746+
747+ // Check PadOp has a single consumer
748+ if (!padOp->hasOneUse ())
749+ return rewriter.notifyMatchFailure (sliceOp,
750+ " pad shall have a single consumer" );
751+
752+ // Check input is statically ranked
753+ auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1 ().getType ());
754+ auto padTy = dyn_cast<RankedTensorType>(padOp.getType ());
755+ if (!inputTy || !padTy)
756+ return rewriter.notifyMatchFailure (
757+ sliceOp, " slice input must be a static ranked tensor" );
758+
759+ // Validate and extract tosa::PadOp padding
760+ DenseIntElementsAttr paddingElems;
761+ if (!matchPattern (padOp.getPadding (), m_Constant (&paddingElems))) {
762+ return rewriter.notifyMatchFailure (
763+ sliceOp,
764+ " The `padding` input specified on the tosa::PadOp must be constant." );
765+ }
766+ llvm::SmallVector<int64_t > padPaddings =
767+ llvm::to_vector (paddingElems.getValues <int64_t >());
768+
769+ // Extract slice parameters
770+ DenseElementsAttr startElems;
771+ if (!matchPattern (sliceOp.getStart (), m_Constant (&startElems)))
772+ return rewriter.notifyMatchFailure (
773+ sliceOp, " start of slice must be a static ranked shape" );
774+ llvm::SmallVector<int64_t > sliceStarts =
775+ llvm::to_vector (startElems.getValues <int64_t >());
776+
777+ DenseElementsAttr sizeElems;
778+ if (!matchPattern (sliceOp.getSize (), m_Constant (&sizeElems)))
779+ return rewriter.notifyMatchFailure (
780+ sliceOp, " size of slice must be a static ranked shape" );
781+ llvm::SmallVector<int64_t > sliceSizes =
782+ llvm::to_vector (sizeElems.getValues <int64_t >());
783+
784+ // Update the paddings
785+ int64_t rank = inputTy.getRank ();
786+ llvm::SmallVector<int64_t > newSliceStarts (rank, 0 );
787+ llvm::SmallVector<int64_t > newPadPaddings (2 * rank, 0 );
788+ llvm::SmallVector<int64_t > newPadShape (rank, 0 );
789+ bool updated = false ;
790+ for (int64_t i = 0 ; i < rank; ++i) {
791+ const int64_t padLo = padPaddings[i * 2 ];
792+ const int64_t padHi = padPaddings[i * 2 + 1 ];
793+ const int64_t sliceStart = sliceStarts[i];
794+ const int64_t sliceSize = sliceSizes[i];
795+ const int64_t sliceEnd = sliceStart + sliceSize;
796+
797+ const int64_t dimSize = inputTy.getShape ()[i];
798+ const int64_t dimStart = padLo;
799+ const int64_t dimEnd = padLo + dimSize;
800+ const int64_t dimTotal = padLo + dimSize + padHi;
801+
802+ // Check slice within bounds
803+ if (sliceStart < 0 || sliceEnd > dimTotal)
804+ return rewriter.notifyMatchFailure (sliceOp, " slice out-of-bounds" );
805+
806+ const int64_t newPadLo = std::max<int64_t >(padLo - sliceStart, 0 );
807+ const int64_t newPadHi =
808+ std::max<int64_t >(sliceEnd - (padLo + dimSize), 0 );
809+ const int64_t newSliceStart = std::max<int64_t >(sliceStart - padLo, 0 );
810+
811+ // Compute update slice/pad parameters
812+ if (sliceStart < dimStart || sliceEnd > dimEnd) {
813+ // Handle slice when not within the original input entirely
814+ updated |= (newPadLo != padLo) || (newPadHi != padHi) ||
815+ (newSliceStart != sliceStart);
816+ newPadPaddings[i * 2 ] = newPadLo;
817+ newPadPaddings[i * 2 + 1 ] = newPadHi;
818+ newSliceStarts[i] = newSliceStart;
819+ } else {
820+ // Slice is within the original input
821+ updated |= newSliceStart != sliceStart;
822+ newSliceStarts[i] = newSliceStart;
823+ }
824+
825+ // Calculate new pad output shape
826+ newPadShape[i] =
827+ newPadPaddings[i * 2 ] + dimSize + newPadPaddings[i * 2 + 1 ];
828+ }
829+
830+ // Check that we actually need to proceed with the rewrite
831+ if (!updated)
832+ return rewriter.notifyMatchFailure (
833+ sliceOp, " terminate condition; nothing to rewrite" );
834+
835+ // Create a PadOp with updated padding
836+ auto newPaddingsOp =
837+ getTosaConstShape (rewriter, sliceOp.getLoc (), newPadPaddings);
838+ auto newPadTy =
839+ RankedTensorType::get (newPadShape, inputTy.getElementType ());
840+ auto newPadOp = rewriter.create <tosa::PadOp>(
841+ padOp.getLoc (), newPadTy, padOp.getInput1 (), newPaddingsOp,
842+ padOp.getPadConst ());
843+
844+ // Update SliceOp and point to new PadOp
845+ auto newStartOp =
846+ getTosaConstShape (rewriter, sliceOp.getLoc (), newSliceStarts);
847+ rewriter.replaceOpWithNewOp <tosa::SliceOp>(sliceOp, sliceOp.getType (),
848+ newPadOp.getResult (), newStartOp,
849+ sliceOp.getSize ());
850+
851+ return success ();
852+ }
853+ };
854+
734855// Update size operand of tosa.slice if size has dynamic dims but corresponding
735856// output dim is static
736857struct SliceDynamicSizeCanonicalization
@@ -779,8 +900,8 @@ struct SliceDynamicSizeCanonicalization
779900
780901void SliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
781902 MLIRContext *context) {
782- results.add <ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
783- context);
903+ results.add <ConcatSliceOptimization, PadSliceOptimization,
904+ SliceDynamicSizeCanonicalization>( context);
784905}
785906
786907// ===----------------------------------------------------------------------===//
0 commit comments