diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 697a04e94441a..137554f49460d 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -108,7 +108,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> { LogicalResult verifyOutputZeroPoint(int64_t zp); }]; - let hasCanonicalizer = 1; let hasVerifier = 1; let assemblyFormat = diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 99b7cda49094e..a85ff10aa0d73 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -75,28 +75,6 @@ namespace { template struct PoolPadFoldAdaptor; -template <> -struct PoolPadFoldAdaptor { - using OpTy = tosa::AvgPool2dOp; - static bool checkKernelCompliance(OpTy op, const ArrayRef newPad) { - const llvm::ArrayRef kernel = op.getKernel(); - if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] || - newPad[0] >= kernel[0] || newPad[1] >= kernel[0]) - return false; - return true; - } - static bool checkPadConstCompliance(OpTy op, Value padConst) { - return checkMatchingPadConstAndZp(padConst, op.getInputZp()); - } - static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op, - Value padInput, ArrayRef newPad) { - rewriter.replaceOpWithNewOp( - op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(), - op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad), - op.getAccType()); - } -}; - template <> struct PoolPadFoldAdaptor { using OpTy = tosa::MaxPool2dOp; @@ -245,13 +223,6 @@ struct FoldPadToTensorOp : public OpRewritePattern { }; } // namespace -void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add>>( - context); -} - void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add< diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 7574afa215e78..5a40f3fa8572c 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -18,11 +18,11 @@ func.func @test_argmax_fold_i64_index(%arg0: tensor<1xi8>) -> tensor { // ----- -// CHECK-LABEL: @pad_wh_avg_pool2d_fold -func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> { - // CHECK-NOT: tosa.pad +// CHECK-LABEL: @pad_wh_avg_pool2d_nofold +func.func @pad_wh_avg_pool2d_nofold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> { + // CHECK: tosa.pad // CHECK: tosa.avg_pool2d - // CHECK-SAME: pad = array + // CHECK-SAME: pad = array %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>