Skip to content

Commit bb6dbc3

Browse files
committed
Arm backend: break out tagging in partitioner
This prepares the partitioner for partitioning submodules. Signed-off-by: Erik Lundell <[email protected]> Change-Id: Ie3123c672ad4df93b8a9f835c3908d58668e27d2
1 parent 8f0e8a8 commit bb6dbc3

File tree

1 file changed

+91
-66
lines changed

1 file changed

+91
-66
lines changed

backends/arm/tosa/partitioner.py

Lines changed: 91 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter
3737
from executorch.exir.dialects._ops import ops as exir_ops
3838
from torch.export.exported_program import ExportedProgram
39-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
39+
from torch.fx import GraphModule
40+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
4041
from torch.fx.passes.operator_support import OperatorSupportBase
4142

4243
logger = logging.getLogger(__name__)
@@ -110,6 +111,43 @@ def is_noop_expand(node: torch.fx.node.Node) -> bool:
110111
return all(m == 1 for m in multiples)
111112

112113

114+
def is_partitioned(
115+
node: torch.fx.Node,
116+
tag: str,
117+
) -> bool:
118+
"""Return True if the node currently belongs to the partition ``tag``.
119+
120+
Args:
121+
node (torch.fx.Node): FX node to check.
122+
tag (str): Delegation tag identifying the partition.
123+
124+
Returns:
125+
bool: True if the node carries the matching delegation tag.
126+
127+
"""
128+
return "delegation_tag" in node.meta and node.meta["delegation_tag"] == tag
129+
130+
131+
def reject_partition(
132+
reason: str, partition: Partition, reporter: WhyNoPartitionReporter
133+
) -> None:
134+
"""Remove a proposed partition and record the rejection reason.
135+
136+
Args:
137+
reason (str): Human-readable explanation for rejection.
138+
partition (object): Proposed partition object from the
139+
capability partitioner.
140+
reporter (WhyNoPartitionReporter): used to report why nodes were rejected.
141+
"""
142+
for node in partition.nodes:
143+
if "delegation_tag" in node.meta:
144+
del node.meta["delegation_tag"]
145+
reporter.report_reject(
146+
node,
147+
reason,
148+
)
149+
150+
113151
class TOSAPartitioner(Partitioner):
114152
"""Partition an exported program into TOSA-delegable subgraphs.
115153
@@ -142,107 +180,64 @@ def __init__(
142180
self.additional_checks = additional_checks
143181
self.tosa_spec = compile_spec.tosa_spec
144182

145-
def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa
146-
"""Partition the program and tag TOSA-compatible subgraphs.
147-
148-
Run the FX capability-based partitioner to propose subgraphs, then
149-
refine tags by removing boundary-only quantize/dequantize nodes and by
150-
rejecting partitions that would lower to no-ops. Emit a detailed report
151-
of rejected nodes and their reasons.
183+
def _tag_module( # noqa
184+
self,
185+
module: GraphModule,
186+
containing_program: ExportedProgram,
187+
reporter: WhyNoPartitionReporter,
188+
) -> set[str]:
189+
"""Tag nodes in a module, possibly a submodule, from the containing program.
152190
153191
Args:
154-
exported_program (ExportedProgram): Program to analyze and
155-
partition.
156-
192+
module: a GraphModule from `containing_program` to tag nodes in.
193+
containing_program: The ExportedProgram that contains the module.
194+
reporter: A reporter to report why nodes were rejected.
157195
Returns:
158-
PartitionResult: The input program with nodes tagged for delegation
159-
and a mapping of partition tags to delegation specs.
160-
196+
A set of strings with the partition tags.
161197
"""
162-
logger.info("TOSAPartitioner::partition")
163-
partition_tags: dict[str, DelegationSpec] = {}
164-
165-
logger.info(
166-
f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}"
167-
)
168-
169-
reporter = WhyNoPartitionReporter()
198+
tags: set[str] = set()
170199
operator_support = tosa_support_factory(
171-
self.tosa_spec, exported_program, reporter, self.additional_checks
200+
self.tosa_spec, containing_program, reporter, self.additional_checks
172201
)
173202
capability_partitioner = CapabilityBasedPartitioner(
174-
exported_program.graph_module,
203+
module,
175204
operator_support,
176205
allows_single_node_partition=True,
177206
)
178207
partition_list = capability_partitioner.propose_partitions()
179208

180-
def reject_partition(reason: str, partition, tag) -> None:
181-
"""Remove a proposed partition and record the rejection reason.
182-
183-
Args:
184-
reason (str): Human-readable explanation for rejection.
185-
partition (object): Proposed partition object from the
186-
capability partitioner.
187-
tag (str): Delegation tag associated with the partition.
188-
189-
"""
190-
for node in partition.nodes:
191-
if "delegation_tag" in node.meta:
192-
del node.meta["delegation_tag"]
193-
reporter.report_reject(
194-
node,
195-
reason,
196-
)
197-
partition_tags.pop(tag, None)
198-
199209
for partition in partition_list:
200210
tag = f"tag{partition.id}"
201-
202-
def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
203-
"""Return True if the node currently belongs to the partition ``tag``.
204-
205-
Args:
206-
node (torch.fx.Node): FX node to check.
207-
tag (str): Delegation tag identifying the partition.
208-
209-
Returns:
210-
bool: True if the node carries the matching delegation tag.
211-
212-
"""
213-
return (
214-
"delegation_tag" in node.meta and node.meta["delegation_tag"] == tag
215-
)
211+
tags.add(tag)
216212

217213
for node in partition.nodes:
218214
node.meta["delegation_tag"] = tag
219-
partition_tags[tag] = self.delegation_spec
220215

221216
# De-tag outermost q-nodes upwards and dq-nodes downwards.
222217
# De-tag if at least one input/output is not part of the partition.
223-
for node in exported_program.graph_module.graph.nodes:
224-
if not is_partitioned(node):
218+
for node in module.graph.nodes:
219+
if not is_partitioned(node, tag):
225220
continue
226221
if node.target in Q_OPS:
227222
for input in node.all_input_nodes:
228-
if not is_partitioned(input):
223+
if not is_partitioned(input, tag):
229224
del node.meta["delegation_tag"]
230225
break
231226
continue
232227

233228
if node.target in DQ_OPS:
234229
for user in node.users:
235-
if not is_partitioned(user):
230+
if not is_partitioned(user, tag):
236231
del node.meta["delegation_tag"]
237232
break
238233
continue
239234

240235
if self.tosa_spec.support_float():
241236
continue
242237

243-
if is_partitioned(node):
238+
if is_partitioned(node, tag):
244239
for input in node.all_input_nodes:
245-
if is_partitioned(input):
240+
if is_partitioned(input, tag):
246241
continue
247242
if get_first_fake_tensor(input).dtype.is_floating_point:
248243
reporter.report_reject(
@@ -265,8 +260,38 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
265260
reject_partition(
266261
"Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.",
267262
partition,
268-
tag,
263+
reporter,
269264
)
265+
tags.remove(tag)
266+
return tags
267+
268+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
269+
"""Partition the program and tag TOSA-compatible subgraphs.
270+
271+
Run the FX capability-based partitioner to propose subgraphs, then
272+
refine tags by removing boundary-only quantize/dequantize nodes and by
273+
rejecting partitions that would lower to no-ops. Emit a detailed report
274+
of rejected nodes and their reasons.
275+
276+
Args:
277+
exported_program (ExportedProgram): Program to analyze and
278+
partition.
279+
280+
Returns:
281+
PartitionResult: The input program with nodes tagged for delegation
282+
and a mapping of partition tags to delegation specs.
283+
284+
"""
285+
logger.info("TOSAPartitioner::partition")
286+
logger.info(
287+
f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}"
288+
)
289+
290+
reporter = WhyNoPartitionReporter()
291+
tags = self._tag_module(
292+
exported_program.graph_module, exported_program, reporter
293+
)
294+
partition_tags = {tag: self.delegation_spec for tag in tags}
270295

271296
tag_constant_data(exported_program)
272297
logger.info(f"The following nodes were rejected for {self.tosa_spec}:")

0 commit comments

Comments
 (0)