44# LICENSE file in the root directory of this source tree.
55
66# pyre-unsafe
7+ """Provide a partitioner for delegating subgraphs to the TOSA backend.
8+
9+ Implement logic to identify and tag regions of an ``ExportedProgram`` that can
10+ be delegated to the TOSA backend. Use this module to:
11+
12+ - Partition graphs based on operator support and additional checks.
13+ - Prune trivial no-op partitions that would lower to empty TOSA graphs.
14+ - Tag constant data and report reasons for rejected nodes.
15+ """
716
817import logging
918from typing import Callable , List , Optional , Sequence , Tuple
3443
3544
3645def is_noop_clone (node : torch .fx .node .Node ) -> bool :
46+ """Return True if the node is a no-op ``dim_order_ops._clone_dim_order``.
47+
48+ Args:
49+ node (torch.fx.Node): FX node to inspect.
50+
51+ Returns:
52+ bool: True if the node targets ``dim_order_ops._clone_dim_order.default``
53+ in the Edge dialect; otherwise, False.
54+
55+ """
3756 return node .target == exir_ops .edge .dim_order_ops ._clone_dim_order .default
3857
3958
4059def is_noop_alias_copy (node : torch .fx .Node ) -> bool :
60+ """Return True if the node is a no-op ``aten.alias_copy``.
61+
62+ Args:
63+ node (torch.fx.Node): FX node to inspect.
64+
65+ Returns:
66+ bool: True if the node targets ``aten.alias_copy.default``; otherwise,
67+ False.
68+
69+ """
4170 return node .target == exir_ops .edge .aten .alias_copy .default
4271
4372
4473def is_noop_to_dim_order_copy (node : torch .fx .node .Node ) -> bool :
74+ """Return True if node is a no-op ``dim_order_ops._to_dim_order_copy``.
75+
76+ Consider the op a no-op when the output dtype equals the input's dtype.
77+
78+ Args:
79+ node (torch.fx.Node): FX node to inspect.
80+
81+ Returns:
82+ bool: True if it targets ``_to_dim_order_copy.default`` and preserves
83+ dtype; otherwise, False.
84+
85+ """
4586 if node .target != exir_ops .edge .dim_order_ops ._to_dim_order_copy .default :
4687 return False
4788 else :
4889 return node .meta .get ("dtype" ) == get_first_fake_tensor (node .args [0 ]).dtype # type: ignore[arg-type]
4990
5091
5192def is_noop_expand (node : torch .fx .node .Node ) -> bool :
93+ """Return True if the node is an ``expand_copy`` with all-ones multiples.
94+
95+ This corresponds to a semantic no-op, since expanding by 1 along every
96+ dimension leaves the tensor unchanged.
97+
98+ Args:
99+ node (torch.fx.Node): FX node to inspect.
100+
101+ Returns:
102+ bool: True if the node targets ``aten.expand_copy.default`` and all
103+ computed multiples are 1; otherwise, False.
104+
105+ """
52106 if node .target != exir_ops .edge .aten .expand_copy .default :
53107 return False
54108 else :
@@ -57,11 +111,30 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:
57111
58112
59113class TOSAPartitioner (Partitioner ):
114+ """Partition an exported program into TOSA-delegable subgraphs.
115+
116+ Construct this partitioner for compile specs targeting TOSA. The partition
117+ algorithm uses capability checks and optional additional operator-support
118+ rules to tag nodes with a delegation tag per subgraph.
119+ """
120+
60121 def __init__ (
61122 self ,
62123 compile_spec : TosaCompileSpec ,
63124 additional_checks : Optional [Sequence [OperatorSupportBase ]] = None ,
64125 ) -> None :
126+ """Initialize the TOSAPartitioner.
127+
128+ Args:
129+ compile_spec (TosaCompileSpec): Parsed compile specifications for
130+ TOSA containing the TOSA spec and original list.
131+ additional_checks (Optional[Sequence[OperatorSupportBase]]): Extra
132+ operator-support checks to apply when partitioning.
133+
134+ Raises:
135+ RuntimeError: If the provided compile spec does not target TOSA.
136+
137+ """
65138 self .delegation_spec = DelegationSpec (
66139 TOSABackend .__name__ , compile_spec .to_list ()
67140 )
@@ -70,9 +143,22 @@ def __init__(
70143 self .tosa_spec = compile_spec .tosa_spec
71144
72145 def partition (self , exported_program : ExportedProgram ) -> PartitionResult : # noqa
73- # Run the CapabilityBasedPartitioner to return the largest possible
74- # subgraphs containing the nodes with the tags
146+ """Partition the program and tag TOSA-compatible subgraphs.
147+
148+ Run the FX capability-based partitioner to propose subgraphs, then
149+ refine tags by removing boundary-only quantize/dequantize nodes and by
150+ rejecting partitions that would lower to no-ops. Emit a detailed report
151+ of rejected nodes and their reasons.
152+
153+ Args:
154+ exported_program (ExportedProgram): Program to analyze and
155+ partition.
156+
157+ Returns:
158+ PartitionResult: The input program with nodes tagged for delegation
159+ and a mapping of partition tags to delegation specs.
75160
161+ """
76162 logger .info ("TOSAPartitioner::partition" )
77163 partition_tags : dict [str , DelegationSpec ] = {}
78164
@@ -92,6 +178,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no
92178 partition_list = capability_partitioner .propose_partitions ()
93179
94180 def reject_partition (reason : str , partition , tag ) -> None :
181+ """Remove a proposed partition and record the rejection reason.
182+
183+ Args:
184+ reason (str): Human-readable explanation for rejection.
185+ partition (object): Proposed partition object from the
186+ capability partitioner.
187+ tag (str): Delegation tag associated with the partition.
188+
189+ """
95190 for node in partition .nodes :
96191 if "delegation_tag" in node .meta :
97192 del node .meta ["delegation_tag" ]
@@ -105,6 +200,16 @@ def reject_partition(reason: str, partition, tag) -> None:
105200 tag = f"tag{ partition .id } "
106201
107202 def is_partitioned (node : torch .fx .Node , tag = tag ) -> bool :
203+ """Return True if the node currently belongs to the partition ``tag``.
204+
205+ Args:
206+ node (torch.fx.Node): FX node to check.
207+ tag (str): Delegation tag identifying the partition.
208+
209+ Returns:
210+ bool: True if the node carries the matching delegation tag.
211+
212+ """
108213 return (
109214 "delegation_tag" in node .meta and node .meta ["delegation_tag" ] == tag
110215 )
@@ -113,8 +218,8 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
113218 node .meta ["delegation_tag" ] = tag
114219 partition_tags [tag ] = self .delegation_spec
115220
116- # De-tag outmost q-nodes upwards and dq-nodes downwards.
117- # De-tag if at least one input/ output is not part of partition.
221+ # De-tag outermost q-nodes upwards and dq-nodes downwards.
222+ # De-tag if at least one input/output is not part of the partition.
118223 for node in exported_program .graph_module .graph .nodes :
119224 if not is_partitioned (node ):
120225 continue
@@ -175,15 +280,41 @@ def ops_to_not_decompose(
175280 self ,
176281 ep : ExportedProgram ,
177282 ) -> Tuple [List [torch ._ops .OpOverload ], Optional [Callable [[torch .fx .Node ], bool ]]]:
283+ """Return operators and a filter that should not be decomposed.
284+
285+ Provide a base set of ops to preserve as-is and a predicate that keeps
286+ certain activations whole when surrounded by quantize/dequantize ops in
287+ a quantized graph. This helps downstream TOSA lowering and delegation.
288+
289+ Args:
290+ ep (ExportedProgram): Program used to infer target-specific policy.
291+
292+ Returns:
293+ Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
294+ A list of op overloads to keep intact, and an optional filter
295+ function that returns True when an op should not be decomposed.
296+
297+ """
178298 ops_to_not_decompose_if_quant_op = [
179299 torch .ops .aten .hardsigmoid .default ,
180300 torch .ops .aten .hardswish .default ,
181301 ]
182302
183303 def filter_fn (node : torch .fx .Node ) -> bool :
184- # This function filters for operators to not decompose where:
185- # - It's target is in ops_to_not_decompose_if_quant_op list.
186- # - All it's inputs/outputs are quantize operators.
304+ """Return True to keep selected ops intact inside quantized regions.
305+
306+ The predicate holds when the target is in
307+ ``ops_to_not_decompose_if_quant_op`` and all inputs/outputs are
308+ quantize/dequantize ops, indicating a quantized activation that
309+ should not be decomposed.
310+
311+ Args:
312+ node (torch.fx.Node): FX node to evaluate.
313+
314+ Returns:
315+ bool: True to keep the op intact; otherwise, False.
316+
317+ """
187318 dq = torch .ops .quantized_decomposed .dequantize_per_tensor .default
188319 q = torch .ops .quantized_decomposed .quantize_per_tensor .default
189320
@@ -204,7 +335,7 @@ def filter_fn(node: torch.fx.Node) -> bool:
204335
205336 return should_not_decompose
206337
207- # Be default, do not decompose the operator
338+ # By default, do not decompose the operator
208339 return True
209340
210341 ops_to_not_decompose = [
0 commit comments