Skip to content

Commit 068ec0f

Browse files
committed
support channels last inputs in xnnpack
1 parent 9d726e8 commit 068ec0f

File tree

5 files changed

+215
-103
lines changed

5 files changed

+215
-103
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import string
8+
from logging import FATAL
9+
from tokenize import String
710
from typing import Optional, Tuple
811

912
import torch
@@ -56,9 +59,9 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
5659

5760
# Set of ops that require memory format to be NCHW
5861
memory_sensitive_ops_nchw = {
59-
"output",
6062
exir_ops.edge.aten.squeeze_copy.dim,
6163
exir_ops.edge.aten.unsqueeze_copy.default,
64+
exir_ops.edge.aten.linear.default,
6265
}
6366

6467
# Tag which is added to a node's meta to indicate that it uses NHWC format.
@@ -91,10 +94,20 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool:
9194
return not self.is_nhwc_node(node)
9295

9396
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94-
return node.target in self.memory_sensitive_ops_nhwc
97+
return (
98+
node.target in self.memory_sensitive_ops_nhwc
99+
or node.name == "output"
100+
and not node.args[0][0].meta["val"].is_contiguous()
101+
)
95102

96103
def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
97-
return node.target in self.memory_sensitive_ops_nchw
104+
return (
105+
node.target in self.memory_sensitive_ops_nchw
106+
or node.name == "output"
107+
and node.args[0][0]
108+
.meta["val"]
109+
.is_contiguous() # Need to consider output trace so out matches
110+
)
98111

99112
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
100113
# There are two conditions that must be met for a node to be able to
@@ -269,8 +282,17 @@ def input_to_nhwc(
269282
# serializing graph, but don't do anything else here
270283
self.mark_as_nhwc_node(input_node)
271284

272-
if self.is_nhwc_node(input_node):
285+
if input_node.op == "placeholder":
286+
if not input_node.meta["val"][0].is_contiguous():
287+
return
288+
elif self.is_nhwc_node(input_node):
273289
return
290+
# if (
291+
# self.is_nhwc_node(input_node)
292+
# or input_node.op == "placeholder"
293+
# and not input_node.meta["val"][0].is_contiguous()
294+
# ):
295+
# return
274296

275297
if not self.can_be_converted_to_nhwc(input_node):
276298
raise AssertionError(
@@ -333,8 +355,21 @@ def input_to_nchw(
333355
# do anything else here
334356
self.mark_as_nchw_node(input_node)
335357

336-
if self.is_nchw_node(input_node):
358+
if input_node.op == "placeholder":
359+
if input_node.meta["val"][0].is_contiguous():
360+
return
361+
elif self.is_nchw_node(input_node):
337362
return
363+
# TODO
364+
# meta trace happens before passes. At the end of pass, meta gets regenerated. eager mode assumes in/out stay same for conv. Linear has implicit nchw conv
365+
# if (
366+
# self.is_nchw_node(
367+
# input_node
368+
# ) # This is triggering as x (placeholder) is tagged as nchw
369+
# or input_node.op == "placeholder"
370+
# and input_node.meta["val"][0].is_contiguous()
371+
# ):
372+
# return
338373

339374
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
340375
# Already has an associated NCHW node
@@ -371,7 +406,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
371406
# first input to be nhwc. This makes this node's output nhwc too
372407
# Currently, all nodes like this should have all of their other
373408
# inputs as nchw, so fail if this is not true
374-
self.input_to_nhwc(graph_module, node.args[0], node)
409+
if node.name == "output":
410+
self.input_to_nhwc(graph_module, node.args[0][0], node)
411+
else:
412+
self.input_to_nhwc(graph_module, node.args[0], node)
413+
375414
for input_node in node.all_input_nodes[1:]:
376415
if self.is_nhwc_node(input_node):
377416
raise AssertionError(

backends/xnnpack/runtime/XNNExecutor.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,24 @@ ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) {
106106
err == Error::Ok,
107107
Internal,
108108
"Failed to retrieve dim order from tensor!");
109-
// ET_CHECK_OR_RETURN_ERROR(
110-
// is_contiguous_dim_order(dim_order, tensor->dim()),
111-
// Internal,
112-
// "Expecting default dim_order but got a non default dim_order tensor for external input %u",
113-
// i);
114109
size_t dims[XNN_MAX_TENSOR_DIMS];
115110
ET_CHECK_OR_RETURN_ERROR(
116111
num_dims <= XNN_MAX_TENSOR_DIMS,
117112
InvalidArgument,
118113
"XNNPACK backend accepts tensors with at most %d dims, but got %zu",
119114
XNN_MAX_TENSOR_DIMS,
120115
num_dims);
121-
for (int d = 0; d < num_dims; ++d) {
122-
dims[d] = tensor->size(d);
116+
117+
bool is_channels_last = executorch::runtime::is_channels_last_dim_order(dim_order, num_dims);
118+
if (is_channels_last) {
119+
dims[0] = tensor->size(0);
120+
dims[1] = tensor->size(2);
121+
dims[2] = tensor->size(3);
122+
dims[3] = tensor->size(1);
123+
} else {
124+
for (int d = 0; d < num_dims; ++d) {
125+
dims[d] = tensor->size(d);
126+
}
123127
}
124128
status =
125129
xnn_reshape_external_value(runtime_.get(), ext_id, num_dims, dims);
@@ -220,8 +224,24 @@ ET_NODISCARD Error XNNExecutor::resize_outputs(EValue** args) const {
220224

221225
// Convert new output shape into SizesType
222226
SizesType expected_output_size[kTensorDimensionLimit];
223-
for (size_t d = 0; d < num_dim; ++d) {
224-
expected_output_size[d] = static_cast<SizesType>(dims[d]);
227+
executorch::aten::DimOrderType dim_order[kTensorDimensionLimit];
228+
Error errr =
229+
ET_RUNTIME_NAMESPACE::get_dim_order(*out_tensor, dim_order, num_dim);
230+
ET_CHECK_OR_RETURN_ERROR(
231+
errr == Error::Ok,
232+
Internal,
233+
"Failed to retrieve dim order from tensor!");
234+
235+
bool is_channels_last = executorch::runtime::is_channels_last_dim_order(dim_order, num_dim);
236+
if (is_channels_last) {
237+
expected_output_size[0] = static_cast<SizesType>(dims[0]);
238+
expected_output_size[1] = static_cast<SizesType>(dims[3]);
239+
expected_output_size[2] = static_cast<SizesType>(dims[1]);
240+
expected_output_size[3] = static_cast<SizesType>(dims[2]);
241+
} else {
242+
for (size_t d = 0; d < num_dim; ++d) {
243+
expected_output_size[d] = static_cast<SizesType>(dims[d]);
244+
}
225245
}
226246

227247
executorch::aten::ArrayRef<SizesType> output_size{

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 119 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88

99
import torch
10+
from backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
1011
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
1112
ChannelsLastTaggedReshapePass,
1213
)
@@ -58,48 +59,87 @@ def setUp(self):
5859
# .run_method_and_compare_outputs()
5960
# )
6061

61-
# def test_channels_last_input_graph_transformation(self):
62-
# # Define a simple module for testing
63-
# class SimpleModule(torch.nn.Module):
64-
# def __init__(self):
65-
# super().__init__()
66-
# self.conv = torch.nn.Conv2d(3, 3, 3)
67-
# def forward(self, x):
68-
# return self.conv(x)
69-
# # Create a tester instance with NHWC input
70-
# tester = Tester(SimpleModule().eval(), (torch.randn(1, 3, 3, 3).to(memory_format=torch.channels_last),))
71-
# # Run the export and pass stages
72-
# tester.export().to_edge().run_passes(self.PassStage)
73-
# # Check the graph for expected nodes
74-
# tester.check_count({
75-
# "executorch_exir_dialects_edge__ops_aten__to_copy_default": 2, # should be 1 but its 2
76-
# "executorch_exir_dialects_edge__ops_aten_convolution_default": 1
77-
# })
78-
# tester.dump_artifact()
79-
80-
def test_nhwc_input(self):
81-
class SimpleModule(torch.nn.Module):
82-
def __init__(self):
83-
super().__init__()
84-
self.conv = torch.nn.Conv2d(3, 3, 3)
85-
def forward(self, x):
86-
return self.conv(x)
87-
88-
tester = Tester(SimpleModule().eval(), (torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last),))
89-
90-
tester2 = Tester(SimpleModule().eval(), (torch.randn(1, 3, 8, 8).to(memory_format=torch.channels_last),))
91-
tester2.export().to_edge().run_passes(self.PassStage).dump_artifact()
92-
93-
94-
tester.export() \
95-
.to_edge_transform_and_lower() \
96-
.dump_artifact()\
97-
.to_executorch() \
98-
.dump_artifact()\
99-
.serialize() \
100-
.run_method_and_compare_outputs()
62+
class LinearConv(torch.nn.Module):
63+
def __init__(self):
64+
super().__init__()
65+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
66+
self.linear1 = torch.nn.Linear(4, 3)
67+
68+
def forward(self, x):
69+
y = self.linear1(x)
70+
return self.conv1(y)
71+
72+
def test_conv_linear_dim_order_swaps_on_nhwc_input(self):
73+
tester = Tester(
74+
self.LinearConv().eval(),
75+
(torch.randn(1, 3, 6, 4).to(memory_format=torch.channels_last),),
76+
)
77+
78+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
79+
80+
def test_conv_linear_dim_order_swaps_on_nchw_input(self):
81+
tester = Tester(
82+
self.LinearConv().eval(),
83+
(torch.randn(1, 3, 6, 4),),
84+
)
85+
86+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
87+
88+
class ConvLinearConv(torch.nn.Module):
89+
def __init__(self):
90+
super().__init__()
91+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
92+
self.linear1 = torch.nn.Linear(4, 4)
93+
94+
def forward(self, x):
95+
y = self.conv1(x)
96+
return self.linear1(y)
97+
98+
def test_linear_conv_dim_order_swaps_on_nhwc_input(self):
99+
tester = Tester(
100+
self.ConvLinearConv().eval(),
101+
(torch.randn(1, 3, 6, 6).to(memory_format=torch.channels_last),),
102+
)
101103

104+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
102105

106+
def test_linear_conv_dim_order_swaps_on_nchw_input(self):
107+
tester = Tester(
108+
self.ConvLinearConv().eval(),
109+
(torch.randn(1, 3, 6, 6),),
110+
)
111+
112+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
113+
114+
class Bilinear(torch.nn.Module):
115+
def __init__(self):
116+
super().__init__()
117+
118+
def forward(self, x):
119+
return torch.nn.functional.interpolate(
120+
x, scale_factor=2, mode="bilinear", align_corners=True
121+
)
122+
123+
def test_nhwc_input_on_nhwc_op(self):
124+
tester = Tester(
125+
self.Bilinear().eval(),
126+
(
127+
torch.arange(8)
128+
.reshape(1, 2, 2, 2)
129+
.to(torch.float32)
130+
.to(memory_format=torch.channels_last),
131+
),
132+
)
133+
134+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
135+
136+
def test_nchw_input_on_nhwc_op(self):
137+
tester = Tester(
138+
self.Bilinear().eval(),
139+
(torch.arange(8).reshape(1, 2, 2, 2).to(torch.float32),),
140+
)
141+
142+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
103143

104144
# def test_qs8_channels_last_tagged_reshape_pass(self):
105145
# for module, num_reshape in self.modules.items():
@@ -190,45 +230,45 @@ def forward(self, x):
190230
return x
191231

192232
# def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self):
193-
# Copy #1 is for input to conv, nchw -> nhwc
194-
# Copy #2 is for conv to _native_batch_norm_legit_no_training, nhwc -> nchw
195-
# Copy #3 is for input to mean, nchw -> nhwc
196-
# Copy #4 is for output, nhwc -> nchw
197-
198-
# The graph looks like:
199-
# graph():
200-
# %arg0_1 : [#users=1] = placeholder[target=arg0_1]
201-
# %aten__to_copy_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {memory_format: torch.channels_last})
202-
# %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
203-
# %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
204-
# %aten_convolution_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten__to_copy_default, %_param_constant0, %_param_constant1, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), kwargs = {})
205-
# %aten__to_copy_default_1 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_convolution_default,), kwargs = {memory_format: torch.contiguous_format})
206-
# %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
207-
# %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
208-
# %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
209-
# %_tensor_constant1 : [#users=1] = get_attr[target=_tensor_constant1]
210-
# %aten__native_batch_norm_legit_no_training_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten__to_copy_default_1, %_param_constant2, %_param_constant3, %_tensor_constant0, %_tensor_constant1, 0.1, 1e-05), kwargs = {})
211-
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default, 0), kwargs = {})
212-
# %aten_hardtanh_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem, 0, 6), kwargs = {})
213-
# %aten__to_copy_default_2 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_hardtanh_default,), kwargs = {memory_format: torch.channels_last})
214-
# %aten_mean_dim : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mean.dim](args = (%aten__to_copy_default_2, [-1, -2], True), kwargs = {})
215-
# %aten__to_copy_default_3 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_mean_dim,), kwargs = {memory_format: torch.contiguous_format})
216-
# return [aten__to_copy_default_3]
217-
# (
218-
# Tester(
219-
# self.Conv2dBnHardtanhMeanSequenceModule().eval(),
220-
# (torch.randn(1, 1, 6, 6),),
221-
# )
222-
# .export()
223-
# .to_edge()
224-
# .run_passes(self.PassStage)
225-
# .check_count(
226-
# {
227-
# self.to_copy_name: 4,
228-
# }
229-
# )
230-
# .run_method_and_compare_outputs()
231-
# )
233+
# Copy #1 is for input to conv, nchw -> nhwc
234+
# Copy #2 is for conv to _native_batch_norm_legit_no_training, nhwc -> nchw
235+
# Copy #3 is for input to mean, nchw -> nhwc
236+
# Copy #4 is for output, nhwc -> nchw
237+
238+
# The graph looks like:
239+
# graph():
240+
# %arg0_1 : [#users=1] = placeholder[target=arg0_1]
241+
# %aten__to_copy_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {memory_format: torch.channels_last})
242+
# %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
243+
# %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
244+
# %aten_convolution_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten__to_copy_default, %_param_constant0, %_param_constant1, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), kwargs = {})
245+
# %aten__to_copy_default_1 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_convolution_default,), kwargs = {memory_format: torch.contiguous_format})
246+
# %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
247+
# %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
248+
# %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
249+
# %_tensor_constant1 : [#users=1] = get_attr[target=_tensor_constant1]
250+
# %aten__native_batch_norm_legit_no_training_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten__to_copy_default_1, %_param_constant2, %_param_constant3, %_tensor_constant0, %_tensor_constant1, 0.1, 1e-05), kwargs = {})
251+
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default, 0), kwargs = {})
252+
# %aten_hardtanh_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem, 0, 6), kwargs = {})
253+
# %aten__to_copy_default_2 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_hardtanh_default,), kwargs = {memory_format: torch.channels_last})
254+
# %aten_mean_dim : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mean.dim](args = (%aten__to_copy_default_2, [-1, -2], True), kwargs = {})
255+
# %aten__to_copy_default_3 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_mean_dim,), kwargs = {memory_format: torch.contiguous_format})
256+
# return [aten__to_copy_default_3]
257+
# (
258+
# Tester(
259+
# self.Conv2dBnHardtanhMeanSequenceModule().eval(),
260+
# (torch.randn(1, 1, 6, 6),),
261+
# )
262+
# .export()
263+
# .to_edge()
264+
# .run_passes(self.PassStage)
265+
# .check_count(
266+
# {
267+
# self.to_copy_name: 4,
268+
# }
269+
# )
270+
# .run_method_and_compare_outputs()
271+
# )
232272

233273
class Conv2dDynamicQuant(torch.nn.Module):
234274
def __init__(self):

0 commit comments

Comments
 (0)