4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from enum import Enum
7
8
from typing import Optional , Tuple
8
9
9
10
import torch
19
20
from executorch .exir .pass_base import PassResult
20
21
21
22
23
+ class InputDimOrder (Enum ):
24
+ NCHW = 1
25
+ NHWC = 2
26
+
27
+
22
28
# TODO(T151254305) use subgraph_rewriter
23
29
class ChannelsLastTaggedReshapePass (XNNPACKPass ):
24
30
"""
@@ -83,17 +89,49 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
83
89
# is done
84
90
PARTNER_NODE = "XNN_CHANNELS_LAST_TAGGED_RESHAPE_PARTNER_NODE"
85
91
86
- def mark_as_nhwc_node (self , node : torch .fx .Node ) -> None :
92
+ @staticmethod
93
+ def mark_as_nhwc_node (node : torch .fx .Node ) -> None :
87
94
node .meta [ChannelsLastTaggedReshapePass .XNN_NHWC_NODE ] = True
88
95
89
- def mark_as_nchw_node (self , node : torch .fx .Node ) -> None :
96
+ @staticmethod
97
+ def mark_as_nchw_node (node : torch .fx .Node ) -> None :
90
98
node .meta [ChannelsLastTaggedReshapePass .XNN_NHWC_NODE ] = False
91
99
92
- def is_nhwc_node (self , node : torch .fx .Node ) -> bool :
100
+ def tag_node (self , node : torch .fx .Node ) -> None :
101
+ if node .kwargs ["memory_format" ] == torch .channels_last :
102
+ self .mark_as_nhwc_node (node )
103
+ else :
104
+ self .mark_as_nchw_node (node )
105
+
106
+ @staticmethod
107
+ def is_nhwc_node (node : torch .fx .Node ) -> bool :
108
+ if is_dequant (node ) and len (node .all_input_nodes ) > 0 :
109
+ quantize_node = node .args [0 ]
110
+ if len (quantize_node .all_input_nodes ) > 0 :
111
+ actual_node = quantize_node .args [0 ]
112
+ if actual_node .op == "placeholder" :
113
+ return not actual_node .meta ["val" ][0 ].is_contiguous ()
114
+ else :
115
+ return actual_node .meta .get (
116
+ ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False
117
+ )
118
+
93
119
return node .meta .get (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False )
94
120
95
- def is_nchw_node (self , node : torch .fx .Node ) -> bool :
96
- return not self .is_nhwc_node (node )
121
+ @staticmethod
122
+ def is_nchw_node (node : torch .fx .Node ) -> bool :
123
+ if is_dequant (node ) and len (node .all_input_nodes ) > 0 :
124
+ quantize_node = node .args [0 ]
125
+ if len (quantize_node .all_input_nodes ) > 0 :
126
+ actual_node = quantize_node .args [0 ]
127
+ if actual_node .op == "placeholder" :
128
+ return actual_node .meta ["val" ][0 ].is_contiguous ()
129
+ else :
130
+ return not actual_node .meta .get (
131
+ ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False
132
+ )
133
+
134
+ return not ChannelsLastTaggedReshapePass .is_nhwc_node (node )
97
135
98
136
def requires_nhwc_input (self , node : torch .fx .Node ) -> bool :
99
137
return node .target in self .memory_sensitive_ops_nhwc
@@ -111,7 +149,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
111
149
is_nchw_constant = (
112
150
is_param_node (self .exported_program , node )
113
151
and (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in node .meta )
114
- and (self .is_nchw_node (node ))
152
+ and (ChannelsLastTaggedReshapePass .is_nchw_node (node ))
115
153
)
116
154
return is_4d and not is_nchw_constant
117
155
@@ -273,6 +311,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
273
311
# in that case
274
312
self .make_partners (original_input , copy_node )
275
313
314
+ def input_dim_order (
315
+ self , input_node : torch .fx .Node , input_order : InputDimOrder
316
+ ) -> bool :
317
+ if input_node .op == "placeholder" :
318
+ return (
319
+ input_node .meta ["val" ].is_contiguous ()
320
+ if input_order == InputDimOrder .NCHW
321
+ else not input_node .meta ["val" ].is_contiguous ()
322
+ )
323
+ else :
324
+ return (
325
+ ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
326
+ if input_order == InputDimOrder .NCHW
327
+ else ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
328
+ )
329
+
276
330
def input_to_nhwc (
277
331
self ,
278
332
graph_module : torch .fx .GraphModule ,
@@ -282,7 +336,7 @@ def input_to_nhwc(
282
336
if is_param_node (self .exported_program , input_node ):
283
337
if (
284
338
ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
285
- and self .is_nchw_node (input_node )
339
+ and ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
286
340
):
287
341
# This constant data tensor has been used somewhere else
288
342
# in NCHW format so we can't use it here in NHWC format
@@ -296,7 +350,10 @@ def input_to_nhwc(
296
350
if input_node .op == "placeholder" :
297
351
if not input_node .meta ["val" ][0 ].is_contiguous ():
298
352
return
299
- elif self .is_nhwc_node (input_node ):
353
+ elif ChannelsLastTaggedReshapePass .is_nhwc_node (input_node ):
354
+ return
355
+
356
+ if self .input_dim_order (input_node , InputDimOrder .NHWC ):
300
357
return
301
358
302
359
if not self .can_be_converted_to_nhwc (input_node ):
@@ -326,6 +383,8 @@ def input_to_nhwc(
326
383
args = (input_node ,),
327
384
memory_format = torch .channels_last ,
328
385
)
386
+ # Use static method for consistency
387
+ ChannelsLastTaggedReshapePass .mark_as_nhwc_node (input_node_nhwc )
329
388
330
389
if is_dynamic_input :
331
390
# Replace downstream input_nodes with NHWC node
@@ -348,7 +407,7 @@ def input_to_nchw(
348
407
if is_param_node (self .exported_program , input_node ):
349
408
if (
350
409
ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
351
- and self .is_nhwc_node (input_node )
410
+ and ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
352
411
):
353
412
# This constant data tensor has been used somewhere else
354
413
# in NHWC format so we can't use it here in NCHW format
@@ -363,7 +422,10 @@ def input_to_nchw(
363
422
if input_node .op == "placeholder" :
364
423
if input_node .meta ["val" ].is_contiguous ():
365
424
return
366
- elif self .is_nchw_node (input_node ):
425
+ elif ChannelsLastTaggedReshapePass .is_nchw_node (input_node ):
426
+ return
427
+
428
+ if self .input_dim_order (input_node , InputDimOrder .NCHW ):
367
429
return
368
430
369
431
if ChannelsLastTaggedReshapePass .PARTNER_NODE in input_node .meta :
@@ -380,6 +442,7 @@ def input_to_nchw(
380
442
args = (input_node ,),
381
443
memory_format = torch .contiguous_format ,
382
444
)
445
+ ChannelsLastTaggedReshapePass .mark_as_nchw_node (input_node_nchw )
383
446
384
447
self .insert_copy_and_assign_partner_nodes_quantization_sensitive (
385
448
graph_module = graph_module ,
@@ -393,7 +456,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
393
456
original_nodes = list (graph .nodes )
394
457
for node in original_nodes :
395
458
if len (node .all_input_nodes ) == 0 :
396
- # This node has no inputs so we don't need to change anything
459
+ # This node has no inputs so we don't need to change anything, but still need to tag input nodes
460
+ if "val" in node .meta and isinstance (node .meta ["val" ], torch .Tensor ):
461
+ if node .meta ["val" ].is_contiguous ():
462
+ self .mark_as_nchw_node (node )
463
+ else :
464
+ self .mark_as_nhwc_node (node )
397
465
continue
398
466
399
467
# Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
@@ -407,10 +475,12 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
407
475
elif self .requires_nhwc_input (node ):
408
476
# Nodes which enter this branch are ones that require their
409
477
# first input to be nhwc. This makes this node's output nhwc too
410
-
411
478
self .input_to_nhwc (graph_module , node .args [0 ], node )
412
- for input_node in node .all_input_nodes :
413
- if input_node .op == "placeholder" and self .is_nhwc_node (input_node ):
479
+ for input_node in node .all_input_nodes [1 :]:
480
+ if (
481
+ input_node .op == "placeholder"
482
+ and ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
483
+ ):
414
484
raise AssertionError (
415
485
f"Expected { input_node } to be NCHW in channels last reshape pass"
416
486
)
@@ -419,11 +489,14 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
419
489
# The node requires nchw inputs
420
490
for input_node in node .all_input_nodes :
421
491
self .input_to_nchw (graph_module , input_node , node )
492
+ elif node .target == exir_ops .edge .aten ._to_copy .default :
493
+ self .tag_node (node )
422
494
else :
423
495
# The node can have inputs in any format (but all must be the
424
496
# same format)
425
497
is_or_isnt_nhwc_node = [
426
- self .is_nhwc_node (input_node ) for input_node in node .all_input_nodes
498
+ ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
499
+ for input_node in node .all_input_nodes
427
500
]
428
501
if all (is_or_isnt_nhwc_node ):
429
502
# All inputs are nhwc so this node's output is nhwc too
0 commit comments