Skip to content

Commit 1c77077

Browse files
committed
NXP backend: Add infrastructure for context dependant partitioning.
1 parent d43cde5 commit 1c77077

File tree

2 files changed

+68
-3
lines changed

2 files changed

+68
-3
lines changed

backends/nxp/backend/ir/converter/node_converter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from torch.fx import Node
21+
from torch.fx.passes.infra.partitioner import Partition
2122
from torch.nn import Parameter
2223

2324

@@ -125,6 +126,23 @@ def is_supported(
125126
node, target, parameters_mapping, custom_delegation_options
126127
)
127128

129+
@classmethod
130+
def supports_partitioning_result(
131+
cls,
132+
node: Node,
133+
partition_list: list[Partition],
134+
custom_delegation_options: CustomDelegationOptions,
135+
):
136+
"""Check if the given `node` supports the assigned partitioning, which is stored the `partition_list`. Child
137+
classes can overwrite this method in case they have delegation restrictions based on the context defined by
138+
the partitioning result.
139+
140+
:param node: torch.Node to check.
141+
:param partition_list: List of proposed partitions.
142+
:param custom_delegation_options: Custom user options which affect node delegation.
143+
"""
144+
return True
145+
128146
@staticmethod
129147
def _has_shared_q_params_if_quantized(node: Node) -> bool:
130148
"""Check if node has shared quantization parameters if it's quantized."""

backends/nxp/neutron_partitioner.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
)
2121
from executorch.backends.nxp.backend.ir.converter.node_converter import Target
2222
from torch.export.exported_program import ExportedProgram
23-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
23+
from torch.fx import Graph
24+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
2425
from torch.fx.passes.operator_support import OperatorSupportBase
2526
from torch.nn import Parameter
2627
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
@@ -34,6 +35,9 @@
3435
from executorch.exir.backend.utils import tag_constant_data
3536
from executorch.exir.dialects._ops import ops as exir_ops
3637

38+
NXP_DO_NOT_DELEGATE = "NXP_DO_NOT_DELEGATE"
39+
NXP_DELEGATION_TAG = "delegation_tag"
40+
3741

3842
class QDQClusterRecognizer:
3943
"""
@@ -246,6 +250,11 @@ def _is_node_supported_compute(self, node: torch.fx.node.Node) -> bool:
246250
"""
247251
Operator checking function for compute nodes.
248252
"""
253+
254+
if hasattr(node, "meta") and node.meta.get(NXP_DO_NOT_DELEGATE, False):
255+
# The delegation of this node has been prohibited.
256+
return False
257+
249258
if not self.is_node_delegatable(node):
250259
return False
251260

@@ -304,6 +313,31 @@ def __init__(
304313
custom_delegation_options or CustomDelegationOptions()
305314
)
306315

316+
def validate_partitioning_result(
317+
self,
318+
graph: Graph,
319+
partition_list: list[Partition],
320+
custom_delegation_options: CustomDelegationOptions,
321+
) -> bool:
322+
all_delegated_nodes = {
323+
node for partition in partition_list for node in partition.nodes
324+
}
325+
partitioning_valid = True
326+
for node in graph.nodes:
327+
if (
328+
node in all_delegated_nodes
329+
and hasattr(node, "target")
330+
and node.target in supported_ops
331+
):
332+
if not supported_ops[node.target].supports_partitioning_result(
333+
node, partition_list, custom_delegation_options
334+
):
335+
# This node is not supported within its partition. Exclude it from delegation in the future.
336+
partitioning_valid = False
337+
node.meta[NXP_DO_NOT_DELEGATE] = True
338+
339+
return partitioning_valid
340+
307341
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
308342
# Run the CapabilityBasedPartitioner to return the largest possible
309343
# subgraphs containing the nodes with the tags
@@ -342,11 +376,24 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
342376
allows_single_node_partition=True,
343377
)
344378

345-
partition_list = capability_partitioner.propose_partitions()
379+
iteration_limit = len(exported_program.graph.nodes)
380+
for _ in range(iteration_limit):
381+
# Run the partitioning.
382+
partition_list = capability_partitioner.propose_partitions()
383+
384+
# Check if the nodes support the partitioning result. Mark the problematic nodes with `NXP_DO_NOT_DELEGATE`.
385+
partitioning_valid = self.validate_partitioning_result(
386+
exported_program.graph, partition_list, self.custom_delegation_options
387+
)
388+
if partitioning_valid:
389+
# The result of the partitioning is fine
390+
break
391+
392+
# Mark the partitions in the node `meta` attribute.
346393
for partition in partition_list:
347394
for node in partition.nodes:
348395
delegation_tag = f"tag{partition.id}"
349-
node.meta["delegation_tag"] = delegation_tag
396+
node.meta[NXP_DELEGATION_TAG] = delegation_tag
350397
partition_tags[delegation_tag] = self.delegation_spec
351398

352399
tag_constant_data(exported_program)

0 commit comments

Comments
 (0)