Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions backends/arm/_passes/tag_unquantized_nodes_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
from executorch.backends.arm.tosa_quant_utils import dq_q_ops, get_neighbour_quant_args
from executorch.exir.pass_base import ExportPass, PassResult


class TagUnquantizedNodesPass(ExportPass):
"""
Pass run before partitioning to tag unquantized nodes
to ensure we don't greedily partition them for device. Unquantized operations must remain on the CPU.
"""

def is_node_quantized(self, node: torch.fx.Node) -> bool:
user_q_args, input_q_args = get_neighbour_quant_args(node)

# If there are no neighboring quantized nodes, then this node is not quantized except for constants,
# they can only have a dequantization node.
if (
len(node.all_input_nodes) > 0
and len(input_q_args) == 0
or len(user_q_args) == 0
):
return False

return True

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
# Look through operations that are not quantization or dequantization
if node.op == "call_function" and node.target not in dq_q_ops:
is_node_quantized = self.is_node_quantized(node)
if not is_node_quantized:
# For a non-quantized node, we tag the node and its inputs and outputs.
node.meta["arm_override_partition"] = False
for input_node in node.all_input_nodes:
input_node.meta["arm_override_partition"] = False
for user in node.users.keys():
user.meta["arm_override_partition"] = False

graph_module.recompile()
return PassResult(graph_module, True)
16 changes: 16 additions & 0 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self):
# TODO MLETORCH-265 Remove permute_nhwc flag
self.permute_nhwc = False
self.quantize_io = False
self.unquantized_nodes_to_cpu = False
self.tosa_version = None
self.input_order = None

Expand Down Expand Up @@ -146,6 +147,16 @@ def set_input_order(
self.input_order = input_order
return self

def set_unquantized_nodes_to_cpu(
self, unquantized_nodes_to_cpu: bool = False
) -> "ArmCompileSpecBuilder":
"""
For models with operations that are not quantized,
this option keeps the unquantized operators on the CPU.
"""
self.unquantized_nodes_to_cpu = unquantized_nodes_to_cpu
return self

def build(self) -> List[CompileSpec]:
"""
Generate a list of compile spec objects from the builder
Expand Down Expand Up @@ -185,6 +196,11 @@ def build(self) -> List[CompileSpec]:
if self.quantize_io:
self.compile_spec.append(CompileSpec("quantize_io", "True".encode()))

if self.unquantized_nodes_to_cpu:
self.compile_spec.append(
CompileSpec("unquantized_nodes_to_cpu", "True".encode())
)

return self.compile_spec


Expand Down
17 changes: 11 additions & 6 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import torch
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass
from executorch.backends.arm._passes.tag_unquantized_nodes_pass import (
TagUnquantizedNodesPass,
)
from executorch.backends.arm.operator_support.tosa_supported_operators import (
TOSASupportedOperators,
)
Expand Down Expand Up @@ -52,15 +55,17 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:

logger.info(f"Partitioning for {tosa_spec}")

passes = []
for spec in self.delegation_spec.compile_specs:
if spec.key == "quantize_io" and spec.value.decode() == "True":
# Exclude IO quantization from the partition
passes = PassManager(
passes=[
TagIOQuantPass(),
]
)
passes(exported_program.graph_module)
passes.append(TagIOQuantPass())
if spec.key == "unquantized_nodes_to_cpu" and spec.value.decode() == "True":
# Exclude unquantized nodes from the partition
passes.append(TagUnquantizedNodesPass())

passes = PassManager(passes=passes)
passes(exported_program.graph_module)

capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
Expand Down
8 changes: 8 additions & 0 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def get_u55_compile_spec(
quantize_io=False,
custom_path=None,
reorder_inputs=None,
unquantized_nodes_to_cpu=False,
) -> list[CompileSpec]:
"""
Default compile spec for Ethos-U55 tests.
Expand All @@ -102,6 +103,7 @@ def get_u55_compile_spec(
quantize_io=quantize_io,
custom_path=custom_path,
reorder_inputs=reorder_inputs,
unquantized_nodes_to_cpu=unquantized_nodes_to_cpu,
).build()


Expand All @@ -110,6 +112,7 @@ def get_u85_compile_spec(
quantize_io=False,
custom_path=None,
reorder_inputs=None,
unquantized_nodes_to_cpu=False,
) -> list[CompileSpec]:
"""
Default compile spec for Ethos-U85 tests.
Expand All @@ -119,6 +122,7 @@ def get_u85_compile_spec(
quantize_io=quantize_io,
custom_path=custom_path,
reorder_inputs=reorder_inputs,
unquantized_nodes_to_cpu=unquantized_nodes_to_cpu,
).build()


Expand All @@ -127,6 +131,7 @@ def get_u55_compile_spec_unbuilt(
quantize_io=False,
custom_path=None,
reorder_inputs=None,
unquantized_nodes_to_cpu=False,
) -> ArmCompileSpecBuilder:
"""Get the ArmCompileSpecBuilder for the Ethos-U55 tests, to modify
the compile spec before calling .build() to finalize it.
Expand All @@ -143,6 +148,7 @@ def get_u55_compile_spec_unbuilt(
extra_flags="--debug-force-regor --output-format=raw",
)
.set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
.set_unquantized_nodes_to_cpu(unquantized_nodes_to_cpu)
.set_permute_memory_format(permute_memory_to_nhwc)
.dump_intermediate_artifacts_to(artifact_path)
.set_input_order(reorder_inputs)
Expand All @@ -155,6 +161,7 @@ def get_u85_compile_spec_unbuilt(
quantize_io=False,
custom_path=None,
reorder_inputs=None,
unquantized_nodes_to_cpu=False,
) -> list[CompileSpec]:
"""Get the ArmCompileSpecBuilder for the Ethos-U85 tests, to modify
the compile spec before calling .build() to finalize it.
Expand All @@ -169,6 +176,7 @@ def get_u85_compile_spec_unbuilt(
extra_flags="--output-format=raw",
)
.set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
.set_unquantized_nodes_to_cpu(unquantized_nodes_to_cpu)
.set_permute_memory_format(permute_memory_to_nhwc)
.dump_intermediate_artifacts_to(artifact_path)
.set_input_order(reorder_inputs)
Expand Down
103 changes: 103 additions & 0 deletions backends/arm/test/passes/test_tag_unquantized_nodes_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
ArmQuantizer,
get_symmetric_quantization_config,
)

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.xnnpack.test.tester.tester import Quantize
from executorch.exir.backend.compile_spec_schema import CompileSpec


class TestModel(torch.nn.Module):

def get_inputs(self):
return (torch.rand(1, 10, 10, 10), (torch.rand(1, 10, 10, 10)))

def forward(self, x, y):
result = x + y
result = result * y
result = result * x
result = result - y
return result


class TestTagUnquantizedNodesPass(unittest.TestCase):
"""
Tests the TagUnquantizedNodesPass which tags unquantized nodes on model
to not include them in our partitions.
"""

def _tosa_BI_pipeline(
self, module: torch.nn.Module, compile_spec: list[CompileSpec]
):
quantizer = ArmQuantizer()
# Quantize only add and sub nodes
quantizer.STATIC_ANNOTATION_ORDER = [
"add",
"sub",
]
(
ArmTester(
module,
example_inputs=module.get_inputs(),
compile_spec=compile_spec,
)
.quantize(
Quantize(
quantizer,
get_symmetric_quantization_config(is_per_channel=False),
)
)
.export()
.to_edge()
.check_count(
{
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 5
}
)
.check_count(
{
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 6
}
)
.partition()
.check_count(
{
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3
}
)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
.check_count(
{
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2
}
)
)

def test_BI_u55_artifact(self):
model = TestModel()
self._tosa_BI_pipeline(
model,
common.get_u55_compile_spec(
quantize_io=True, unquantized_nodes_to_cpu=True
),
)

def test_BI_u85_artifact(self):
model = TestModel()
self._tosa_BI_pipeline(
model,
common.get_u85_compile_spec(
quantize_io=True, unquantized_nodes_to_cpu=True
),
)
Loading