@@ -731,9 +731,64 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
731731 }
732732};
733733
734+ // Update size operand of tosa.slice if size has dynamic dims but corresponding
735+ // output dim is static
736+ struct SliceDynamicSizeCanonicalization
737+ : public OpRewritePattern<tosa::SliceOp> {
738+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
739+
740+ LogicalResult matchAndRewrite (tosa::SliceOp sliceOp,
741+ PatternRewriter &rewriter) const override {
742+ ShapedType resultType = cast<ShapedType>(sliceOp.getType ());
743+
744+ ElementsAttr sizeElems;
745+ if (!matchPattern (sliceOp.getSize (), m_Constant (&sizeElems))) {
746+ return rewriter.notifyMatchFailure (
747+ sliceOp, " size of slice must be a static ranked shape" );
748+ }
749+
750+ llvm::SmallVector<int64_t > sliceSizes =
751+ llvm::to_vector (sizeElems.getValues <int64_t >());
752+
753+ bool replaceSliceSize{false };
754+ // if size op has -1 indicating dynamic shape but corresponding dim on the
755+ // output is statically known, update size to match with known output dim
756+ // shape
757+ for (const auto i : llvm::enumerate (sliceSizes)) {
758+ int64_t size = i.value ();
759+ size_t index = i.index ();
760+ if (size == -1 && !resultType.isDynamicDim (index)) {
761+ sliceSizes[index] = resultType.getDimSize (index);
762+ replaceSliceSize = true ;
763+ }
764+ }
765+
766+ if (!replaceSliceSize) {
767+ return rewriter.notifyMatchFailure (
768+ sliceOp, " no dimension of size of slice is dynamic that resolves "
769+ " to static output shape" );
770+ }
771+
772+ auto size_op = getTosaConstShape (rewriter, sliceOp.getLoc (), sliceSizes);
773+ auto newSliceOp = rewriter.create <tosa::SliceOp>(
774+ sliceOp.getLoc (), sliceOp.getType (), sliceOp.getInput1 (),
775+ sliceOp.getStart (), size_op);
776+
777+ rewriter.replaceOp (sliceOp, newSliceOp.getResult ());
778+
779+ // Remove const_shape size op when it no longer has use point.
780+ Operation *sizeConstShape = sliceOp.getSize ().getDefiningOp ();
781+ if (sizeConstShape->getResult (0 ).hasOneUse ())
782+ rewriter.eraseOp (sizeConstShape);
783+
784+ return success ();
785+ }
786+ };
787+
734788void SliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
735789 MLIRContext *context) {
736- results.add <ConcatSliceOptimization>(context);
790+ results.add <ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
791+ context);
737792}
738793
739794// ===----------------------------------------------------------------------===//
0 commit comments