3636from executorch .exir .backend .utils import tag_constant_data , WhyNoPartitionReporter
3737from executorch .exir .dialects ._ops import ops as exir_ops
3838from torch .export .exported_program import ExportedProgram
39- from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner
39+ from torch .fx import GraphModule
40+ from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner , Partition
4041from torch .fx .passes .operator_support import OperatorSupportBase
4142
4243logger = logging .getLogger (__name__ )
@@ -110,6 +111,43 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:
110111 return all (m == 1 for m in multiples )
111112
112113
114+ def is_partitioned (
115+ node : torch .fx .Node ,
116+ tag : str ,
117+ ) -> bool :
118+ """Return True if the node currently belongs to the partition ``tag``.
119+
120+ Args:
121+ node (torch.fx.Node): FX node to check.
122+ tag (str): Delegation tag identifying the partition.
123+
124+ Returns:
125+ bool: True if the node carries the matching delegation tag.
126+
127+ """
128+ return "delegation_tag" in node .meta and node .meta ["delegation_tag" ] == tag
129+
130+
131+ def reject_partition (
132+ reason : str , partition : Partition , reporter : WhyNoPartitionReporter
133+ ) -> None :
134+ """Remove a proposed partition and record the rejection reason.
135+
136+ Args:
137+ reason (str): Human-readable explanation for rejection.
138+ partition (object): Proposed partition object from the
139+ capability partitioner.
140+ reporter (WhyNoPartitionReporter): used to report why nodes were rejected.
141+ """
142+ for node in partition .nodes :
143+ if "delegation_tag" in node .meta :
144+ del node .meta ["delegation_tag" ]
145+ reporter .report_reject (
146+ node ,
147+ reason ,
148+ )
149+
150+
113151class TOSAPartitioner (Partitioner ):
114152 """Partition an exported program into TOSA-delegable subgraphs.
115153
@@ -142,107 +180,64 @@ def __init__(
142180 self .additional_checks = additional_checks
143181 self .tosa_spec = compile_spec .tosa_spec
144182
145- def partition ( self , exported_program : ExportedProgram ) -> PartitionResult : # noqa
146- """Partition the program and tag TOSA-compatible subgraphs.
147-
148- Run the FX capability-based partitioner to propose subgraphs, then
149- refine tags by removing boundary-only quantize/dequantize nodes and by
150- rejecting partitions that would lower to no-ops. Emit a detailed report
151- of rejected nodes and their reasons .
183+ def _tag_module ( # noqa
184+ self ,
185+ module : GraphModule ,
186+ containing_program : ExportedProgram ,
187+ reporter : WhyNoPartitionReporter ,
188+ ) -> set [ str ]:
189+ """Tag nodes in a module, possibly a submodule, from the containing program .
152190
153191 Args:
154- exported_program (ExportedProgram): Program to analyze and
155- partition .
156-
192+ module: a GraphModule from `containing_program` to tag nodes in.
193+ containing_program: The ExportedProgram that contains the module .
194+ reporter: A reporter to report why nodes were rejected.
157195 Returns:
158- PartitionResult: The input program with nodes tagged for delegation
159- and a mapping of partition tags to delegation specs.
160-
196+ A set of strings with the partition tags.
161197 """
162- logger .info ("TOSAPartitioner::partition" )
163- partition_tags : dict [str , DelegationSpec ] = {}
164-
165- logger .info (
166- f"Partitioning for { self .delegation_spec .backend_id } : { self .tosa_spec } "
167- )
168-
169- reporter = WhyNoPartitionReporter ()
198+ tags : set [str ] = set ()
170199 operator_support = tosa_support_factory (
171- self .tosa_spec , exported_program , reporter , self .additional_checks
200+ self .tosa_spec , containing_program , reporter , self .additional_checks
172201 )
173202 capability_partitioner = CapabilityBasedPartitioner (
174- exported_program . graph_module ,
203+ module ,
175204 operator_support ,
176205 allows_single_node_partition = True ,
177206 )
178207 partition_list = capability_partitioner .propose_partitions ()
179208
180- def reject_partition (reason : str , partition , tag ) -> None :
181- """Remove a proposed partition and record the rejection reason.
182-
183- Args:
184- reason (str): Human-readable explanation for rejection.
185- partition (object): Proposed partition object from the
186- capability partitioner.
187- tag (str): Delegation tag associated with the partition.
188-
189- """
190- for node in partition .nodes :
191- if "delegation_tag" in node .meta :
192- del node .meta ["delegation_tag" ]
193- reporter .report_reject (
194- node ,
195- reason ,
196- )
197- partition_tags .pop (tag , None )
198-
199209 for partition in partition_list :
200210 tag = f"tag{ partition .id } "
201-
202- def is_partitioned (node : torch .fx .Node , tag = tag ) -> bool :
203- """Return True if the node currently belongs to the partition ``tag``.
204-
205- Args:
206- node (torch.fx.Node): FX node to check.
207- tag (str): Delegation tag identifying the partition.
208-
209- Returns:
210- bool: True if the node carries the matching delegation tag.
211-
212- """
213- return (
214- "delegation_tag" in node .meta and node .meta ["delegation_tag" ] == tag
215- )
211+ tags .add (tag )
216212
217213 for node in partition .nodes :
218214 node .meta ["delegation_tag" ] = tag
219- partition_tags [tag ] = self .delegation_spec
220215
221216 # De-tag outermost q-nodes upwards and dq-nodes downwards.
222217 # De-tag if at least one input/output is not part of the partition.
223- for node in exported_program . graph_module .graph .nodes :
224- if not is_partitioned (node ):
218+ for node in module .graph .nodes :
219+ if not is_partitioned (node , tag ):
225220 continue
226221 if node .target in Q_OPS :
227222 for input in node .all_input_nodes :
228- if not is_partitioned (input ):
223+ if not is_partitioned (input , tag ):
229224 del node .meta ["delegation_tag" ]
230225 break
231226 continue
232227
233228 if node .target in DQ_OPS :
234229 for user in node .users :
235- if not is_partitioned (user ):
230+ if not is_partitioned (user , tag ):
236231 del node .meta ["delegation_tag" ]
237232 break
238233 continue
239234
240235 if self .tosa_spec .support_float ():
241236 continue
242237
243- if is_partitioned (node ):
238+ if is_partitioned (node , tag ):
244239 for input in node .all_input_nodes :
245- if is_partitioned (input ):
240+ if is_partitioned (input , tag ):
246241 continue
247242 if get_first_fake_tensor (input ).dtype .is_floating_point :
248243 reporter .report_reject (
@@ -265,8 +260,38 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
265260 reject_partition (
266261 "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition." ,
267262 partition ,
268- tag ,
263+ reporter ,
269264 )
265+ tags .remove (tag )
266+ return tags
267+
268+ def partition (self , exported_program : ExportedProgram ) -> PartitionResult :
269+ """Partition the program and tag TOSA-compatible subgraphs.
270+
271+ Run the FX capability-based partitioner to propose subgraphs, then
272+ refine tags by removing boundary-only quantize/dequantize nodes and by
273+ rejecting partitions that would lower to no-ops. Emit a detailed report
274+ of rejected nodes and their reasons.
275+
276+ Args:
277+ exported_program (ExportedProgram): Program to analyze and
278+ partition.
279+
280+ Returns:
281+ PartitionResult: The input program with nodes tagged for delegation
282+ and a mapping of partition tags to delegation specs.
283+
284+ """
285+ logger .info ("TOSAPartitioner::partition" )
286+ logger .info (
287+ f"Partitioning for { self .delegation_spec .backend_id } : { self .tosa_spec } "
288+ )
289+
290+ reporter = WhyNoPartitionReporter ()
291+ tags = self ._tag_module (
292+ exported_program .graph_module , exported_program , reporter
293+ )
294+ partition_tags = {tag : self .delegation_spec for tag in tags }
270295
271296 tag_constant_data (exported_program )
272297 logger .info (f"The following nodes were rejected for { self .tosa_spec } :" )
0 commit comments