Skip to content

Commit 5c25493

Browse files
Arm backend: Add docstrings for tosa/partitioner.py (#14844)
Signed-off-by: Sebastian Larsson <[email protected]>
1 parent b88b09c commit 5c25493

File tree

1 file changed

+139
-8
lines changed

1 file changed

+139
-8
lines changed

backends/arm/tosa/partitioner.py

Lines changed: 139 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
"""Provide a partitioner for delegating subgraphs to the TOSA backend.
8+
9+
Implement logic to identify and tag regions of an ``ExportedProgram`` that can
10+
be delegated to the TOSA backend. Use this module to:
11+
12+
- Partition graphs based on operator support and additional checks.
13+
- Prune trivial no-op partitions that would lower to empty TOSA graphs.
14+
- Tag constant data and report reasons for rejected nodes.
15+
"""
716

817
import logging
918
from typing import Callable, List, Optional, Sequence, Tuple
@@ -34,21 +43,66 @@
3443

3544

3645
def is_noop_clone(node: torch.fx.node.Node) -> bool:
46+
"""Return True if the node is a no-op ``dim_order_ops._clone_dim_order``.
47+
48+
Args:
49+
node (torch.fx.Node): FX node to inspect.
50+
51+
Returns:
52+
bool: True if the node targets ``dim_order_ops._clone_dim_order.default``
53+
in the Edge dialect; otherwise, False.
54+
55+
"""
3756
return node.target == exir_ops.edge.dim_order_ops._clone_dim_order.default
3857

3958

4059
def is_noop_alias_copy(node: torch.fx.Node) -> bool:
60+
"""Return True if the node is a no-op ``aten.alias_copy``.
61+
62+
Args:
63+
node (torch.fx.Node): FX node to inspect.
64+
65+
Returns:
66+
bool: True if the node targets ``aten.alias_copy.default``; otherwise,
67+
False.
68+
69+
"""
4170
return node.target == exir_ops.edge.aten.alias_copy.default
4271

4372

4473
def is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool:
74+
"""Return True if node is a no-op ``dim_order_ops._to_dim_order_copy``.
75+
76+
Consider the op a no-op when the output dtype equals the input's dtype.
77+
78+
Args:
79+
node (torch.fx.Node): FX node to inspect.
80+
81+
Returns:
82+
bool: True if it targets ``_to_dim_order_copy.default`` and preserves
83+
dtype; otherwise, False.
84+
85+
"""
4586
if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default:
4687
return False
4788
else:
4889
return node.meta.get("dtype") == get_first_fake_tensor(node.args[0]).dtype # type: ignore[arg-type]
4990

5091

5192
def is_noop_expand(node: torch.fx.node.Node) -> bool:
93+
"""Return True if the node is an ``expand_copy`` with all-ones multiples.
94+
95+
This corresponds to a semantic no-op, since expanding by 1 along every
96+
dimension leaves the tensor unchanged.
97+
98+
Args:
99+
node (torch.fx.Node): FX node to inspect.
100+
101+
Returns:
102+
bool: True if the node targets ``aten.expand_copy.default`` and all
103+
computed multiples are 1; otherwise, False.
104+
105+
"""
52106
if node.target != exir_ops.edge.aten.expand_copy.default:
53107
return False
54108
else:
@@ -57,11 +111,30 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:
57111

58112

59113
class TOSAPartitioner(Partitioner):
114+
"""Partition an exported program into TOSA-delegable subgraphs.
115+
116+
Construct this partitioner for compile specs targeting TOSA. The partition
117+
algorithm uses capability checks and optional additional operator-support
118+
rules to tag nodes with a delegation tag per subgraph.
119+
"""
120+
60121
def __init__(
61122
self,
62123
compile_spec: TosaCompileSpec,
63124
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
64125
) -> None:
126+
"""Initialize the TOSAPartitioner.
127+
128+
Args:
129+
compile_spec (TosaCompileSpec): Parsed compile specifications for
130+
TOSA containing the TOSA spec and original list.
131+
additional_checks (Optional[Sequence[OperatorSupportBase]]): Extra
132+
operator-support checks to apply when partitioning.
133+
134+
Raises:
135+
RuntimeError: If the provided compile spec does not target TOSA.
136+
137+
"""
65138
self.delegation_spec = DelegationSpec(
66139
TOSABackend.__name__, compile_spec.to_list()
67140
)
@@ -70,9 +143,22 @@ def __init__(
70143
self.tosa_spec = compile_spec.tosa_spec
71144

72145
def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa
73-
# Run the CapabilityBasedPartitioner to return the largest possible
74-
# subgraphs containing the nodes with the tags
146+
"""Partition the program and tag TOSA-compatible subgraphs.
147+
148+
Run the FX capability-based partitioner to propose subgraphs, then
149+
refine tags by removing boundary-only quantize/dequantize nodes and by
150+
rejecting partitions that would lower to no-ops. Emit a detailed report
151+
of rejected nodes and their reasons.
152+
153+
Args:
154+
exported_program (ExportedProgram): Program to analyze and
155+
partition.
156+
157+
Returns:
158+
PartitionResult: The input program with nodes tagged for delegation
159+
and a mapping of partition tags to delegation specs.
75160
161+
"""
76162
logger.info("TOSAPartitioner::partition")
77163
partition_tags: dict[str, DelegationSpec] = {}
78164

@@ -92,6 +178,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no
92178
partition_list = capability_partitioner.propose_partitions()
93179

94180
def reject_partition(reason: str, partition, tag) -> None:
181+
"""Remove a proposed partition and record the rejection reason.
182+
183+
Args:
184+
reason (str): Human-readable explanation for rejection.
185+
partition (object): Proposed partition object from the
186+
capability partitioner.
187+
tag (str): Delegation tag associated with the partition.
188+
189+
"""
95190
for node in partition.nodes:
96191
if "delegation_tag" in node.meta:
97192
del node.meta["delegation_tag"]
@@ -105,6 +200,16 @@ def reject_partition(reason: str, partition, tag) -> None:
105200
tag = f"tag{partition.id}"
106201

107202
def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
203+
"""Return True if the node currently belongs to the partition ``tag``.
204+
205+
Args:
206+
node (torch.fx.Node): FX node to check.
207+
tag (str): Delegation tag identifying the partition.
208+
209+
Returns:
210+
bool: True if the node carries the matching delegation tag.
211+
212+
"""
108213
return (
109214
"delegation_tag" in node.meta and node.meta["delegation_tag"] == tag
110215
)
@@ -113,8 +218,8 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
113218
node.meta["delegation_tag"] = tag
114219
partition_tags[tag] = self.delegation_spec
115220

116-
# De-tag outmost q-nodes upwards and dq-nodes downwards.
117-
# De-tag if at least one input/ output is not part of partition.
221+
# De-tag outermost q-nodes upwards and dq-nodes downwards.
222+
# De-tag if at least one input/output is not part of the partition.
118223
for node in exported_program.graph_module.graph.nodes:
119224
if not is_partitioned(node):
120225
continue
@@ -175,15 +280,41 @@ def ops_to_not_decompose(
175280
self,
176281
ep: ExportedProgram,
177282
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
283+
"""Return operators and a filter that should not be decomposed.
284+
285+
Provide a base set of ops to preserve as-is and a predicate that keeps
286+
certain activations whole when surrounded by quantize/dequantize ops in
287+
a quantized graph. This helps downstream TOSA lowering and delegation.
288+
289+
Args:
290+
ep (ExportedProgram): Program used to infer target-specific policy.
291+
292+
Returns:
293+
Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
294+
A list of op overloads to keep intact, and an optional filter
295+
function that returns True when an op should not be decomposed.
296+
297+
"""
178298
ops_to_not_decompose_if_quant_op = [
179299
torch.ops.aten.hardsigmoid.default,
180300
torch.ops.aten.hardswish.default,
181301
]
182302

183303
def filter_fn(node: torch.fx.Node) -> bool:
184-
# This function filters for operators to not decompose where:
185-
# - It's target is in ops_to_not_decompose_if_quant_op list.
186-
# - All it's inputs/outputs are quantize operators.
304+
"""Return True to keep selected ops intact inside quantized regions.
305+
306+
The predicate holds when the target is in
307+
``ops_to_not_decompose_if_quant_op`` and all inputs/outputs are
308+
quantize/dequantize ops, indicating a quantized activation that
309+
should not be decomposed.
310+
311+
Args:
312+
node (torch.fx.Node): FX node to evaluate.
313+
314+
Returns:
315+
bool: True to keep the op intact; otherwise, False.
316+
317+
"""
187318
dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default
188319
q = torch.ops.quantized_decomposed.quantize_per_tensor.default
189320

@@ -204,7 +335,7 @@ def filter_fn(node: torch.fx.Node) -> bool:
204335

205336
return should_not_decompose
206337

207-
# Be default, do not decompose the operator
338+
# By default, do not decompose the operator
208339
return True
209340

210341
ops_to_not_decompose = [

0 commit comments

Comments
 (0)