Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
LogicalResult verifyOutputZeroPoint(int64_t zp);
}];

let hasCanonicalizer = 1;
let hasVerifier = 1;

let assemblyFormat =
Expand Down
29 changes: 0 additions & 29 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,6 @@ namespace {
template <typename OpTy>
struct PoolPadFoldAdaptor;

template <>
struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
using OpTy = tosa::AvgPool2dOp;
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
const llvm::ArrayRef<int64_t> kernel = op.getKernel();
if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
return false;
return true;
}
static bool checkPadConstCompliance(OpTy op, Value padConst) {
return checkMatchingPadConstAndZp(padConst, op.getInputZp());
}
static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
Value padInput, ArrayRef<int64_t> newPad) {
rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
op.getAccType());
}
};

template <>
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
using OpTy = tosa::MaxPool2dOp;
Expand Down Expand Up @@ -245,13 +223,6 @@ struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
};
} // namespace

void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
context);
}

void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {

// -----

// CHECK-LABEL: @pad_wh_avg_pool2d_fold
func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
// CHECK-NOT: tosa.pad
// CHECK-LABEL: @pad_wh_avg_pool2d_nofold
func.func @pad_wh_avg_pool2d_nofold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
// CHECK: tosa.pad
// CHECK: tosa.avg_pool2d
// CHECK-SAME: pad = array<i64: 1, 1, 1, 1>
// CHECK-SAME: pad = array<i64: 0, 1, 0, 1>
%pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
%pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32>
Expand Down