diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 3e512847109..6eb1dcbef72 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -4,6 +4,15 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe +"""Provide a partitioner for delegating subgraphs to the TOSA backend. + +Implement logic to identify and tag regions of an ``ExportedProgram`` that can +be delegated to the TOSA backend. Use this module to: + +- Partition graphs based on operator support and additional checks. +- Prune trivial no-op partitions that would lower to empty TOSA graphs. +- Tag constant data and report reasons for rejected nodes. +""" import logging from typing import Callable, List, Optional, Sequence, Tuple @@ -34,14 +43,46 @@ def is_noop_clone(node: torch.fx.node.Node) -> bool: + """Return True if the node is a no-op ``dim_order_ops._clone_dim_order``. + + Args: + node (torch.fx.Node): FX node to inspect. + + Returns: + bool: True if the node targets ``dim_order_ops._clone_dim_order.default`` + in the Edge dialect; otherwise, False. + + """ return node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default def is_noop_alias_copy(node: torch.fx.Node) -> bool: + """Return True if the node is a no-op ``aten.alias_copy``. + + Args: + node (torch.fx.Node): FX node to inspect. + + Returns: + bool: True if the node targets ``aten.alias_copy.default``; otherwise, + False. + + """ return node.target == exir_ops.edge.aten.alias_copy.default def is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool: + """Return True if node is a no-op ``dim_order_ops._to_dim_order_copy``. + + Consider the op a no-op when the output dtype equals the input's dtype. + + Args: + node (torch.fx.Node): FX node to inspect. + + Returns: + bool: True if it targets ``_to_dim_order_copy.default`` and preserves + dtype; otherwise, False. + + """ if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default: return False else: @@ -49,6 +90,19 @@ def is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool: def is_noop_expand(node: torch.fx.node.Node) -> bool: + """Return True if the node is an ``expand_copy`` with all-ones multiples. + + This corresponds to a semantic no-op, since expanding by 1 along every + dimension leaves the tensor unchanged. + + Args: + node (torch.fx.Node): FX node to inspect. + + Returns: + bool: True if the node targets ``aten.expand_copy.default`` and all + computed multiples are 1; otherwise, False. + + """ if node.target != exir_ops.edge.aten.expand_copy.default: return False else: @@ -57,11 +111,30 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool: class TOSAPartitioner(Partitioner): + """Partition an exported program into TOSA-delegable subgraphs. + + Construct this partitioner for compile specs targeting TOSA. The partition + algorithm uses capability checks and optional additional operator-support + rules to tag nodes with a delegation tag per subgraph. + """ + def __init__( self, compile_spec: TosaCompileSpec, additional_checks: Optional[Sequence[OperatorSupportBase]] = None, ) -> None: + """Initialize the TOSAPartitioner. + + Args: + compile_spec (TosaCompileSpec): Parsed compile specifications for + TOSA containing the TOSA spec and original list. + additional_checks (Optional[Sequence[OperatorSupportBase]]): Extra + operator-support checks to apply when partitioning. + + Raises: + RuntimeError: If the provided compile spec does not target TOSA. + + """ self.delegation_spec = DelegationSpec( TOSABackend.__name__, compile_spec.to_list() ) @@ -70,9 +143,22 @@ def __init__( self.tosa_spec = compile_spec.tosa_spec def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa - # Run the CapabilityBasedPartitioner to return the largest possible - # subgraphs containing the nodes with the tags + """Partition the program and tag TOSA-compatible subgraphs. + + Run the FX capability-based partitioner to propose subgraphs, then + refine tags by removing boundary-only quantize/dequantize nodes and by + rejecting partitions that would lower to no-ops. Emit a detailed report + of rejected nodes and their reasons. + + Args: + exported_program (ExportedProgram): Program to analyze and + partition. + + Returns: + PartitionResult: The input program with nodes tagged for delegation + and a mapping of partition tags to delegation specs. + """ logger.info("TOSAPartitioner::partition") partition_tags: dict[str, DelegationSpec] = {} @@ -92,6 +178,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no partition_list = capability_partitioner.propose_partitions() def reject_partition(reason: str, partition, tag) -> None: + """Remove a proposed partition and record the rejection reason. + + Args: + reason (str): Human-readable explanation for rejection. + partition (object): Proposed partition object from the + capability partitioner. + tag (str): Delegation tag associated with the partition. + + """ for node in partition.nodes: if "delegation_tag" in node.meta: del node.meta["delegation_tag"] @@ -105,6 +200,16 @@ def reject_partition(reason: str, partition, tag) -> None: tag = f"tag{partition.id}" def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: + """Return True if the node currently belongs to the partition ``tag``. + + Args: + node (torch.fx.Node): FX node to check. + tag (str): Delegation tag identifying the partition. + + Returns: + bool: True if the node carries the matching delegation tag. + + """ return ( "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag ) @@ -113,8 +218,8 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: node.meta["delegation_tag"] = tag partition_tags[tag] = self.delegation_spec - # De-tag outmost q-nodes upwards and dq-nodes downwards. - # De-tag if at least one input/ output is not part of partition. + # De-tag outermost q-nodes upwards and dq-nodes downwards. + # De-tag if at least one input/output is not part of the partition. for node in exported_program.graph_module.graph.nodes: if not is_partitioned(node): continue @@ -175,15 +280,41 @@ def ops_to_not_decompose( self, ep: ExportedProgram, ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """Return operators and a filter that should not be decomposed. + + Provide a base set of ops to preserve as-is and a predicate that keeps + certain activations whole when surrounded by quantize/dequantize ops in + a quantized graph. This helps downstream TOSA lowering and delegation. + + Args: + ep (ExportedProgram): Program used to infer target-specific policy. + + Returns: + Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + A list of op overloads to keep intact, and an optional filter + function that returns True when an op should not be decomposed. + + """ ops_to_not_decompose_if_quant_op = [ torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, ] def filter_fn(node: torch.fx.Node) -> bool: - # This function filters for operators to not decompose where: - # - It's target is in ops_to_not_decompose_if_quant_op list. - # - All it's inputs/outputs are quantize operators. + """Return True to keep selected ops intact inside quantized regions. + + The predicate holds when the target is in + ``ops_to_not_decompose_if_quant_op`` and all inputs/outputs are + quantize/dequantize ops, indicating a quantized activation that + should not be decomposed. + + Args: + node (torch.fx.Node): FX node to evaluate. + + Returns: + bool: True to keep the op intact; otherwise, False. + + """ dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default q = torch.ops.quantized_decomposed.quantize_per_tensor.default @@ -204,7 +335,7 @@ def filter_fn(node: torch.fx.Node) -> bool: return should_not_decompose - # Be default, do not decompose the operator + # By default, do not decompose the operator return True ops_to_not_decompose = [