@@ -687,8 +687,8 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
687687 sliceOp, " degenerate slice with zero sized dim in output" );
688688 }
689689 sliceStart[axis] -= droppedConcatInputSize;
690- auto newConcat = rewriter.create <tosa::ConcatOp>(concatOp-> getLoc (),
691- requiredConcatInputs, axis);
690+ auto newConcat = rewriter.create <tosa::ConcatOp>(
691+ concatOp-> getLoc (), requiredConcatInputs, axis);
692692 auto newSlice = rewriter.create <tosa::SliceOp>(
693693 sliceOp->getLoc (), sliceOp.getType (), newConcat,
694694 rewriter.getDenseI64ArrayAttr (sliceStart),
@@ -698,9 +698,75 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
698698 }
699699};
700700
701+ // / This patterns adjust the multipliers of a tile followed by a slice to only
702+ // / tile as much data as it is required by the slice
703+ struct TileSliceOptimization : public OpRewritePattern <tosa::SliceOp> {
704+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
705+
706+ LogicalResult matchAndRewrite (tosa::SliceOp sliceOp,
707+ PatternRewriter &rewriter) const override {
708+ Value sliceInput = sliceOp.getInput1 ();
709+ auto tileOp = sliceInput.getDefiningOp <tosa::TileOp>();
710+ if (!tileOp)
711+ return rewriter.notifyMatchFailure (sliceOp,
712+ " slice input must be tile operation" );
713+ if (!tileOp->hasOneUse ())
714+ return rewriter.notifyMatchFailure (
715+ sliceOp, " preceding tile must have a single use" ); // Do not insert
716+ // additional tiles
717+
718+ const auto tileOpInputType =
719+ dyn_cast<RankedTensorType>(tileOp->getOperand (0 ).getType ());
720+ if (!tileOpInputType || !tileOpInputType.hasStaticShape ())
721+ return rewriter.notifyMatchFailure (
722+ sliceOp, " input to preceding tile op must be a static ranked tensor" );
723+ llvm::SmallVector<int64_t > requiredMultipliers;
724+ llvm::SmallVector<int64_t > newTileStarts;
725+ requiredMultipliers.reserve (tileOpInputType.getRank ());
726+ newTileStarts.reserve (tileOpInputType.getRank ());
727+ for (auto [axis, sliceStart, sliceSize] :
728+ llvm::enumerate (sliceOp.getStart (), sliceOp.getSize ())) {
729+ if (sliceSize <= 0 ) {
730+ return rewriter.notifyMatchFailure (
731+ sliceOp, " degenerate slice with zero sized dim" );
732+ }
733+ const int64_t tileInputDimSize = tileOpInputType.getDimSize (axis);
734+ const int64_t sliceOffsetInNewFirstTile = sliceStart % tileInputDimSize;
735+ const int64_t sliceSizeInFirstTile =
736+ std::min (tileInputDimSize - sliceOffsetInNewFirstTile, sliceSize);
737+ assert (sliceSizeInFirstTile > 0 );
738+ const int64_t requiredMultiplierWithoutFirstTile =
739+ llvm::divideCeil (sliceSize - sliceSizeInFirstTile, tileInputDimSize);
740+ const int64_t requiredMultiplier =
741+ requiredMultiplierWithoutFirstTile + (sliceSizeInFirstTile != 0 );
742+ assert (requiredMultiplier <= tileOp.getMultiples ()[axis]);
743+ requiredMultipliers.push_back (requiredMultiplier);
744+ newTileStarts.push_back (sliceOffsetInNewFirstTile);
745+ }
746+ if (requiredMultipliers == tileOp.getMultiples ())
747+ return rewriter.notifyMatchFailure (
748+ sliceOp, " could not reduce multipliers in preceding tile" );
749+
750+ llvm::SmallVector<int64_t > newTileShape (tileOpInputType.getShape ());
751+ for (auto [newShape, multiplier] :
752+ llvm::zip_equal (newTileShape, requiredMultipliers)) {
753+ newShape *= multiplier;
754+ }
755+ auto newTile = rewriter.create <tosa::TileOp>(
756+ tileOp->getLoc (), tileOpInputType.clone (newTileShape),
757+ tileOp->getOperand (0 ), requiredMultipliers);
758+ auto newSlice = rewriter.create <tosa::SliceOp>(
759+ sliceOp->getLoc (), sliceOp.getType (), newTile,
760+ rewriter.getDenseI64ArrayAttr (newTileStarts), sliceOp.getSizeAttr ());
761+ rewriter.replaceOp (sliceOp, newSlice);
762+ return success ();
763+ }
764+ };
765+
701766void SliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
702767 MLIRContext *context) {
703768 results.add <ConcatSliceOptimization>(context);
769+ results.add <TileSliceOptimization>(context);
704770}
705771
706772struct MinToClampOptimization : public OpRewritePattern <tosa::MinimumOp> {
0 commit comments