Skip to content
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
UnsqueezeScalarPlaceholdersPass,
)
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.pass_manager import PassManager
Expand All @@ -58,6 +59,7 @@ def transform_to_backend_pipeline(
):
"""Apply passes before transforming program to backend"""
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(RemoveGetItemPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
Expand Down
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.repeat.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
op_get_item,
op_hardtanh,
op_log,
op_max_pool2d,
op_mm,
op_mul,
op_permute,
Expand Down
77 changes: 77 additions & 0 deletions backends/arm/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import cast, List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import get_quant_node_args

from serializer.tosa_serializer import TosaOp


@register_node_visitor
class MaxPool2dVisitor(NodeVisitor):
target = "aten.max_pool2d.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:

input_tensor = inputs[0]
kernel_size = inputs[1].special
stride = inputs[2].special

try:
padding = [*inputs[3].special, *inputs[3].special]
except IndexError:
padding = [0, 0, 0, 0]

accumulator_type = input_tensor.dtype

if is_quant_node:
# Accumulator type always is int8 when input tensor is an integer type.
accumulator_type = ts.DType.INT8

# Initilize zero point to zero.
input_zp = 0
output_zp = 0

if is_quant_node:
input_zp = get_quant_node_args(
cast(torch.fx.Node, node.all_input_nodes[0])
).zp
output_zp = get_quant_node_args(list(node.users)[0]).zp

attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(
kernel=kernel_size,
stride=stride,
pad=padding,
input_zp=input_zp,
output_zp=output_zp,
accum_dtype=accumulator_type,
)

tosa_graph.addOperator(
TosaOp.Op().MAX_POOL2D,
[input_tensor.name],
[output.name],
attr,
)
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
# TODO: remove?
torch.ops.aten.adaptive_avg_pool2d.default,
torch.ops.aten.avg_pool2d.default,
torch.ops.aten.max_pool2d.default,
torch.ops.aten.full.default,
torch.ops.aten.flatten.using_ints,
torch.ops.aten.dropout.default,
Expand Down
11 changes: 11 additions & 0 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ def pytest_sessionfinish(session, exitstatus):

# ==== End of Pytest hooks =====

# ==== Custom Pytest decorators =====


def expectedFailureOnFVP(test_item):
if is_option_enabled("corstone300"):
test_item.__unittest_expecting_failure__ = True
return test_item


# ==== End of Custom Pytest decorators =====


def load_libquantized_ops_aot_lib():
so_ext = {
Expand Down
248 changes: 248 additions & 0 deletions backends/arm/test/ops/test_max_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import unittest

from typing import Tuple

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
ArmQuantizer,
get_symmetric_quantization_config,
)
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester

from executorch.backends.xnnpack.test.tester.tester import Quantize
from executorch.exir.backend.backend_details import CompileSpec
from parameterized import parameterized

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

test_data_suite = [
# (test_name, test_data, [kernel_size, stride, padding])
("zeros", torch.zeros(1, 1, 4, 8), [2, 2, 1]),
("ones", torch.ones(1, 16, 50, 32), [4, 2, 0]),
("rand", torch.rand(1, 16, 52, 16), [4, 3, 0]),
]

test_data_suite_mult_batches = [
("randn", torch.randn(5, 16, 50, 32), [4, 2, 0]),
]


class TestMaxPool2d(unittest.TestCase):
"""Tests MaxPool2d."""

class MaxPool2d(torch.nn.Module):
def __init__(
self,
kernel_size: int | Tuple[int, int],
stride: int | Tuple[int, int],
padding: int | Tuple[int, int],
):
super().__init__()
self.max_pool_2d = torch.nn.MaxPool2d(
kernel_size=kernel_size, stride=stride, padding=padding
)

def forward(self, x):
return self.max_pool_2d(x)

def _test_maxpool2d_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
)
.export()
.check(["torch.ops.aten.max_pool2d.default"])
.check_not(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"])
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
]
)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

def _test_maxpool2d_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"])
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
]
)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
)

def _test_maxpool2d_tosa_ethos_BI_pipeline(
self,
module: torch.nn.Module,
compile_spec: CompileSpec,
test_data: Tuple[torch.tensor],
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
tester = (
ArmTester(
module,
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.serialize()
)

return tester

@parameterized.expand(test_data_suite)
def test_maxpool2d_tosa_MI(
self,
test_name: str,
test_data: torch.Tensor,
model_params: int | Tuple[int, int],
):
self._test_maxpool2d_tosa_MI_pipeline(
self.MaxPool2d(*model_params), (test_data,)
)

@parameterized.expand(test_data_suite)
def test_maxpool2d_tosa_BI(
self,
test_name: str,
test_data: torch.Tensor,
model_params: int | Tuple[int, int],
):
self._test_maxpool2d_tosa_BI_pipeline(
self.MaxPool2d(*model_params), (test_data,)
)

@parameterized.expand(test_data_suite)
def test_maxpool2d_tosa_u55_BI(
self,
test_name: str,
test_data: torch.Tensor,
model_params: int | Tuple[int, int],
):
tester = self._test_maxpool2d_tosa_ethos_BI_pipeline(
self.MaxPool2d(*model_params),
common.get_u55_compile_spec(permute_memory_to_nhwc=True),
(test_data,),
)
if common.is_option_enabled("corstone300"):
tester.run_method_and_compare_outputs(
qtol=1, inputs=(test_data,), target_board="corstone-300"
)

@parameterized.expand(test_data_suite)
def test_maxpool2d_tosa_u85_BI(
self,
test_name: str,
test_data: torch.Tensor,
model_params: int | Tuple[int, int],
):
tester = self._test_maxpool2d_tosa_ethos_BI_pipeline(
self.MaxPool2d(*model_params),
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
(test_data,),
)
if common.is_option_enabled("corstone300"):
tester.run_method_and_compare_outputs(
qtol=1, inputs=(test_data,), target_board="corstone-320"
)

@parameterized.expand(test_data_suite_mult_batches)
def test_maxpool2d_tosa_MI_mult_batches(
self,
test_name: str,
test_data: torch.Tensor,
model_params: int | Tuple[int, int],
):
self._test_maxpool2d_tosa_MI_pipeline(
self.MaxPool2d(*model_params), (test_data,)
)

@parameterized.expand(test_data_suite_mult_batches)
def test_maxpool2d_tosa_BI_mult_batches(
self,
test_name: str,
test_data: torch.Tensor,
model_params: int | Tuple[int, int],
):
self._test_maxpool2d_tosa_BI_pipeline(
self.MaxPool2d(*model_params), (test_data,)
)

@parameterized.expand(test_data_suite_mult_batches)
@common.expectedFailureOnFVP # TODO: MLETORCH-433
def test_maxpool2d_tosa_u55_BI_mult_batches(
self,
test_name: str,
test_data: torch.Tensor,
model_params: int | Tuple[int, int],
):
tester = self._test_maxpool2d_tosa_ethos_BI_pipeline(
self.MaxPool2d(*model_params),
common.get_u55_compile_spec(permute_memory_to_nhwc=True),
(test_data,),
)
if common.is_option_enabled("corstone300"):
tester.run_method_and_compare_outputs(
qtol=1, inputs=(test_data,), target_board="corstone-300"
)

@parameterized.expand(test_data_suite_mult_batches)
@common.expectedFailureOnFVP # TODO: MLETORCH-433
def test_maxpool2d_tosa_u85_BI_mult_batches(
self,
test_name: str,
test_data: torch.Tensor,
model_params: int | Tuple[int, int],
):
tester = self._test_maxpool2d_tosa_ethos_BI_pipeline(
self.MaxPool2d(*model_params),
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
(test_data,),
)
if common.is_option_enabled("corstone300"):
tester.run_method_and_compare_outputs(
qtol=1, inputs=(test_data,), target_board="corstone-320"
)
Loading