44# LICENSE file in the root directory of this source tree.
55
66import logging
7- from enum import Enum
7+ import operator
88
9+ from executorch .backends .nxp .backend .edge_program_converter import functions_converters
10+ from executorch .backends .nxp .backend .node_format import NodeFormat , NXP_NODE_FORMAT
911from executorch .exir .dialects ._ops import ops as exir_ops
12+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload
1013
11- from torch import Node
1214from torch .export import ExportedProgram
15+ from torch .fx import Node
1316
1417logger = logging .getLogger (__name__ )
1518
16- NXP_NODE_FORMAT = "nxp_node_format" # Key into the `meta` attribute of nodes, which is mapped to the inferred format.
17-
18-
19- class NodeFormat (Enum ):
20- # Node's output in NCHW format
21- CHANNELS_FIRST = 0
22-
23- # Node's output format has no meaning
24- FORMATLESS = 1
25-
26- # Format has not been identified
27- NONE = 2
28-
29- def is_channels_first (self ) -> bool :
30- return self == NodeFormat .CHANNELS_FIRST
31-
3219
3320class NodeFormatInference :
3421 # Dictionary with Edge Aten ops that always use channels first format.
@@ -53,6 +40,9 @@ class NodeFormatInference:
5340 # Mapping between Node and its children (outputs)
5441 _node_outputs : dict [Node , list [Node ]]
5542
43+ # List of all edge operations, which are supported by the converter.
44+ _known_targets : list [EdgeOpOverload ]
45+
5646 def __init__ (self , edge_program : ExportedProgram ):
5747 self ._edge_program = edge_program
5848
@@ -66,6 +56,12 @@ def __init__(self, edge_program: ExportedProgram):
6656
6757 self ._type_changed_during_last_run = False
6858
59+ self ._known_targets = list (functions_converters ) + [
60+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
61+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
62+ operator .getitem ,
63+ ]
64+
6965 def identify_node_formats (self ):
7066 self ._type_changed_during_last_run = True
7167
@@ -100,9 +96,19 @@ def _infer_format_of_nodes(self, node: Node):
10096 logger .error (
10197 f"Node format inference for node type: { op_type } not found!"
10298 )
103- else :
99+ elif node .op != "call_function" or (
100+ hasattr (node , "target" ) and node .target in self ._known_targets
101+ ):
102+ # Generic node, or tensor.
104103 self ._handle_node_which_can_use_any_node_format (node )
105104
105+ else :
106+ # Don't infer the format for unknown nodes. These nodes will never be delegated, so they will divide
107+ # delegated partitions. Propagating the format here could unnecessarily enforce the format in one of these
108+ # partitions, which would require extra transpositions.
109+ for processed_node in self ._node_inputs [node ] + [node ]:
110+ self ._assign_format_to_node (processed_node , NodeFormat .NONE )
111+
106112 def _infer_format_based_on_io_ranks (self , node : Node ):
107113 """Determine the format of the output tensor of given "reshape style operator" based on the ranks of its input
108114 and output.
@@ -155,6 +161,10 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat):
155161 # Once CHANNEL_FIRST was assigned, we don't want to reassign
156162 return
157163
164+ if node_format is NodeFormat .NONE and old_node_format is not NodeFormat .NONE :
165+ # A format has already been assigned to the node before. Don't replace it with `NONE`.
166+ return
167+
158168 if old_node_format != node_format :
159169 self ._type_changed_during_last_run = True
160170
0 commit comments