Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/xnnpack/partition/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
SquareRootConfig,
SubConfig,
TanhConfig,
ToDimOrderCopyConfig,
UpsampleBilinear2dConfig,
)
from executorch.backends.xnnpack.partition.config.node_configs import (
Expand Down Expand Up @@ -102,6 +103,7 @@
ReciprocalSquareRootConfig,
ReLUConfig,
TanhConfig,
ToDimOrderCopyConfig,
SigmoidConfig,
SliceCopyConfig,
SoftmaxConfig,
Expand Down
29 changes: 29 additions & 0 deletions backends/xnnpack/partition/config/generic_node_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,35 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32]


class ToDimOrderCopyConfig(GenericNodePartitionerConfig):
target_name = "_to_dim_order_copy.default"

def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
"""
Only support dim order conversion partitioning, not DType conversions
"""
if not self.check_common_constraints(node, ep):
return False

# Get input node and compare dtypes
input_node = get_input_node(node, 0)
input_dtype = input_node.meta["val"].dtype
output_dtype = node.meta["val"].dtype

# Return False if doing dtype conversion
if input_dtype != output_dtype:
why(
node,
reason=f"dtype conversion from {input_dtype} to {output_dtype} is not supported",
)
return False

return True

def supported_precision_types(self) -> List[ConfigPrecisionType]:
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]


class MeanDimConfig(GenericNodePartitionerConfig):
target_name = "mean.dim"

Expand Down
113 changes: 113 additions & 0 deletions backends/xnnpack/test/ops/test_to_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said earlier, using to_copy is OK but we can just as easily move to to_dim_order_copy and remove the dim_order ops revert pass.

# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from executorch.backends.xnnpack.test.tester import Tester


class TestChannelsLastTaggedReshapePass(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()

def run_tester(self, module, inputs):
tester = Tester(
module.eval(),
inputs,
)
tester.export().to_edge_transform_and_lower().check_not(
["executorch_exir_dialects_edge__ops_aten__to_copy_default"]
).to_executorch().serialize().run_method_and_compare_outputs()

class ChannelLastBeforeLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
y = x.to(memory_format=torch.channels_last)
return self.linear(y)

ChannelLastBeforeLinearModule = ChannelLastBeforeLinear()

def test_channel_last_before_linear(self):
self.run_tester(self.ChannelLastBeforeLinearModule, (torch.randn(1, 3, 3, 3),))

class ContiguousBeforeConv(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
y = x.to(memory_format=torch.contiguous_format)
return self.conv(y)

ContiguousBeforeConvModule = ContiguousBeforeConv()

def test_contiguous_before_conv(self):
self.run_tester(self.ContiguousBeforeConvModule, (torch.randn(1, 3, 6, 6),))

class DtypeAndMemoryFormatConversion(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
y = x.to(torch.float, memory_format=torch.channels_last)
return self.conv(y)

DtypeAndMemoryFormatConversionModule = DtypeAndMemoryFormatConversion()

def test_dtype_and_memory_format_conversion(self):
self.run_tester(
self.DtypeAndMemoryFormatConversionModule,
(torch.randint(0, 10, (1, 3, 6, 6), dtype=torch.int32),),
)

class DtypeAndMemoryFormatWithLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
y = x.to(torch.float, memory_format=torch.channels_last)
return self.linear(y)

DtypeAndMemoryFormatWithLinearModule = DtypeAndMemoryFormatWithLinear()

def test_dtype_and_memory_format_with_linear(self):
self.run_tester(
self.DtypeAndMemoryFormatWithLinearModule,
(torch.randint(0, 10, (1, 3, 3, 3), dtype=torch.int16),),
)

class QuantizedToCopy(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
y = self.conv(x)
y = y.to(memory_format=torch.contiguous_format)
return self.conv2(y)

QuantizedToCopyModule = QuantizedToCopy()

def test_quantized_to_copy(self):
tester = Tester(
self.QuantizedToCopyModule.eval(),
(torch.randn(1, 3, 9, 9),),
)

tester.quantize().export().to_edge_transform_and_lower().check_not(
[
"executorch_exir_dialects_edge__ops_aten__to_copy_default",
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
]
).to_executorch().serialize().run_method_and_compare_outputs(qtol=0.01)
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def run_tester(self, module, inputs):
module.eval(),
inputs,
)
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
tester.export().to_edge_transform_and_lower().check_not(
["executorch_exir_dialects_edge__ops_aten__to_copy_default"]
).to_executorch().serialize().run_method_and_compare_outputs()

class LinearConv(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -179,6 +181,23 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
.run_method_and_compare_outputs()
)

class LinearConvDimSwap(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.linear1 = torch.nn.Linear(4, 3)

def forward(self, x):
y = self.linear1(x)
y = y.to(memory_format=torch.channels_last)
y = y.to(memory_format=torch.contiguous_format)
return self.conv1(y)

LinearConvDimSwapModule = LinearConvDimSwap()

def test_conv_linear_dim_order_swap_partitioner(self):
self.run_tester(self.LinearConvDimSwapModule, (torch.randn(1, 3, 6, 4),))

def test_qs8_channels_last_tagged_reshape_pass(self):
for module, num_reshape in self.modules.items():
(
Expand Down
Loading