@@ -731,9 +731,62 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
731731 }
732732};
733733
734+ // Update size operand of tosa.slice if size has dynamic dims but corresponding
735+ // output dim is static
736+ struct SliceDynamicSizeCanonicalization
737+ : public OpRewritePattern<tosa::SliceOp> {
738+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
739+
740+ LogicalResult matchAndRewrite (tosa::SliceOp sliceOp,
741+ PatternRewriter &rewriter) const override {
742+ ShapedType resultType = cast<ShapedType>(sliceOp.getType ());
743+
744+ ElementsAttr sizeElems;
745+ if (!matchPattern (sliceOp.getSize (), m_Constant (&sizeElems))) {
746+ return rewriter.notifyMatchFailure (
747+ sliceOp, " size of slice must be a static ranked shape" );
748+ }
749+
750+ llvm::SmallVector<int64_t > sliceSizes =
751+ llvm::to_vector (sizeElems.getValues <int64_t >());
752+
753+ bool replaceSliceSize{false };
754+ // if size op has -1 indicating dynamic shape but corresponding dim on the
755+ // output is statically known, update size to match with known output dim
756+ // shape
757+ for (const auto &[index, size] : llvm::enumerate (sliceSizes)) {
758+ if (size == -1 && !resultType.isDynamicDim (index)) {
759+ sliceSizes[index] = resultType.getDimSize (index);
760+ replaceSliceSize = true ;
761+ }
762+ }
763+
764+ if (!replaceSliceSize) {
765+ return rewriter.notifyMatchFailure (
766+ sliceOp, " no dimension of size of slice is dynamic that resolves "
767+ " to static output shape" );
768+ }
769+
770+ auto size_op = getTosaConstShape (rewriter, sliceOp.getLoc (), sliceSizes);
771+ auto newSliceOp = rewriter.create <tosa::SliceOp>(
772+ sliceOp.getLoc (), sliceOp.getType (), sliceOp.getInput1 (),
773+ sliceOp.getStart (), size_op);
774+
775+ rewriter.replaceOp (sliceOp, newSliceOp.getResult ());
776+
777+ // Remove const_shape size op when it no longer has use point.
778+ Operation *sizeConstShape = sliceOp.getSize ().getDefiningOp ();
779+ if (sizeConstShape->getResult (0 ).hasOneUse ())
780+ rewriter.eraseOp (sizeConstShape);
781+
782+ return success ();
783+ }
784+ };
785+
734786void SliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
735787 MLIRContext *context) {
736- results.add <ConcatSliceOptimization>(context);
788+ results.add <ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
789+ context);
737790}
738791
739792// ===----------------------------------------------------------------------===//
0 commit comments