|
20 | 20 | ) |
21 | 21 | from executorch.backends.nxp.backend.ir.converter.node_converter import Target |
22 | 22 | from torch.export.exported_program import ExportedProgram |
23 | | -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner |
| 23 | +from torch.fx import Graph |
| 24 | +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition |
24 | 25 | from torch.fx.passes.operator_support import OperatorSupportBase |
25 | 26 | from torch.nn import Parameter |
26 | 27 | from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 |
|
34 | 35 | from executorch.exir.backend.utils import tag_constant_data |
35 | 36 | from executorch.exir.dialects._ops import ops as exir_ops |
36 | 37 |
|
| 38 | +NXP_DO_NOT_DELEGATE = "NXP_DO_NOT_DELEGATE" |
| 39 | +NXP_DELEGATION_TAG = "delegation_tag" |
| 40 | + |
37 | 41 |
|
38 | 42 | class QDQClusterRecognizer: |
39 | 43 | """ |
@@ -246,6 +250,11 @@ def _is_node_supported_compute(self, node: torch.fx.node.Node) -> bool: |
246 | 250 | """ |
247 | 251 | Operator checking function for compute nodes. |
248 | 252 | """ |
| 253 | + |
| 254 | + if hasattr(node, "meta") and node.meta.get(NXP_DO_NOT_DELEGATE, False): |
| 255 | + # The delegation of this node has been prohibited. |
| 256 | + return False |
| 257 | + |
249 | 258 | if not self.is_node_delegatable(node): |
250 | 259 | return False |
251 | 260 |
|
@@ -304,6 +313,31 @@ def __init__( |
304 | 313 | custom_delegation_options or CustomDelegationOptions() |
305 | 314 | ) |
306 | 315 |
|
| 316 | + def validate_partitioning_result( |
| 317 | + self, |
| 318 | + graph: Graph, |
| 319 | + partition_list: list[Partition], |
| 320 | + custom_delegation_options: CustomDelegationOptions, |
| 321 | + ) -> bool: |
| 322 | + all_delegated_nodes = { |
| 323 | + node for partition in partition_list for node in partition.nodes |
| 324 | + } |
| 325 | + partitioning_valid = True |
| 326 | + for node in graph.nodes: |
| 327 | + if ( |
| 328 | + node in all_delegated_nodes |
| 329 | + and hasattr(node, "target") |
| 330 | + and node.target in supported_ops |
| 331 | + ): |
| 332 | + if not supported_ops[node.target].supports_partitioning_result( |
| 333 | + node, partition_list, custom_delegation_options |
| 334 | + ): |
| 335 | + # This node is not supported within its partition. Exclude it from delegation in the future. |
| 336 | + partitioning_valid = False |
| 337 | + node.meta[NXP_DO_NOT_DELEGATE] = True |
| 338 | + |
| 339 | + return partitioning_valid |
| 340 | + |
307 | 341 | def partition(self, exported_program: ExportedProgram) -> PartitionResult: |
308 | 342 | # Run the CapabilityBasedPartitioner to return the largest possible |
309 | 343 | # subgraphs containing the nodes with the tags |
@@ -342,11 +376,24 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: |
342 | 376 | allows_single_node_partition=True, |
343 | 377 | ) |
344 | 378 |
|
345 | | - partition_list = capability_partitioner.propose_partitions() |
| 379 | + iteration_limit = len(exported_program.graph.nodes) |
| 380 | + for _ in range(iteration_limit): |
| 381 | + # Run the partitioning. |
| 382 | + partition_list = capability_partitioner.propose_partitions() |
| 383 | + |
| 384 | + # Check if the nodes support the partitioning result. Mark the problematic nodes with `NXP_DO_NOT_DELEGATE`. |
| 385 | + partitioning_valid = self.validate_partitioning_result( |
| 386 | + exported_program.graph, partition_list, self.custom_delegation_options |
| 387 | + ) |
| 388 | + if partitioning_valid: |
| 389 | + # The result of the partitioning is fine |
| 390 | + break |
| 391 | + |
| 392 | + # Mark the partitions in the node `meta` attribute. |
346 | 393 | for partition in partition_list: |
347 | 394 | for node in partition.nodes: |
348 | 395 | delegation_tag = f"tag{partition.id}" |
349 | | - node.meta["delegation_tag"] = delegation_tag |
| 396 | + node.meta[NXP_DELEGATION_TAG] = delegation_tag |
350 | 397 | partition_tags[delegation_tag] = self.delegation_spec |
351 | 398 |
|
352 | 399 | tag_constant_data(exported_program) |
|
0 commit comments