Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):

# Set of ops that require memory format to be NCHW
memory_sensitive_ops_nchw = {
"output",
exir_ops.edge.aten.squeeze_copy.dim,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.linear.default,
}

# Tag which is added to a node's meta to indicate that it uses NHWC format.
Expand Down Expand Up @@ -91,10 +91,18 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool:
return not self.is_nhwc_node(node)

def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
return node.target in self.memory_sensitive_ops_nhwc
return (
node.target in self.memory_sensitive_ops_nhwc
or node.name == "output"
and not node.args[0][0].meta["val"].is_contiguous()
)

def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
return node.target in self.memory_sensitive_ops_nchw
return (
node.target in self.memory_sensitive_ops_nchw
or node.name == "output"
and node.args[0][0].meta["val"].is_contiguous()
)

def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
# There are two conditions that must be met for a node to be able to
Expand Down Expand Up @@ -269,7 +277,10 @@ def input_to_nhwc(
# serializing graph, but don't do anything else here
self.mark_as_nhwc_node(input_node)

if self.is_nhwc_node(input_node):
if input_node.op == "placeholder":
if not input_node.meta["val"][0].is_contiguous():
return
elif self.is_nhwc_node(input_node):
return

if not self.can_be_converted_to_nhwc(input_node):
Expand Down Expand Up @@ -333,7 +344,10 @@ def input_to_nchw(
# do anything else here
self.mark_as_nchw_node(input_node)

if self.is_nchw_node(input_node):
if input_node.op == "placeholder":
if input_node.meta["val"].is_contiguous():
return
elif self.is_nchw_node(input_node):
return

if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
Expand Down Expand Up @@ -371,7 +385,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
# first input to be nhwc. This makes this node's output nhwc too
# Currently, all nodes like this should have all of their other
# inputs as nchw, so fail if this is not true
self.input_to_nhwc(graph_module, node.args[0], node)
if node.name == "output":
self.input_to_nhwc(graph_module, node.args[0][0], node)
else:
self.input_to_nhwc(graph_module, node.args[0], node)

for input_node in node.all_input_nodes[1:]:
if self.is_nhwc_node(input_node):
raise AssertionError(
Expand Down
23 changes: 14 additions & 9 deletions backends/xnnpack/runtime/XNNExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,16 @@ ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) {
err == Error::Ok,
Internal,
"Failed to retrieve dim order from tensor!");
ET_CHECK_OR_RETURN_ERROR(
is_contiguous_dim_order(dim_order, tensor->dim()),
Internal,
"Expecting default dim_order but got a non default dim_order tensor for external input %u",
i);
size_t dims[XNN_MAX_TENSOR_DIMS];
ET_CHECK_OR_RETURN_ERROR(
num_dims <= XNN_MAX_TENSOR_DIMS,
InvalidArgument,
"XNNPACK backend accepts tensors with at most %d dims, but got %zu",
XNN_MAX_TENSOR_DIMS,
num_dims);
for (int d = 0; d < num_dims; ++d) {
dims[d] = tensor->size(d);

for (int j = 0; j < num_dims; ++j) {
dims[j] = tensor->size(static_cast<int>(dim_order[j]));
}
status =
xnn_reshape_external_value(runtime_.get(), ext_id, num_dims, dims);
Expand Down Expand Up @@ -220,8 +216,17 @@ ET_NODISCARD Error XNNExecutor::resize_outputs(EValue** args) const {

// Convert new output shape into SizesType
SizesType expected_output_size[kTensorDimensionLimit];
for (size_t d = 0; d < num_dim; ++d) {
expected_output_size[d] = static_cast<SizesType>(dims[d]);
executorch::aten::DimOrderType dim_order[kTensorDimensionLimit];
Error errr =
ET_RUNTIME_NAMESPACE::get_dim_order(*out_tensor, dim_order, num_dim);
ET_CHECK_OR_RETURN_ERROR(
errr == Error::Ok,
Internal,
"Failed to retrieve dim order from tensor!");

for (int j = 0; j < num_dim; ++j) {
expected_output_size[static_cast<int>(dim_order[j])] =
static_cast<SizesType>(dims[j]);
}

executorch::aten::ArrayRef<SizesType> output_size{
Expand Down
115 changes: 115 additions & 0 deletions backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,121 @@ def setUp(self):
)
dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"

def run_tester(self, module, inputs):
tester = Tester(
module.eval(),
inputs,
)
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()

class LinearConv(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.linear1 = torch.nn.Linear(4, 3)

def forward(self, x):
y = self.linear1(x)
return self.conv1(y)

LinearConvModule = LinearConv()

class ConvLinearConv(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.linear1 = torch.nn.Linear(4, 4)

def forward(self, x):
y = self.conv1(x)
return self.linear1(y)

ConvLinearConvModule = ConvLinearConv()

class Bilinear(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.nn.functional.interpolate(
x, scale_factor=2, mode="bilinear", align_corners=True
)

BilinearModule = Bilinear()

class TwoConvAdd(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = torch.nn.Conv2d(5, 16, 3, padding=1)

def forward(self, x1, x2):
y1 = self.conv1(x1)
y2 = self.conv2(x2)
return torch.add(y1, y2)

TwoConvAddModule = TwoConvAdd()

def test_two_conv_add(self):
x1 = torch.randn(1, 3, 8, 8)
x2 = torch.randn(1, 5, 8, 8)

# Test with regular format inputs
self.run_tester(self.TwoConvAddModule, (x1, x2))

# Test with channels_last format inputs
x1_cl = x1.to(memory_format=torch.channels_last)
x2_cl = x2.to(memory_format=torch.channels_last)
self.run_tester(self.TwoConvAddModule, (x1_cl, x2_cl))

# Test with mixed format inputs
self.run_tester(self.TwoConvAddModule, (x1_cl, x2))
self.run_tester(self.TwoConvAddModule, (x1, x2_cl))

# Verify the pass adds the expected number of to_copy operations
(
Tester(self.TwoConvAddModule, (x1, x2))
.export()
.to_edge()
.run_passes(self.PassStage)
.check_count(
{
self.to_copy_name: 3, # 2 for inputs to conv, 1 for outputs from add
}
)
.run_method_and_compare_outputs()
)

def test_conv_linear_dim_order_swaps(self):
self.run_tester(self.LinearConvModule, (torch.randn(1, 3, 6, 4),))
self.run_tester(
self.LinearConvModule,
(torch.randn(1, 3, 6, 4).to(memory_format=torch.channels_last),),
)

def test_linear_conv_dim_order_swaps(self):
self.run_tester(self.ConvLinearConvModule, (torch.randn(1, 3, 6, 6),))
self.run_tester(
self.ConvLinearConvModule,
(torch.randn(1, 3, 6, 6).to(memory_format=torch.channels_last),),
)

def test_nhwc_nchw_input_on_nhwc_op(self):
self.run_tester(
self.BilinearModule,
(
torch.arange(8)
.reshape(1, 2, 2, 2)
.to(torch.float32)
.to(memory_format=torch.channels_last),
),
)

self.run_tester(
self.BilinearModule,
(torch.arange(8).reshape(1, 2, 2, 2).to(torch.float32),),
)

def test_fp32_channels_last_tagged_reshape_pass(self):
for module, num_reshape in self.modules.items():
(
Expand Down
10 changes: 7 additions & 3 deletions backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from executorch.exir.backend.backend_api import validation_disabled
from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.dim_order_utils import get_memory_format
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass

from executorch.exir.print_program import pretty_print, print_program
Expand Down Expand Up @@ -533,10 +534,13 @@ def fn(x):
# create random tensor inputs with the shapes given above:
random_inputs = []
for arg_idx in range(len(self.example_inputs)):
memFormat = get_memory_format(
list(self.example_inputs[arg_idx].dim_order())
)
random_inputs.append(
torch.randn(input_shapes[arg_idx]).to(
dtype=self.example_inputs[arg_idx].dtype
)
torch.randn(input_shapes[arg_idx])
.to(dtype=self.example_inputs[arg_idx].dtype)
.to(memory_format=memFormat)
)

yield tuple(random_inputs)
Expand Down
3 changes: 0 additions & 3 deletions backends/xnnpack/xnnpack_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ def preprocess(

node_to_external_map = generate_node_to_external_map(ep, graph_module)

# Make sure all inputs are contiguous_format or NCHW or default dim order
assert_default_dim_order(graph_module)

# TODO retrace the graph module to lift the new params may have
# been added to the graph in passes

Expand Down
Loading