Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class InsertTableOpsPass(ExportPass):
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
exir_ops.edge.aten.tanh.default: torch.tanh,
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
}

def __init__(self, exported_program: ExportedProgram) -> None:
Expand Down
36 changes: 34 additions & 2 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,40 @@ def ops_to_not_decompose(
self,
ep: ExportedProgram,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
ops_to_not_decompose_if_quant_op = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: why another list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This list is for ops where we specifically only want to not decompose if it's quantized. Also, another operator will be coming in a follow on commit, so a list is needed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you are using the same list in the filter_fn, sounds good, false alarm.

torch.ops.aten.hardsigmoid.default,
]

def filter_fn(node: torch.fx.Node) -> bool:
# This function filters for operators to not decompose where:
# - It's target is in ops_to_not_decompose_if_quant_op list.
# - All it's inputs/outputs are quantize operators.
dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default
q = torch.ops.quantized_decomposed.quantize_per_tensor.default

if node.target in ops_to_not_decompose_if_quant_op:
# Assume we should not decompose the operator (it is quantized)
should_not_decompose = True

input_nodes = node.all_input_nodes
ouput_nodes = node.users

for inp in input_nodes:
if inp.target != dq:
should_not_decompose = False

for out in ouput_nodes:
if out.target != q:
should_not_decompose = False

return should_not_decompose

# Be default, do not decompose the operator
return True

ops_to_not_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.upsample_nearest2d.vec,
]
return (ops_to_not_decompose, None)
] + ops_to_not_decompose_if_quant_op

return (ops_to_not_decompose, filter_fn)
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.hardsigmoid.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.div.Tensor,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _match_pattern(
torch.ops.aten.sigmoid.default,
torch.ops.aten.tanh.default,
torch.ops.aten.sum.dim_IntList,
torch.ops.aten.hardsigmoid.default,
]

_one_to_one_shared_input_qspec = [
Expand Down
128 changes: 128 additions & 0 deletions backends/arm/test/ops/test_hardsigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from typing import Tuple

import pytest
import torch

from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.exir.backend.compile_spec_schema import CompileSpec
from parameterized import parameterized


test_data_suite = [
# (test_name, test_data)
("zeros", torch.zeros(1, 10, 10, 10)),
("ones", torch.ones(10, 10, 10)),
("rand", torch.rand(10, 10) - 0.5),
("randn_pos", torch.randn(10) + 10),
("randn_neg", torch.randn(10) - 10),
("ramp", torch.arange(-16, 16, 0.2)),
]


class TestHardsigmoid(unittest.TestCase):
class Hardsigmoid(torch.nn.Module):
def __init__(self):
super().__init__()
self.hardsigmoid = torch.nn.Hardsigmoid()

def forward(self, x):
return self.hardsigmoid(x)

def _test_hardsigmoid_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
)
.export()
.check(["torch.ops.aten.hardsigmoid.default"])
.check_not(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_hardsigmoid_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
)
.quantize()
.export()
.check(["torch.ops.aten.hardsigmoid.default"])
.check(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_hardsigmoid_tosa_ethos_BI_pipeline(
self,
compile_spec: list[CompileSpec],
module: torch.nn.Module,
test_data: Tuple[torch.tensor],
):
tester = (
ArmTester(
module,
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize()
.export()
.check_count({"torch.ops.aten.hardsigmoid.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_not(["executorch_exir_dialects_edge__ops_aten_clamp_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.serialize()
)
if conftest.is_option_enabled("corstone_fvp"):
tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)

@parameterized.expand(test_data_suite)
def test_hardsigmoid_tosa_MI(
self,
test_name: str,
test_data: torch.Tensor,
):
self._test_hardsigmoid_tosa_MI_pipeline(self.Hardsigmoid(), (test_data,))

@parameterized.expand(test_data_suite)
def test_hardsigmoid_tosa_BI(self, test_name: str, test_data: torch.Tensor):
self._test_hardsigmoid_tosa_BI_pipeline(self.Hardsigmoid(), (test_data,))

@parameterized.expand(test_data_suite)
@pytest.mark.corstone_fvp
def test_hardsigmoid_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
self._test_hardsigmoid_tosa_ethos_BI_pipeline(
common.get_u55_compile_spec(), self.Hardsigmoid(), (test_data,)
)

@parameterized.expand(test_data_suite)
@pytest.mark.corstone_fvp
def test_hardsigmoid_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor):
self._test_hardsigmoid_tosa_ethos_BI_pipeline(
common.get_u85_compile_spec(), self.Hardsigmoid(), (test_data,)
)
Loading