Skip to content

Commit f5d1617

Browse files
Lallapalloozaaokblast
authored andcommitted
[mlir][tosa] Stop folding pad into avg_pool2d (llvm#164599)
Keep explicit padding ahead of tosa.avg_pool2d to preserve semantics. Folding a pad into the op drops padded values from the average divisor.
1 parent a2e8b48 commit f5d1617

File tree

3 files changed

+4
-34
lines changed

3 files changed

+4
-34
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
108108
LogicalResult verifyOutputZeroPoint(int64_t zp);
109109
}];
110110

111-
let hasCanonicalizer = 1;
112111
let hasVerifier = 1;
113112

114113
let assemblyFormat =

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,6 @@ namespace {
7575
template <typename OpTy>
7676
struct PoolPadFoldAdaptor;
7777

78-
template <>
79-
struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
80-
using OpTy = tosa::AvgPool2dOp;
81-
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
82-
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
83-
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
84-
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
85-
return false;
86-
return true;
87-
}
88-
static bool checkPadConstCompliance(OpTy op, Value padConst) {
89-
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
90-
}
91-
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
92-
Value padInput, ArrayRef<int64_t> newPad) {
93-
rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
94-
op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
95-
op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
96-
op.getAccType());
97-
}
98-
};
99-
10078
template <>
10179
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
10280
using OpTy = tosa::MaxPool2dOp;
@@ -245,13 +223,6 @@ struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
245223
};
246224
} // namespace
247225

248-
void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
249-
MLIRContext *context) {
250-
results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
251-
PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
252-
context);
253-
}
254-
255226
void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
256227
MLIRContext *context) {
257228
results.add<

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ func.func @test_argmax_fold_i64_index(%arg0: tensor<1xi8>) -> tensor<i64> {
1818

1919
// -----
2020

21-
// CHECK-LABEL: @pad_wh_avg_pool2d_fold
22-
func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
23-
// CHECK-NOT: tosa.pad
21+
// CHECK-LABEL: @pad_wh_avg_pool2d_nofold
22+
func.func @pad_wh_avg_pool2d_nofold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
23+
// CHECK: tosa.pad
2424
// CHECK: tosa.avg_pool2d
25-
// CHECK-SAME: pad = array<i64: 1, 1, 1, 1>
25+
// CHECK-SAME: pad = array<i64: 0, 1, 0, 1>
2626
%pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
2727
%pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
2828
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>

0 commit comments

Comments
 (0)