Skip to content

Commit 6a238e3

Browse files
NXP backend: Add infrastructure for context dependant partitioning (#14373)
### Summary This PR adds the option to specify delegation conditions which depend on the partition a particular node ends up in. This infrastructure is applied to the `view_copy` node. ### Test plan Unit tests provided. cc @robert-kalmar @roman-janik-nxp @StrycekSimon @jirioc
1 parent b100c95 commit 6a238e3

File tree

4 files changed

+166
-3
lines changed

4 files changed

+166
-3
lines changed

backends/nxp/backend/ir/converter/node_converter.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from torch.fx import Node
21+
from torch.fx.passes.infra.partitioner import Partition
2122
from torch.nn import Parameter
2223

2324

@@ -37,6 +38,10 @@ def _is_dequant_node(node: torch.fx.Node) -> bool:
3738
]
3839

3940

41+
def is_not_qdq_node(node: torch.fx.Node) -> bool:
42+
return not (_is_quant_node(node) or _is_dequant_node(node))
43+
44+
4045
class Target(Enum):
4146
IGNORE = "ignore" # No target platform. Any target specific restrictions will be ignored.
4247

@@ -125,6 +130,23 @@ def is_supported(
125130
node, target, parameters_mapping, custom_delegation_options
126131
)
127132

133+
@classmethod
134+
def supports_partitioning_result(
135+
cls,
136+
node: Node,
137+
partition_list: list[Partition],
138+
custom_delegation_options: CustomDelegationOptions,
139+
):
140+
"""Check if the given `node` supports the assigned partitioning, which is stored the `partition_list`. Child
141+
classes can overwrite this method in case they have delegation restrictions based on the context defined by
142+
the partitioning result.
143+
144+
:param node: torch.Node to check.
145+
:param partition_list: List of proposed partitions.
146+
:param custom_delegation_options: Custom user options which affect node delegation.
147+
"""
148+
return True
149+
128150
@staticmethod
129151
def _has_shared_q_params_if_quantized(node: Node) -> bool:
130152
"""Check if node has shared quantization parameters if it's quantized."""

backends/nxp/backend/ir/converter/node_converters/ops_converters/view_copy_converter.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
1515
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1616
CustomDelegationOptions,
17+
is_not_qdq_node,
1718
NodeConverter,
1819
)
1920
from executorch.backends.nxp.backend.ir.converter.node_converters.shared.reshape_transposition import (
@@ -23,6 +24,7 @@
2324
reshape_options,
2425
)
2526
from torch.fx import Node
27+
from torch.fx.passes.infra.partitioner import Partition
2628
from torch.nn import Parameter
2729

2830

@@ -45,6 +47,27 @@ def _is_supported_in_IR(
4547

4648
return True
4749

50+
@classmethod
51+
def supports_partitioning_result(
52+
cls,
53+
node: Node,
54+
partition_list: list[Partition],
55+
custom_delegation_options: CustomDelegationOptions,
56+
):
57+
view_copy_partitions = [
58+
partition for partition in partition_list if node in partition.nodes
59+
]
60+
assert len(view_copy_partitions) == 1
61+
non_q_dq_partition_nodes = list(
62+
filter(is_not_qdq_node, view_copy_partitions[0].nodes)
63+
)
64+
65+
if len(non_q_dq_partition_nodes) == 1:
66+
# The `view_copy` cannot be the only node in a partition.
67+
return False
68+
69+
return True
70+
4871
@staticmethod
4972
def _safe_compute_flat_size(shape: list[int | str]) -> int:
5073
"""Compute the flat size of a tensor with given shape. Strings and negative dimensions are treated as '1'.

backends/nxp/neutron_partitioner.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
)
2121
from executorch.backends.nxp.backend.ir.converter.node_converter import Target
2222
from torch.export.exported_program import ExportedProgram
23-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
23+
from torch.fx import Graph
24+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
2425
from torch.fx.passes.operator_support import OperatorSupportBase
2526
from torch.nn import Parameter
2627
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
@@ -34,6 +35,9 @@
3435
from executorch.exir.backend.utils import tag_constant_data
3536
from executorch.exir.dialects._ops import ops as exir_ops
3637

38+
NXP_DO_NOT_DELEGATE = "NXP_DO_NOT_DELEGATE"
39+
NXP_DELEGATION_TAG = "delegation_tag"
40+
3741

3842
class QDQClusterRecognizer:
3943
"""
@@ -246,6 +250,11 @@ def _is_node_supported_compute(self, node: torch.fx.node.Node) -> bool:
246250
"""
247251
Operator checking function for compute nodes.
248252
"""
253+
254+
if hasattr(node, "meta") and node.meta.get(NXP_DO_NOT_DELEGATE, False):
255+
# The delegation of this node has been prohibited.
256+
return False
257+
249258
if not self.is_node_delegatable(node):
250259
return False
251260

@@ -304,6 +313,31 @@ def __init__(
304313
custom_delegation_options or CustomDelegationOptions()
305314
)
306315

316+
def validate_partitioning_result(
317+
self,
318+
graph: Graph,
319+
partition_list: list[Partition],
320+
custom_delegation_options: CustomDelegationOptions,
321+
) -> bool:
322+
all_delegated_nodes = {
323+
node for partition in partition_list for node in partition.nodes
324+
}
325+
partitioning_valid = True
326+
for node in graph.nodes:
327+
if (
328+
node in all_delegated_nodes
329+
and hasattr(node, "target")
330+
and node.target in supported_ops
331+
):
332+
if not supported_ops[node.target].supports_partitioning_result(
333+
node, partition_list, custom_delegation_options
334+
):
335+
# This node is not supported within its partition. Exclude it from delegation in the future.
336+
partitioning_valid = False
337+
node.meta[NXP_DO_NOT_DELEGATE] = True
338+
339+
return partitioning_valid
340+
307341
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
308342
# Run the CapabilityBasedPartitioner to return the largest possible
309343
# subgraphs containing the nodes with the tags
@@ -342,11 +376,24 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
342376
allows_single_node_partition=True,
343377
)
344378

345-
partition_list = capability_partitioner.propose_partitions()
379+
iteration_limit = len(exported_program.graph.nodes)
380+
for _ in range(iteration_limit):
381+
# Run the partitioning.
382+
partition_list = capability_partitioner.propose_partitions()
383+
384+
# Check if the nodes support the partitioning result. Mark the problematic nodes with `NXP_DO_NOT_DELEGATE`.
385+
partitioning_valid = self.validate_partitioning_result(
386+
exported_program.graph, partition_list, self.custom_delegation_options
387+
)
388+
if partitioning_valid:
389+
# The result of the partitioning is fine
390+
break
391+
392+
# Mark the partitions in the node `meta` attribute.
346393
for partition in partition_list:
347394
for node in partition.nodes:
348395
delegation_tag = f"tag{partition.id}"
349-
node.meta["delegation_tag"] = delegation_tag
396+
node.meta[NXP_DELEGATION_TAG] = delegation_tag
350397
partition_tags[delegation_tag] = self.delegation_spec
351398

352399
tag_constant_data(exported_program)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import numpy as np
9+
import torch
10+
11+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import (
12+
ViewCopyConverter,
13+
)
14+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
15+
from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
18+
19+
class SingleViewCopyModule(torch.nn.Module):
20+
def __init__(self, new_shape: list[int]):
21+
super().__init__()
22+
self.new_shape = new_shape
23+
24+
def forward(self, x):
25+
return torch.reshape(x, self.new_shape)
26+
27+
28+
class TestContextSensitiveDelegation(unittest.TestCase):
29+
__test__ = False # Prevent interfering with PyTest tests.
30+
31+
@classmethod
32+
def setUpClass(cls):
33+
torch.manual_seed(23)
34+
np.random.seed(42)
35+
36+
def test_single_view_copy_partition(self):
37+
input_shape = (2, 10)
38+
module = SingleViewCopyModule([1, 20])
39+
40+
ep = to_quantized_edge_program(module, input_shape).exported_program()
41+
42+
# Make sure the `view_copy` was not delegated.
43+
assert graph_contains_any_of_ops(
44+
ep.graph, [exir_ops.edge.aten.view_copy.default]
45+
)
46+
assert not any("delegate" in n.name for n in ep.graph.nodes)
47+
48+
def test_single_view_copy_partition__forced_delegation(self):
49+
input_shape = (2, 10)
50+
module = SingleViewCopyModule([1, 20])
51+
52+
def _supported_partitioning(*_):
53+
return True
54+
55+
# Replace the partition support check function, to accept anything.
56+
original_supports_partitioning_result = (
57+
ViewCopyConverter.supports_partitioning_result
58+
)
59+
ViewCopyConverter.supports_partitioning_result = _supported_partitioning
60+
61+
with self.assertRaises(RuntimeError) as e:
62+
to_quantized_edge_program(module, input_shape).exported_program()
63+
assert (
64+
str(e.exception)
65+
== "Model converted with neutron-converter does not contain a NeutronGraph node."
66+
)
67+
68+
# Return to the original partition support check function.
69+
ViewCopyConverter.supports_partitioning_result = (
70+
original_supports_partitioning_result
71+
)

0 commit comments

Comments
 (0)