From 3898792064216bace126d70862ec25be9617a5d2 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Sun, 13 Apr 2025 13:19:10 -0400 Subject: [PATCH 1/2] Revert "Revert "[tosa]: canonicalize dynamic size of tosa.slice to static output shape" (#135525)" This reverts commit d6e2aee9b1069b4a5fc1a0b07aef23b380f856f6. --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 55 ++++++++++++++++++- mlir/test/Dialect/Tosa/canonicalize.mlir | 15 +++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c4ef7d0bb9ff5..84f89bfd7f2d3 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -731,9 +731,62 @@ struct ConcatSliceOptimization : public OpRewritePattern { } }; +// Update size operand of tosa.slice if size has dynamic dims but corresponding +// output dim is static +struct SliceDynamicSizeCanonicalization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const override { + ShapedType resultType = cast(sliceOp.getType()); + + ElementsAttr sizeElems; + if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) { + return rewriter.notifyMatchFailure( + sliceOp, "size of slice must be a static ranked shape"); + } + + llvm::SmallVector sliceSizes = + llvm::to_vector(sizeElems.getValues()); + + bool replaceSliceSize{false}; + // if size op has -1 indicating dynamic shape but corresponding dim on the + // output is statically known, update size to match with known output dim + // shape + for (const auto &[index, size] : llvm::enumerate(sliceSizes)) { + if (size == -1 && !resultType.isDynamicDim(index)) { + sliceSizes[index] = resultType.getDimSize(index); + replaceSliceSize = true; + } + } + + if (!replaceSliceSize) { + return rewriter.notifyMatchFailure( + sliceOp, "no dimension of size of slice is dynamic that resolves " + "to static output shape"); + } + + auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes); + auto newSliceOp = rewriter.create( + sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(), + sliceOp.getStart(), size_op); + + rewriter.replaceOp(sliceOp, newSliceOp.getResult()); + + // Remove const_shape size op when it no longer has use point. + Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); + if (sizeConstShape->getResult(0).hasOneUse()) + rewriter.eraseOp(sizeConstShape); + + return success(); + } +}; + void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index b366b4f1e4fd4..a754a46be603f 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1212,3 +1212,18 @@ func.func @do_not_fold_intdiv_division_by_0() -> tensor<1x24x2xi32> { %16 = tosa.intdiv %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32> return %16 : tensor<1x24x2xi32> } + + +// ---- +// CHECK-LABEL: func.func @slice_dynamic_size_static_output_canonicalize( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> { +// CHECK: %[[START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[SIZE:.*]] = tosa.const_shape {values = dense<[2, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[SLICE:.*]] = tosa.slice %[[ARG0]], %[[START]], %[[SIZE]] : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32> +// CHECK: return %[[SLICE]] +func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> { + %0 = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> + %1 = tosa.const_shape {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32> + return %2 : tensor<2x60x58x?xf32> + } From e49d7749294386df3c4d0aedff81e51e04d3417d Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Sun, 13 Apr 2025 16:09:41 -0400 Subject: [PATCH 2/2] [tosa] : Re-enable PR #135429 with ASAN fix. --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 6 ------ mlir/test/Dialect/Tosa/canonicalize.mlir | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 84f89bfd7f2d3..47368532df169 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -773,12 +773,6 @@ struct SliceDynamicSizeCanonicalization sliceOp.getStart(), size_op); rewriter.replaceOp(sliceOp, newSliceOp.getResult()); - - // Remove const_shape size op when it no longer has use point. - Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); - if (sizeConstShape->getResult(0).hasOneUse()) - rewriter.eraseOp(sizeConstShape); - return success(); } }; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index a754a46be603f..d153474593d80 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1214,7 +1214,7 @@ func.func @do_not_fold_intdiv_division_by_0() -> tensor<1x24x2xi32> { } -// ---- +// ----- // CHECK-LABEL: func.func @slice_dynamic_size_static_output_canonicalize( // CHECK-SAME: %[[ARG0:.*]]: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> { // CHECK: %[[START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>