Skip to content

Commit 2374f89

Browse files
committed
support channels last dim order in xnnpack
1 parent d9503e6 commit 2374f89

File tree

5 files changed

+145
-21
lines changed

5 files changed

+145
-21
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
5656

5757
# Set of ops that require memory format to be NCHW
5858
memory_sensitive_ops_nchw = {
59-
"output",
6059
exir_ops.edge.aten.squeeze_copy.dim,
6160
exir_ops.edge.aten.unsqueeze_copy.default,
61+
exir_ops.edge.aten.linear.default,
6262
}
6363

6464
# Tag which is added to a node's meta to indicate that it uses NHWC format.
@@ -91,10 +91,20 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool:
9191
return not self.is_nhwc_node(node)
9292

9393
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94-
return node.target in self.memory_sensitive_ops_nhwc
94+
return (
95+
node.target in self.memory_sensitive_ops_nhwc
96+
or node.name == "output"
97+
and not node.args[0][0].meta["val"].is_contiguous()
98+
)
9599

96100
def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
97-
return node.target in self.memory_sensitive_ops_nchw
101+
return (
102+
node.target in self.memory_sensitive_ops_nchw
103+
or node.name == "output"
104+
and node.args[0][0]
105+
.meta["val"]
106+
.is_contiguous()
107+
)
98108

99109
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
100110
# There are two conditions that must be met for a node to be able to
@@ -269,7 +279,10 @@ def input_to_nhwc(
269279
# serializing graph, but don't do anything else here
270280
self.mark_as_nhwc_node(input_node)
271281

272-
if self.is_nhwc_node(input_node):
282+
if input_node.op == "placeholder":
283+
if not input_node.meta["val"][0].is_contiguous():
284+
return
285+
elif self.is_nhwc_node(input_node):
273286
return
274287

275288
if not self.can_be_converted_to_nhwc(input_node):
@@ -333,7 +346,10 @@ def input_to_nchw(
333346
# do anything else here
334347
self.mark_as_nchw_node(input_node)
335348

336-
if self.is_nchw_node(input_node):
349+
if input_node.op == "placeholder":
350+
if input_node.meta["val"][0].is_contiguous():
351+
return
352+
elif self.is_nchw_node(input_node):
337353
return
338354

339355
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
@@ -371,7 +387,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
371387
# first input to be nhwc. This makes this node's output nhwc too
372388
# Currently, all nodes like this should have all of their other
373389
# inputs as nchw, so fail if this is not true
374-
self.input_to_nhwc(graph_module, node.args[0], node)
390+
if node.name == "output":
391+
self.input_to_nhwc(graph_module, node.args[0][0], node)
392+
else:
393+
self.input_to_nhwc(graph_module, node.args[0], node)
394+
375395
for input_node in node.all_input_nodes[1:]:
376396
if self.is_nhwc_node(input_node):
377397
raise AssertionError(

backends/xnnpack/runtime/XNNExecutor.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,24 @@ ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) {
106106
err == Error::Ok,
107107
Internal,
108108
"Failed to retrieve dim order from tensor!");
109-
ET_CHECK_OR_RETURN_ERROR(
110-
is_contiguous_dim_order(dim_order, tensor->dim()),
111-
Internal,
112-
"Expecting default dim_order but got a non default dim_order tensor for external input %u",
113-
i);
114109
size_t dims[XNN_MAX_TENSOR_DIMS];
115110
ET_CHECK_OR_RETURN_ERROR(
116111
num_dims <= XNN_MAX_TENSOR_DIMS,
117112
InvalidArgument,
118113
"XNNPACK backend accepts tensors with at most %d dims, but got %zu",
119114
XNN_MAX_TENSOR_DIMS,
120115
num_dims);
121-
for (int d = 0; d < num_dims; ++d) {
122-
dims[d] = tensor->size(d);
116+
117+
bool is_channels_last = executorch::runtime::is_channels_last_dim_order(dim_order, num_dims);
118+
if (is_channels_last) {
119+
dims[0] = tensor->size(0);
120+
dims[1] = tensor->size(2);
121+
dims[2] = tensor->size(3);
122+
dims[3] = tensor->size(1);
123+
} else {
124+
for (int d = 0; d < num_dims; ++d) {
125+
dims[d] = tensor->size(d);
126+
}
123127
}
124128
status =
125129
xnn_reshape_external_value(runtime_.get(), ext_id, num_dims, dims);
@@ -220,8 +224,24 @@ ET_NODISCARD Error XNNExecutor::resize_outputs(EValue** args) const {
220224

221225
// Convert new output shape into SizesType
222226
SizesType expected_output_size[kTensorDimensionLimit];
223-
for (size_t d = 0; d < num_dim; ++d) {
224-
expected_output_size[d] = static_cast<SizesType>(dims[d]);
227+
executorch::aten::DimOrderType dim_order[kTensorDimensionLimit];
228+
Error errr =
229+
ET_RUNTIME_NAMESPACE::get_dim_order(*out_tensor, dim_order, num_dim);
230+
ET_CHECK_OR_RETURN_ERROR(
231+
errr == Error::Ok,
232+
Internal,
233+
"Failed to retrieve dim order from tensor!");
234+
235+
bool is_channels_last = executorch::runtime::is_channels_last_dim_order(dim_order, num_dim);
236+
if (is_channels_last) {
237+
expected_output_size[0] = static_cast<SizesType>(dims[0]);
238+
expected_output_size[1] = static_cast<SizesType>(dims[3]);
239+
expected_output_size[2] = static_cast<SizesType>(dims[1]);
240+
expected_output_size[3] = static_cast<SizesType>(dims[2]);
241+
} else {
242+
for (size_t d = 0; d < num_dim; ++d) {
243+
expected_output_size[d] = static_cast<SizesType>(dims[d]);
244+
}
225245
}
226246

227247
executorch::aten::ArrayRef<SizesType> output_size{

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88

99
import torch
10+
from backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
1011
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
1112
ChannelsLastTaggedReshapePass,
1213
)
@@ -58,6 +59,88 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
5859
.run_method_and_compare_outputs()
5960
)
6061

62+
class LinearConv(torch.nn.Module):
63+
def __init__(self):
64+
super().__init__()
65+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
66+
self.linear1 = torch.nn.Linear(4, 3)
67+
68+
def forward(self, x):
69+
y = self.linear1(x)
70+
return self.conv1(y)
71+
72+
def test_conv_linear_dim_order_swaps_on_nhwc_input(self):
73+
tester = Tester(
74+
self.LinearConv().eval(),
75+
(torch.randn(1, 3, 6, 4).to(memory_format=torch.channels_last),),
76+
)
77+
78+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
79+
80+
def test_conv_linear_dim_order_swaps_on_nchw_input(self):
81+
tester = Tester(
82+
self.LinearConv().eval(),
83+
(torch.randn(1, 3, 6, 4),),
84+
)
85+
86+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
87+
88+
class ConvLinearConv(torch.nn.Module):
89+
def __init__(self):
90+
super().__init__()
91+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
92+
self.linear1 = torch.nn.Linear(4, 4)
93+
94+
def forward(self, x):
95+
y = self.conv1(x)
96+
return self.linear1(y)
97+
98+
def test_linear_conv_dim_order_swaps_on_nhwc_input(self):
99+
tester = Tester(
100+
self.ConvLinearConv().eval(),
101+
(torch.randn(1, 3, 6, 6).to(memory_format=torch.channels_last),),
102+
)
103+
104+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
105+
106+
def test_linear_conv_dim_order_swaps_on_nchw_input(self):
107+
tester = Tester(
108+
self.ConvLinearConv().eval(),
109+
(torch.randn(1, 3, 6, 6),),
110+
)
111+
112+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
113+
114+
class Bilinear(torch.nn.Module):
115+
def __init__(self):
116+
super().__init__()
117+
118+
def forward(self, x):
119+
return torch.nn.functional.interpolate(
120+
x, scale_factor=2, mode="bilinear", align_corners=True
121+
)
122+
123+
def test_nhwc_input_on_nhwc_op(self):
124+
tester = Tester(
125+
self.Bilinear().eval(),
126+
(
127+
torch.arange(8)
128+
.reshape(1, 2, 2, 2)
129+
.to(torch.float32)
130+
.to(memory_format=torch.channels_last),
131+
),
132+
)
133+
134+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
135+
136+
def test_nchw_input_on_nhwc_op(self):
137+
tester = Tester(
138+
self.Bilinear().eval(),
139+
(torch.arange(8).reshape(1, 2, 2, 2).to(torch.float32),),
140+
)
141+
142+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
143+
61144
def test_qs8_channels_last_tagged_reshape_pass(self):
62145
for module, num_reshape in self.modules.items():
63146
(

backends/xnnpack/test/tester/tester.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from executorch.exir.backend.backend_api import validation_disabled
3333
from executorch.exir.backend.partitioner import Partitioner
34+
from executorch.exir.dim_order_utils import get_memory_format
3435
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
3536

3637
from executorch.exir.print_program import pretty_print, print_program
@@ -533,10 +534,13 @@ def fn(x):
533534
# create random tensor inputs with the shapes given above:
534535
random_inputs = []
535536
for arg_idx in range(len(self.example_inputs)):
537+
memFormat = get_memory_format(
538+
list(self.example_inputs[arg_idx].dim_order())
539+
)
536540
random_inputs.append(
537-
torch.randn(input_shapes[arg_idx]).to(
538-
dtype=self.example_inputs[arg_idx].dtype
539-
)
541+
torch.randn(input_shapes[arg_idx])
542+
.to(dtype=self.example_inputs[arg_idx].dtype)
543+
.to(memory_format=memFormat)
540544
)
541545

542546
yield tuple(random_inputs)

backends/xnnpack/xnnpack_preprocess.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,6 @@ def preprocess(
145145

146146
node_to_external_map = generate_node_to_external_map(ep, graph_module)
147147

148-
# Make sure all inputs are contiguous_format or NCHW or default dim order
149-
assert_default_dim_order(graph_module)
150-
151148
# TODO retrace the graph module to lift the new params may have
152149
# been added to the graph in passes
153150

0 commit comments

Comments
 (0)