@@ -763,8 +763,19 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
763763 sliceOp, " slice input must be a static ranked tensor" );
764764 int32_t axis = concatOp.getAxis ();
765765
766- llvm::SmallVector<int64_t > sliceStart (sliceOp.getStart ());
767- llvm::ArrayRef<int64_t > sliceSize = sliceOp.getSize ();
766+ llvm::SmallVector<int64_t > sliceStart;
767+ if (!tosa::getConstShapeValue (sliceOp.getStart ().getDefiningOp (),
768+ sliceStart)) {
769+ return rewriter.notifyMatchFailure (
770+ sliceOp, " slice start must be a constant shape" );
771+ }
772+
773+ llvm::SmallVector<int64_t > sliceSize;
774+ if (!tosa::getConstShapeValue (sliceOp.getSize ().getDefiningOp (),
775+ sliceSize)) {
776+ return rewriter.notifyMatchFailure (sliceOp,
777+ " slice size must be a constant shape" );
778+ }
768779 llvm::SmallVector<Value> requiredConcatInputs;
769780 int64_t processedOriginalConcatInputSize = 0 ;
770781 int64_t droppedConcatInputSize = 0 ;
@@ -803,8 +814,8 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
803814 concatOp->getLoc (), requiredConcatInputs, axis);
804815 auto newSlice = rewriter.create <tosa::SliceOp>(
805816 sliceOp->getLoc (), sliceOp.getType (), newConcat,
806- rewriter. getDenseI64ArrayAttr ( sliceStart),
807- rewriter. getDenseI64ArrayAttr ( sliceSize));
817+ getTosaConstShape ( rewriter, sliceOp. getStart (). getLoc (), sliceStart),
818+ getTosaConstShape ( rewriter, sliceOp. getSize (). getLoc (), sliceSize));
808819 rewriter.replaceOp (sliceOp, newSlice);
809820 return success ();
810821 }
@@ -839,8 +850,21 @@ struct TileSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
839850 SmallVector<int64_t > tileMultiplies;
840851 const LogicalResult tileHasConstantMultiplies =
841852 tileOp.getConstantMultiples (tileMultiplies);
853+ llvm::SmallVector<int64_t > sliceStartShape;
854+ if (!tosa::getConstShapeValue (sliceOp.getStart ().getDefiningOp (),
855+ sliceStartShape)) {
856+ return rewriter.notifyMatchFailure (
857+ sliceOp, " slice start must be a constant shape" );
858+ }
859+
860+ llvm::SmallVector<int64_t > sliceSizeShape;
861+ if (!tosa::getConstShapeValue (sliceOp.getSize ().getDefiningOp (),
862+ sliceSizeShape)) {
863+ return rewriter.notifyMatchFailure (sliceOp,
864+ " slice size must be a constant shape" );
865+ }
842866 for (auto [axis, sliceStart, sliceSize] :
843- llvm::enumerate (sliceOp. getStart (), sliceOp. getSize () )) {
867+ llvm::enumerate (sliceStartShape, sliceSizeShape )) {
844868 if (sliceSize <= 0 ) {
845869 return rewriter.notifyMatchFailure (
846870 sliceOp, " degenerate slice with zero sized dim" );
@@ -878,16 +902,61 @@ struct TileSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
878902 tileOp->getOperand (0 ), constantShapeValue);
879903 auto newSlice = rewriter.create <tosa::SliceOp>(
880904 sliceOp->getLoc (), sliceOp.getType (), newTile,
881- rewriter.getDenseI64ArrayAttr (newTileStarts), sliceOp.getSizeAttr ());
905+ getTosaConstShape (rewriter, sliceOp.getStart ().getLoc (), newTileStarts),
906+ sliceOp.getSize ());
882907 rewriter.replaceOp (sliceOp, newSlice);
883908 return success ();
884909 }
885910};
886911
912+ // This pattern fuses consecutive slice operations into a single slice
913+ struct SliceSliceOptimization : public OpRewritePattern <tosa::SliceOp> {
914+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
915+ LogicalResult matchAndRewrite (tosa::SliceOp sliceOp,
916+ PatternRewriter &rewriter) const override {
917+
918+ auto precedingSliceOp = sliceOp.getInput1 ().getDefiningOp <SliceOp>();
919+ if (!precedingSliceOp)
920+ return failure ();
921+ SmallVector<int64_t > precedingSliceStart;
922+ if (!tosa::getConstShapeValue (precedingSliceOp.getStart ().getDefiningOp (),
923+ precedingSliceStart)) {
924+ return rewriter.notifyMatchFailure (
925+ sliceOp, " preceding slice start must be a constant shape" );
926+ }
927+ SmallVector<int64_t > thisSliceStart;
928+ if (!tosa::getConstShapeValue (sliceOp.getStart ().getDefiningOp (),
929+ thisSliceStart)) {
930+ return rewriter.notifyMatchFailure (
931+ sliceOp, " slice start must be a constant shape" );
932+ }
933+ SmallVector<int64_t > newSliceStart;
934+ newSliceStart.reserve (precedingSliceStart.size ());
935+ for (auto [startPreceding, startThis] :
936+ llvm::zip_equal (precedingSliceStart, thisSliceStart)) {
937+ newSliceStart.push_back (startPreceding + startThis);
938+ }
939+ Value newStartConst = getTosaConstShape (
940+ rewriter,
941+ rewriter.getFusedLoc ({sliceOp.getStart ().getLoc (),
942+ precedingSliceOp.getStart ().getLoc ()}),
943+ newSliceStart);
944+ rewriter.modifyOpInPlace (sliceOp, [&]() {
945+ sliceOp.getInput1Mutable ().assign (precedingSliceOp.getInput1 ());
946+ sliceOp.getStartMutable ().assign (newStartConst);
947+ sliceOp->setLoc (rewriter.getFusedLoc (
948+ {precedingSliceOp->getLoc (), sliceOp->getLoc ()}));
949+ });
950+
951+ return success ();
952+ }
953+ };
954+
887955void SliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
888956 MLIRContext *context) {
889957 results.add <ConcatSliceOptimization>(context);
890958 results.add <TileSliceOptimization>(context);
959+ results.add <SliceSliceOptimization>(context);
891960}
892961
893962struct MinToClampOptimization : public OpRewritePattern <tosa::MinimumOp> {
@@ -1525,30 +1594,6 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
15251594}
15261595
15271596OpFoldResult SliceOp::fold (FoldAdaptor adaptor) {
1528- const auto tryFoldWithPrecedingSlice = [this ](FoldAdaptor adaptor) {
1529- auto precedingSliceOp = getInput1 ().getDefiningOp <SliceOp>();
1530- if (!precedingSliceOp)
1531- return failure ();
1532- const auto precedingSliceStart = precedingSliceOp.getStart ();
1533- const auto thisSliceStart = getStart ();
1534- SmallVector<int64_t > newSliceStart;
1535- newSliceStart.reserve (precedingSliceStart.size ());
1536- for (auto [startPreceding, startThis] :
1537- llvm::zip_equal (precedingSliceStart, thisSliceStart)) {
1538- newSliceStart.push_back (startPreceding + startThis);
1539- }
1540- setOperand (precedingSliceOp->getOperand (0 ));
1541- setStart (newSliceStart);
1542- getOperation ()->setLoc (
1543- FusedLoc::get (getContext (), {precedingSliceOp->getLoc (), getLoc ()}));
1544- return success ();
1545- };
1546-
1547- // First try folding the preceding slice, this also works if the shapes are
1548- // dynamic
1549- if (succeeded (tryFoldWithPrecedingSlice (adaptor)))
1550- return getResult ();
1551-
15521597 auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1 ().getType ());
15531598 auto outputTy = llvm::dyn_cast<RankedTensorType>(getType ());
15541599
@@ -1573,7 +1618,12 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
15731618
15741619 if (inputTy.hasStaticShape () && outputTy.hasStaticShape () &&
15751620 outputTy.getNumElements () == 1 ) {
1576- llvm::SmallVector<uint64_t > indices (getStart ());
1621+ DenseElementsAttr startElems;
1622+ if (!matchPattern (getStart (), m_Constant (&startElems)))
1623+ return {};
1624+
1625+ llvm::SmallVector<uint64_t > indices =
1626+ llvm::to_vector (startElems.getValues <uint64_t >());
15771627 auto value = operand.getValues <Attribute>()[indices];
15781628 return SplatElementsAttr::get (outputTy, value);
15791629 }
0 commit comments