Skip to content

Commit 75f968d

Browse files
authored
Make determinism of channels_last more conservative
Differential Revision: D83998877 Pull Request resolved: #14862
1 parent 29b98c3 commit 75f968d

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def is_nhwc_node(node: torch.fx.Node) -> bool:
110110
if len(quantize_node.all_input_nodes) > 0:
111111
actual_node = quantize_node.args[0]
112112
if actual_node.op == "placeholder":
113-
return not actual_node.meta["val"][0].is_contiguous()
113+
return ChannelsLastTaggedReshapePass._is_nhwc_tensor(
114+
actual_node.meta["val"][0]
115+
)
114116
else:
115117
return actual_node.meta.get(
116118
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
@@ -125,14 +127,36 @@ def is_nchw_node(node: torch.fx.Node) -> bool:
125127
if len(quantize_node.all_input_nodes) > 0:
126128
actual_node = quantize_node.args[0]
127129
if actual_node.op == "placeholder":
128-
return actual_node.meta["val"][0].is_contiguous()
130+
return not ChannelsLastTaggedReshapePass._is_nhwc_tensor(
131+
actual_node.meta["val"][0]
132+
)
129133
else:
130134
return not actual_node.meta.get(
131135
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False
132136
)
133137

134138
return not ChannelsLastTaggedReshapePass.is_nhwc_node(node)
135139

140+
@staticmethod
141+
def _is_nhwc_tensor(tensor: torch.Tensor) -> bool:
142+
nhwc = tensor.is_contiguous(memory_format=torch.channels_last)
143+
nchw = tensor.is_contiguous()
144+
# if both are true false
145+
# if both nchw and nhwc are true
146+
# then we want to see this is nchw hence return false
147+
# if either of nchw or nhwc is false, then just rely on hwc
148+
# if both are false, mayb channels_last_3d, then return nhwc
149+
# however this should not happen here
150+
# return (not (nchw and nhwc)) and nhwc
151+
# Readable version
152+
if nchw and nhwc:
153+
return False
154+
else:
155+
return nhwc
156+
157+
def _is_nhwc(self, tensor: torch.Tensor) -> bool:
158+
return ChannelsLastTaggedReshapePass._is_nhwc_tensor(tensor)
159+
136160
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
137161
return node.target in self.memory_sensitive_ops_nhwc
138162

@@ -315,11 +339,8 @@ def input_dim_order(
315339
self, input_node: torch.fx.Node, input_order: InputDimOrder
316340
) -> bool:
317341
if input_node.op == "placeholder":
318-
return (
319-
input_node.meta["val"].is_contiguous()
320-
if input_order == InputDimOrder.NCHW
321-
else not input_node.meta["val"].is_contiguous()
322-
)
342+
is_nhwc = self._is_nhwc(input_node.meta["val"])
343+
return not is_nhwc if input_order == InputDimOrder.NCHW else is_nhwc
323344
else:
324345
return (
325346
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
@@ -348,7 +369,7 @@ def input_to_nhwc(
348369
self.mark_as_nhwc_node(input_node)
349370

350371
if input_node.op == "placeholder":
351-
if not input_node.meta["val"][0].is_contiguous():
372+
if self._is_nhwc(input_node.meta["val"][0]):
352373
return
353374
elif ChannelsLastTaggedReshapePass.is_nhwc_node(input_node):
354375
return
@@ -420,7 +441,7 @@ def input_to_nchw(
420441
self.mark_as_nchw_node(input_node)
421442

422443
if input_node.op == "placeholder":
423-
if input_node.meta["val"].is_contiguous():
444+
if not self._is_nhwc(input_node.meta["val"]):
424445
return
425446
elif ChannelsLastTaggedReshapePass.is_nchw_node(input_node):
426447
return
@@ -462,17 +483,17 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
462483
and isinstance(node.meta["val"], torch.Tensor)
463484
and len(node.meta["val"].shape) == 4
464485
):
465-
if node.meta["val"].is_contiguous():
466-
self.mark_as_nchw_node(node)
467-
else:
486+
if self._is_nhwc(node.meta["val"]):
468487
self.mark_as_nhwc_node(node)
488+
else:
489+
self.mark_as_nchw_node(node)
469490
continue
470491

471492
# Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
472493
if node.op == "output":
473494
out_tuple = node.args[0]
474495
for out_node in out_tuple:
475-
if out_node.meta["val"].is_contiguous():
496+
if not self._is_nhwc(out_node.meta["val"]):
476497
self.input_to_nchw(graph_module, out_node, node)
477498
else:
478499
self.input_to_nhwc(graph_module, out_node, node)

0 commit comments

Comments
 (0)