Skip to content

Commit e359d50

Browse files
committed
support channels last dim order in xnnpack
1 parent d9503e6 commit e359d50

File tree

9 files changed

+247
-22
lines changed

9 files changed

+247
-22
lines changed

.vscode/launch.json

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
// Use IntelliSense to learn about possible attributes.
3+
// Hover to view descriptions of existing attributes.
4+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5+
"version": "0.2.0",
6+
"configurations": [
7+
{
8+
"name": "Debug CMake project",
9+
"type": "lldb", // https://github.com/vadimcn/vscode-lldb
10+
"request": "launch",
11+
"program": "${command:cmake.launchTargetPath}",
12+
"args": [
13+
"--model_path=./add.pte",
14+
]
15+
}
16+
]
17+
}

.vscode/settings.json

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
{
2+
"files.associations": {
3+
"cstdlib": "cpp",
4+
"__bit_reference": "cpp",
5+
"__hash_table": "cpp",
6+
"__locale": "cpp",
7+
"__node_handle": "cpp",
8+
"__split_buffer": "cpp",
9+
"__tree": "cpp",
10+
"__verbose_abort": "cpp",
11+
"array": "cpp",
12+
"bitset": "cpp",
13+
"cctype": "cpp",
14+
"charconv": "cpp",
15+
"clocale": "cpp",
16+
"cmath": "cpp",
17+
"complex": "cpp",
18+
"condition_variable": "cpp",
19+
"cstdarg": "cpp",
20+
"cstdint": "cpp",
21+
"cstdio": "cpp",
22+
"cstring": "cpp",
23+
"ctime": "cpp",
24+
"cwchar": "cpp",
25+
"cwctype": "cpp",
26+
"deque": "cpp",
27+
"execution": "cpp",
28+
"memory": "cpp",
29+
"forward_list": "cpp",
30+
"future": "cpp",
31+
"initializer_list": "cpp",
32+
"iomanip": "cpp",
33+
"ios": "cpp",
34+
"iosfwd": "cpp",
35+
"iostream": "cpp",
36+
"istream": "cpp",
37+
"limits": "cpp",
38+
"list": "cpp",
39+
"locale": "cpp",
40+
"map": "cpp",
41+
"mutex": "cpp",
42+
"new": "cpp",
43+
"optional": "cpp",
44+
"print": "cpp",
45+
"queue": "cpp",
46+
"ratio": "cpp",
47+
"regex": "cpp",
48+
"set": "cpp",
49+
"shared_mutex": "cpp",
50+
"sstream": "cpp",
51+
"stack": "cpp",
52+
"stdexcept": "cpp",
53+
"streambuf": "cpp",
54+
"string": "cpp",
55+
"string_view": "cpp",
56+
"typeindex": "cpp",
57+
"typeinfo": "cpp",
58+
"unordered_map": "cpp",
59+
"unordered_set": "cpp",
60+
"variant": "cpp",
61+
"vector": "cpp",
62+
"algorithm": "cpp",
63+
"iterator": "cpp",
64+
"tuple": "cpp",
65+
"span": "cpp"
66+
},
67+
"C_Cpp.default.compilerPath": "/library/developer/commandlinetools/usr/bin/c++",
68+
"python.analysis.typeCheckingMode": "off"
69+
}

CMakePresets.json

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,17 @@
104104
"Windows"
105105
]
106106
}
107-
}
107+
},
108+
{
109+
"name": "Executorch",
110+
"displayName": "Executorch",
111+
"description": "Sets Ninja generator, build and install directory",
112+
"generator": "Ninja",
113+
"binaryDir": "${sourceDir}/out/build/${presetName}",
114+
"cacheVariables": {
115+
"CMAKE_BUILD_TYPE": "Debug",
116+
"CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}"
117+
}
118+
}
108119
]
109120
}

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
5656

5757
# Set of ops that require memory format to be NCHW
5858
memory_sensitive_ops_nchw = {
59-
"output",
6059
exir_ops.edge.aten.squeeze_copy.dim,
6160
exir_ops.edge.aten.unsqueeze_copy.default,
61+
exir_ops.edge.aten.linear.default,
6262
}
6363

6464
# Tag which is added to a node's meta to indicate that it uses NHWC format.
@@ -91,10 +91,20 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool:
9191
return not self.is_nhwc_node(node)
9292

9393
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94-
return node.target in self.memory_sensitive_ops_nhwc
94+
return (
95+
node.target in self.memory_sensitive_ops_nhwc
96+
or node.name == "output"
97+
and not node.args[0][0].meta["val"].is_contiguous()
98+
)
9599

96100
def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
97-
return node.target in self.memory_sensitive_ops_nchw
101+
return (
102+
node.target in self.memory_sensitive_ops_nchw
103+
or node.name == "output"
104+
and node.args[0][0]
105+
.meta["val"]
106+
.is_contiguous()
107+
)
98108

99109
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
100110
# There are two conditions that must be met for a node to be able to
@@ -269,7 +279,10 @@ def input_to_nhwc(
269279
# serializing graph, but don't do anything else here
270280
self.mark_as_nhwc_node(input_node)
271281

272-
if self.is_nhwc_node(input_node):
282+
if input_node.op == "placeholder":
283+
if not input_node.meta["val"][0].is_contiguous():
284+
return
285+
elif self.is_nhwc_node(input_node):
273286
return
274287

275288
if not self.can_be_converted_to_nhwc(input_node):
@@ -333,7 +346,10 @@ def input_to_nchw(
333346
# do anything else here
334347
self.mark_as_nchw_node(input_node)
335348

336-
if self.is_nchw_node(input_node):
349+
if input_node.op == "placeholder":
350+
if input_node.meta["val"][0].is_contiguous():
351+
return
352+
elif self.is_nchw_node(input_node):
337353
return
338354

339355
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
@@ -371,7 +387,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
371387
# first input to be nhwc. This makes this node's output nhwc too
372388
# Currently, all nodes like this should have all of their other
373389
# inputs as nchw, so fail if this is not true
374-
self.input_to_nhwc(graph_module, node.args[0], node)
390+
if node.name == "output":
391+
self.input_to_nhwc(graph_module, node.args[0][0], node)
392+
else:
393+
self.input_to_nhwc(graph_module, node.args[0], node)
394+
375395
for input_node in node.all_input_nodes[1:]:
376396
if self.is_nhwc_node(input_node):
377397
raise AssertionError(

backends/xnnpack/runtime/XNNExecutor.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,24 @@ ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) {
106106
err == Error::Ok,
107107
Internal,
108108
"Failed to retrieve dim order from tensor!");
109-
ET_CHECK_OR_RETURN_ERROR(
110-
is_contiguous_dim_order(dim_order, tensor->dim()),
111-
Internal,
112-
"Expecting default dim_order but got a non default dim_order tensor for external input %u",
113-
i);
114109
size_t dims[XNN_MAX_TENSOR_DIMS];
115110
ET_CHECK_OR_RETURN_ERROR(
116111
num_dims <= XNN_MAX_TENSOR_DIMS,
117112
InvalidArgument,
118113
"XNNPACK backend accepts tensors with at most %d dims, but got %zu",
119114
XNN_MAX_TENSOR_DIMS,
120115
num_dims);
121-
for (int d = 0; d < num_dims; ++d) {
122-
dims[d] = tensor->size(d);
116+
117+
bool is_channels_last = executorch::runtime::is_channels_last_dim_order(dim_order, num_dims);
118+
if (is_channels_last) {
119+
dims[0] = tensor->size(0);
120+
dims[1] = tensor->size(2);
121+
dims[2] = tensor->size(3);
122+
dims[3] = tensor->size(1);
123+
} else {
124+
for (int d = 0; d < num_dims; ++d) {
125+
dims[d] = tensor->size(d);
126+
}
123127
}
124128
status =
125129
xnn_reshape_external_value(runtime_.get(), ext_id, num_dims, dims);
@@ -220,8 +224,24 @@ ET_NODISCARD Error XNNExecutor::resize_outputs(EValue** args) const {
220224

221225
// Convert new output shape into SizesType
222226
SizesType expected_output_size[kTensorDimensionLimit];
223-
for (size_t d = 0; d < num_dim; ++d) {
224-
expected_output_size[d] = static_cast<SizesType>(dims[d]);
227+
executorch::aten::DimOrderType dim_order[kTensorDimensionLimit];
228+
Error errr =
229+
ET_RUNTIME_NAMESPACE::get_dim_order(*out_tensor, dim_order, num_dim);
230+
ET_CHECK_OR_RETURN_ERROR(
231+
errr == Error::Ok,
232+
Internal,
233+
"Failed to retrieve dim order from tensor!");
234+
235+
bool is_channels_last = executorch::runtime::is_channels_last_dim_order(dim_order, num_dim);
236+
if (is_channels_last) {
237+
expected_output_size[0] = static_cast<SizesType>(dims[0]);
238+
expected_output_size[1] = static_cast<SizesType>(dims[3]);
239+
expected_output_size[2] = static_cast<SizesType>(dims[1]);
240+
expected_output_size[3] = static_cast<SizesType>(dims[2]);
241+
} else {
242+
for (size_t d = 0; d < num_dim; ++d) {
243+
expected_output_size[d] = static_cast<SizesType>(dims[d]);
244+
}
225245
}
226246

227247
executorch::aten::ArrayRef<SizesType> output_size{

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88

99
import torch
10+
from backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
1011
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
1112
ChannelsLastTaggedReshapePass,
1213
)
@@ -58,6 +59,88 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
5859
.run_method_and_compare_outputs()
5960
)
6061

62+
class LinearConv(torch.nn.Module):
63+
def __init__(self):
64+
super().__init__()
65+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
66+
self.linear1 = torch.nn.Linear(4, 3)
67+
68+
def forward(self, x):
69+
y = self.linear1(x)
70+
return self.conv1(y)
71+
72+
def test_conv_linear_dim_order_swaps_on_nhwc_input(self):
73+
tester = Tester(
74+
self.LinearConv().eval(),
75+
(torch.randn(1, 3, 6, 4).to(memory_format=torch.channels_last),),
76+
)
77+
78+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
79+
80+
def test_conv_linear_dim_order_swaps_on_nchw_input(self):
81+
tester = Tester(
82+
self.LinearConv().eval(),
83+
(torch.randn(1, 3, 6, 4),),
84+
)
85+
86+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
87+
88+
class ConvLinearConv(torch.nn.Module):
89+
def __init__(self):
90+
super().__init__()
91+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
92+
self.linear1 = torch.nn.Linear(4, 4)
93+
94+
def forward(self, x):
95+
y = self.conv1(x)
96+
return self.linear1(y)
97+
98+
def test_linear_conv_dim_order_swaps_on_nhwc_input(self):
99+
tester = Tester(
100+
self.ConvLinearConv().eval(),
101+
(torch.randn(1, 3, 6, 6).to(memory_format=torch.channels_last),),
102+
)
103+
104+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
105+
106+
def test_linear_conv_dim_order_swaps_on_nchw_input(self):
107+
tester = Tester(
108+
self.ConvLinearConv().eval(),
109+
(torch.randn(1, 3, 6, 6),),
110+
)
111+
112+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
113+
114+
class Bilinear(torch.nn.Module):
115+
def __init__(self):
116+
super().__init__()
117+
118+
def forward(self, x):
119+
return torch.nn.functional.interpolate(
120+
x, scale_factor=2, mode="bilinear", align_corners=True
121+
)
122+
123+
def test_nhwc_input_on_nhwc_op(self):
124+
tester = Tester(
125+
self.Bilinear().eval(),
126+
(
127+
torch.arange(8)
128+
.reshape(1, 2, 2, 2)
129+
.to(torch.float32)
130+
.to(memory_format=torch.channels_last),
131+
),
132+
)
133+
134+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
135+
136+
def test_nchw_input_on_nhwc_op(self):
137+
tester = Tester(
138+
self.Bilinear().eval(),
139+
(torch.arange(8).reshape(1, 2, 2, 2).to(torch.float32),),
140+
)
141+
142+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
143+
61144
def test_qs8_channels_last_tagged_reshape_pass(self):
62145
for module, num_reshape in self.modules.items():
63146
(

backends/xnnpack/test/tester/tester.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from executorch.exir.backend.backend_api import validation_disabled
3333
from executorch.exir.backend.partitioner import Partitioner
34+
from executorch.exir.dim_order_utils import get_memory_format
3435
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
3536

3637
from executorch.exir.print_program import pretty_print, print_program
@@ -533,10 +534,13 @@ def fn(x):
533534
# create random tensor inputs with the shapes given above:
534535
random_inputs = []
535536
for arg_idx in range(len(self.example_inputs)):
537+
memFormat = get_memory_format(
538+
list(self.example_inputs[arg_idx].dim_order())
539+
)
536540
random_inputs.append(
537-
torch.randn(input_shapes[arg_idx]).to(
538-
dtype=self.example_inputs[arg_idx].dtype
539-
)
541+
torch.randn(input_shapes[arg_idx])
542+
.to(dtype=self.example_inputs[arg_idx].dtype)
543+
.to(memory_format=memFormat)
540544
)
541545

542546
yield tuple(random_inputs)

backends/xnnpack/xnnpack_preprocess.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,6 @@ def preprocess(
145145

146146
node_to_external_map = generate_node_to_external_map(ep, graph_module)
147147

148-
# Make sure all inputs are contiguous_format or NCHW or default dim order
149-
assert_default_dim_order(graph_module)
150-
151148
# TODO retrace the graph module to lift the new params may have
152149
# been added to the graph in passes
153150

cmake_wrapper.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
3+
source /Users/madragna/executorch/.venv/bin/activate
4+
cmake "$@"

0 commit comments

Comments
 (0)