Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
Empty file.
26 changes: 26 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Debug CMake project",
"type": "lldb", // https://github.com/vadimcn/vscode-lldb
"request": "launch",
"program": "${command:cmake.launchTargetPath}",
"args": [
"--model_path=./add.pte",
]
},
{
"name": "Debug python proj",
"type": "debugpy",
"request": "launch",
"module": "unittest",
"args": [
"./backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py"
]
},
]
}
83 changes: 83 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
{
"files.associations": {
"cstdlib": "cpp",
"__bit_reference": "cpp",
"__hash_table": "cpp",
"__locale": "cpp",
"__node_handle": "cpp",
"__split_buffer": "cpp",
"__tree": "cpp",
"__verbose_abort": "cpp",
"array": "cpp",
"bitset": "cpp",
"cctype": "cpp",
"charconv": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"complex": "cpp",
"condition_variable": "cpp",
"cstdarg": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"deque": "cpp",
"execution": "cpp",
"memory": "cpp",
"forward_list": "cpp",
"future": "cpp",
"initializer_list": "cpp",
"iomanip": "cpp",
"ios": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"list": "cpp",
"locale": "cpp",
"map": "cpp",
"mutex": "cpp",
"new": "cpp",
"optional": "cpp",
"print": "cpp",
"queue": "cpp",
"ratio": "cpp",
"regex": "cpp",
"set": "cpp",
"shared_mutex": "cpp",
"sstream": "cpp",
"stack": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"string": "cpp",
"string_view": "cpp",
"typeindex": "cpp",
"typeinfo": "cpp",
"unordered_map": "cpp",
"unordered_set": "cpp",
"variant": "cpp",
"vector": "cpp",
"algorithm": "cpp",
"iterator": "cpp",
"tuple": "cpp",
"span": "cpp",
"*.inc": "cpp",
"alignedvector3": "cpp"
},
"C_Cpp.default.compilerPath": "/library/developer/commandlinetools/usr/bin/c++",
"python.analysis.typeCheckingMode": "off",
"python.testing.unittestArgs": [
"-v",
"-s",
"./backends",
"-p",
"test_*.py"
],
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
"python.testing.pytestArgs": [
"."
]
}
Empty file.
Binary file not shown.
51 changes: 45 additions & 6 deletions backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import string
from logging import FATAL
from tokenize import String
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -56,9 +59,9 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):

# Set of ops that require memory format to be NCHW
memory_sensitive_ops_nchw = {
"output",
exir_ops.edge.aten.squeeze_copy.dim,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.linear.default,
}

# Tag which is added to a node's meta to indicate that it uses NHWC format.
Expand Down Expand Up @@ -91,10 +94,20 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool:
return not self.is_nhwc_node(node)

def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
return node.target in self.memory_sensitive_ops_nhwc
return (
node.target in self.memory_sensitive_ops_nhwc
or node.name == "output"
and not node.args[0][0].meta["val"].is_contiguous()
)

def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
return node.target in self.memory_sensitive_ops_nchw
return (
node.target in self.memory_sensitive_ops_nchw
or node.name == "output"
and node.args[0][0]
.meta["val"]
.is_contiguous() # Need to consider output trace so out matches
)

def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
# There are two conditions that must be met for a node to be able to
Expand Down Expand Up @@ -269,8 +282,17 @@ def input_to_nhwc(
# serializing graph, but don't do anything else here
self.mark_as_nhwc_node(input_node)

if self.is_nhwc_node(input_node):
if input_node.op == "placeholder":
if not input_node.meta["val"][0].is_contiguous():
return
elif self.is_nhwc_node(input_node):
return
# if (
# self.is_nhwc_node(input_node)
# or input_node.op == "placeholder"
# and not input_node.meta["val"][0].is_contiguous()
# ):
# return

if not self.can_be_converted_to_nhwc(input_node):
raise AssertionError(
Expand Down Expand Up @@ -333,8 +355,21 @@ def input_to_nchw(
# do anything else here
self.mark_as_nchw_node(input_node)

if self.is_nchw_node(input_node):
if input_node.op == "placeholder":
if input_node.meta["val"][0].is_contiguous():
return
elif self.is_nchw_node(input_node):
return
# TODO
# meta trace happens before passes. At the end of pass, meta gets regenerated. eager mode assumes in/out stay same for conv. Linear has implicit nchw conv
# if (
# self.is_nchw_node(
# input_node
# ) # This is triggering as x (placeholder) is tagged as nchw
# or input_node.op == "placeholder"
# and input_node.meta["val"][0].is_contiguous()
# ):
# return

if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
# Already has an associated NCHW node
Expand Down Expand Up @@ -371,7 +406,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
# first input to be nhwc. This makes this node's output nhwc too
# Currently, all nodes like this should have all of their other
# inputs as nchw, so fail if this is not true
self.input_to_nhwc(graph_module, node.args[0], node)
if node.name == "output":
self.input_to_nhwc(graph_module, node.args[0][0], node)
else:
self.input_to_nhwc(graph_module, node.args[0], node)

for input_node in node.all_input_nodes[1:]:
if self.is_nhwc_node(input_node):
raise AssertionError(
Expand Down
38 changes: 29 additions & 9 deletions backends/xnnpack/runtime/XNNExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,24 @@ ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) {
err == Error::Ok,
Internal,
"Failed to retrieve dim order from tensor!");
ET_CHECK_OR_RETURN_ERROR(
is_contiguous_dim_order(dim_order, tensor->dim()),
Internal,
"Expecting default dim_order but got a non default dim_order tensor for external input %u",
i);
size_t dims[XNN_MAX_TENSOR_DIMS];
ET_CHECK_OR_RETURN_ERROR(
num_dims <= XNN_MAX_TENSOR_DIMS,
InvalidArgument,
"XNNPACK backend accepts tensors with at most %d dims, but got %zu",
XNN_MAX_TENSOR_DIMS,
num_dims);
for (int d = 0; d < num_dims; ++d) {
dims[d] = tensor->size(d);

bool is_channels_last = executorch::runtime::is_channels_last_dim_order(dim_order, num_dims);
if (is_channels_last) {
dims[0] = tensor->size(0);
dims[1] = tensor->size(2);
dims[2] = tensor->size(3);
dims[3] = tensor->size(1);
} else {
for (int d = 0; d < num_dims; ++d) {
dims[d] = tensor->size(d);
}
}
status =
xnn_reshape_external_value(runtime_.get(), ext_id, num_dims, dims);
Expand Down Expand Up @@ -220,8 +224,24 @@ ET_NODISCARD Error XNNExecutor::resize_outputs(EValue** args) const {

// Convert new output shape into SizesType
SizesType expected_output_size[kTensorDimensionLimit];
for (size_t d = 0; d < num_dim; ++d) {
expected_output_size[d] = static_cast<SizesType>(dims[d]);
executorch::aten::DimOrderType dim_order[kTensorDimensionLimit];
Error errr =
ET_RUNTIME_NAMESPACE::get_dim_order(*out_tensor, dim_order, num_dim);
ET_CHECK_OR_RETURN_ERROR(
errr == Error::Ok,
Internal,
"Failed to retrieve dim order from tensor!");

bool is_channels_last = executorch::runtime::is_channels_last_dim_order(dim_order, num_dim);
if (is_channels_last) {
expected_output_size[0] = static_cast<SizesType>(dims[0]);
expected_output_size[1] = static_cast<SizesType>(dims[3]);
expected_output_size[2] = static_cast<SizesType>(dims[1]);
expected_output_size[3] = static_cast<SizesType>(dims[2]);
} else {
for (size_t d = 0; d < num_dim; ++d) {
expected_output_size[d] = static_cast<SizesType>(dims[d]);
}
}

executorch::aten::ArrayRef<SizesType> output_size{
Expand Down
Loading
Loading