Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 139 additions & 8 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
"""Provide a partitioner for delegating subgraphs to the TOSA backend.

Implement logic to identify and tag regions of an ``ExportedProgram`` that can
be delegated to the TOSA backend. Use this module to:

- Partition graphs based on operator support and additional checks.
- Prune trivial no-op partitions that would lower to empty TOSA graphs.
- Tag constant data and report reasons for rejected nodes.
"""

import logging
from typing import Callable, List, Optional, Sequence, Tuple
Expand Down Expand Up @@ -34,21 +43,66 @@


def is_noop_clone(node: torch.fx.node.Node) -> bool:
"""Return True if the node is a no-op ``dim_order_ops._clone_dim_order``.

Args:
node (torch.fx.Node): FX node to inspect.

Returns:
bool: True if the node targets ``dim_order_ops._clone_dim_order.default``
in the Edge dialect; otherwise, False.

"""
return node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default


def is_noop_alias_copy(node: torch.fx.Node) -> bool:
"""Return True if the node is a no-op ``aten.alias_copy``.

Args:
node (torch.fx.Node): FX node to inspect.

Returns:
bool: True if the node targets ``aten.alias_copy.default``; otherwise,
False.

"""
return node.target == exir_ops.edge.aten.alias_copy.default


def is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool:
"""Return True if node is a no-op ``dim_order_ops._to_dim_order_copy``.

Consider the op a no-op when the output dtype equals the input's dtype.

Args:
node (torch.fx.Node): FX node to inspect.

Returns:
bool: True if it targets ``_to_dim_order_copy.default`` and preserves
dtype; otherwise, False.

"""
if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default:
return False
else:
return node.meta.get("dtype") == get_first_fake_tensor(node.args[0]).dtype # type: ignore[arg-type]


def is_noop_expand(node: torch.fx.node.Node) -> bool:
"""Return True if the node is an ``expand_copy`` with all-ones multiples.

This corresponds to a semantic no-op, since expanding by 1 along every
dimension leaves the tensor unchanged.

Args:
node (torch.fx.Node): FX node to inspect.

Returns:
bool: True if the node targets ``aten.expand_copy.default`` and all
computed multiples are 1; otherwise, False.

"""
if node.target != exir_ops.edge.aten.expand_copy.default:
return False
else:
Expand All @@ -57,11 +111,30 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:


class TOSAPartitioner(Partitioner):
"""Partition an exported program into TOSA-delegable subgraphs.

Construct this partitioner for compile specs targeting TOSA. The partition
algorithm uses capability checks and optional additional operator-support
rules to tag nodes with a delegation tag per subgraph.
"""

def __init__(
self,
compile_spec: TosaCompileSpec,
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
) -> None:
"""Initialize the TOSAPartitioner.

Args:
compile_spec (TosaCompileSpec): Parsed compile specifications for
TOSA containing the TOSA spec and original list.
additional_checks (Optional[Sequence[OperatorSupportBase]]): Extra
operator-support checks to apply when partitioning.

Raises:
RuntimeError: If the provided compile spec does not target TOSA.

"""
self.delegation_spec = DelegationSpec(
TOSABackend.__name__, compile_spec.to_list()
)
Expand All @@ -70,9 +143,22 @@ def __init__(
self.tosa_spec = compile_spec.tosa_spec

def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa
# Run the CapabilityBasedPartitioner to return the largest possible
# subgraphs containing the nodes with the tags
"""Partition the program and tag TOSA-compatible subgraphs.

Run the FX capability-based partitioner to propose subgraphs, then
refine tags by removing boundary-only quantize/dequantize nodes and by
rejecting partitions that would lower to no-ops. Emit a detailed report
of rejected nodes and their reasons.

Args:
exported_program (ExportedProgram): Program to analyze and
partition.

Returns:
PartitionResult: The input program with nodes tagged for delegation
and a mapping of partition tags to delegation specs.

"""
logger.info("TOSAPartitioner::partition")
partition_tags: dict[str, DelegationSpec] = {}

Expand All @@ -92,6 +178,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no
partition_list = capability_partitioner.propose_partitions()

def reject_partition(reason: str, partition, tag) -> None:
"""Remove a proposed partition and record the rejection reason.

Args:
reason (str): Human-readable explanation for rejection.
partition (object): Proposed partition object from the
capability partitioner.
tag (str): Delegation tag associated with the partition.

"""
for node in partition.nodes:
if "delegation_tag" in node.meta:
del node.meta["delegation_tag"]
Expand All @@ -105,6 +200,16 @@ def reject_partition(reason: str, partition, tag) -> None:
tag = f"tag{partition.id}"

def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
"""Return True if the node currently belongs to the partition ``tag``.

Args:
node (torch.fx.Node): FX node to check.
tag (str): Delegation tag identifying the partition.

Returns:
bool: True if the node carries the matching delegation tag.

"""
return (
"delegation_tag" in node.meta and node.meta["delegation_tag"] == tag
)
Expand All @@ -113,8 +218,8 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec

# De-tag outmost q-nodes upwards and dq-nodes downwards.
# De-tag if at least one input/ output is not part of partition.
# De-tag outermost q-nodes upwards and dq-nodes downwards.
# De-tag if at least one input/output is not part of the partition.
for node in exported_program.graph_module.graph.nodes:
if not is_partitioned(node):
continue
Expand Down Expand Up @@ -175,15 +280,41 @@ def ops_to_not_decompose(
self,
ep: ExportedProgram,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
"""Return operators and a filter that should not be decomposed.

Provide a base set of ops to preserve as-is and a predicate that keeps
certain activations whole when surrounded by quantize/dequantize ops in
a quantized graph. This helps downstream TOSA lowering and delegation.

Args:
ep (ExportedProgram): Program used to infer target-specific policy.

Returns:
Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
A list of op overloads to keep intact, and an optional filter
function that returns True when an op should not be decomposed.

"""
ops_to_not_decompose_if_quant_op = [
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
]

def filter_fn(node: torch.fx.Node) -> bool:
# This function filters for operators to not decompose where:
# - It's target is in ops_to_not_decompose_if_quant_op list.
# - All it's inputs/outputs are quantize operators.
"""Return True to keep selected ops intact inside quantized regions.

The predicate holds when the target is in
``ops_to_not_decompose_if_quant_op`` and all inputs/outputs are
quantize/dequantize ops, indicating a quantized activation that
should not be decomposed.

Args:
node (torch.fx.Node): FX node to evaluate.

Returns:
bool: True to keep the op intact; otherwise, False.

"""
dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default
q = torch.ops.quantized_decomposed.quantize_per_tensor.default

Expand All @@ -204,7 +335,7 @@ def filter_fn(node: torch.fx.Node) -> bool:

return should_not_decompose

# Be default, do not decompose the operator
# By default, do not decompose the operator
return True

ops_to_not_decompose = [
Expand Down
Loading