Skip to content

Commit 74a97cd

Browse files
committed
partition to_dim_order_copy in XNN delegate
1 parent be56146 commit 74a97cd

File tree

5 files changed

+117
-0
lines changed

5 files changed

+117
-0
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
400400
# The node requires nchw inputs
401401
for input_node in node.all_input_nodes:
402402
self.input_to_nchw(graph_module, input_node, node)
403+
elif node.target == exir_ops.edge.aten._to_copy.default:
404+
if node.meta["val"].is_contiguous():
405+
self.mark_as_nchw_node(node)
406+
else:
407+
self.mark_as_nhwc_node(node)
403408
else:
404409
# The node can have inputs in any format (but all must be the
405410
# same format)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
SoftmaxConfig,
4949
SquareRootConfig,
5050
SubConfig,
51+
ToDimOrderCopyConfig,
5152
UpsampleBilinear2dConfig,
5253
)
5354
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -97,6 +98,7 @@
9798
PreluConfig,
9899
ReciprocalSquareRootConfig,
99100
ReLUConfig,
101+
ToDimOrderCopyConfig,
100102
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
101103
SigmoidConfig,
102104
SliceCopyConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
364364
return [ConfigPrecisionType.FP32]
365365

366366

367+
class ToDimOrderCopyConfig(GenericNodePartitionerConfig):
368+
target_name = "_to_dim_order_copy.default"
369+
370+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
371+
return [ConfigPrecisionType.FP32]
372+
373+
367374
class MeanDimConfig(GenericNodePartitionerConfig):
368375
target_name = "mean.dim"
369376

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from typing import Optional, Tuple
9+
10+
import torch
11+
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
12+
ConfigPrecisionType,
13+
)
14+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
15+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
16+
get_symmetric_quantization_config,
17+
)
18+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
19+
QuantizationConfig,
20+
)
21+
from executorch.backends.xnnpack.test.tester import Quantize, Tester
22+
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
23+
24+
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
25+
def setUp(self):
26+
torch._dynamo.reset()
27+
28+
def run_tester(self, module, inputs):
29+
tester = Tester(
30+
module.eval(),
31+
inputs,
32+
)
33+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
34+
35+
class ChannelLastBeforeLinear(torch.nn.Module):
36+
def __init__(self):
37+
super().__init__()
38+
self.linear = torch.nn.Linear(3, 3)
39+
40+
def forward(self, x):
41+
y = x.to(memory_format=torch.channels_last)
42+
return self.linear(y)
43+
44+
ChannelLastBeforeLinearModule = ChannelLastBeforeLinear()
45+
def test_channel_last_before_linear(self):
46+
self.run_tester(self.ChannelLastBeforeLinearModule, (torch.randn(1, 3, 3, 3),))
47+
48+
49+
class ContiguousBeforeConv(torch.nn.Module):
50+
def __init__(self):
51+
super().__init__()
52+
self.conv = torch.nn.Conv2d(3, 3, 3)
53+
54+
def forward(self, x):
55+
y = x.to(memory_format=torch.contiguous_format)
56+
return self.conv(y)
57+
58+
ContiguousBeforeConvModule = ContiguousBeforeConv()
59+
def test_contiguous_before_conv(self):
60+
self.run_tester(self.ContiguousBeforeConvModule, (torch.randn(1, 3, 6, 6),))
61+
62+
class DtypeAndMemoryFormatConversion(torch.nn.Module):
63+
def __init__(self):
64+
super().__init__()
65+
self.conv = torch.nn.Conv2d(3, 3, 3)
66+
67+
def forward(self, x):
68+
y = x.to(torch.float, memory_format=torch.channels_last)
69+
return self.conv(y)
70+
71+
DtypeAndMemoryFormatConversionModule = DtypeAndMemoryFormatConversion()
72+
def test_dtype_and_memory_format_conversion(self):
73+
self.run_tester(self.DtypeAndMemoryFormatConversionModule, (torch.randint(0, 10, (1, 3, 6, 6), dtype=torch.int32),))
74+
75+
class DtypeAndMemoryFormatWithLinear(torch.nn.Module):
76+
def __init__(self):
77+
super().__init__()
78+
self.linear = torch.nn.Linear(3, 3)
79+
80+
def forward(self, x):
81+
y = x.to(torch.float, memory_format=torch.channels_last)
82+
return self.linear(y)
83+
84+
DtypeAndMemoryFormatWithLinearModule = DtypeAndMemoryFormatWithLinear()
85+
def test_dtype_and_memory_format_with_linear(self):
86+
self.run_tester(self.DtypeAndMemoryFormatWithLinearModule, (torch.randint(0, 10, (1, 3, 3, 3), dtype=torch.int16),))

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,23 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
173173
.run_method_and_compare_outputs()
174174
)
175175

176+
class LinearConvDimSwap(torch.nn.Module):
177+
def __init__(self):
178+
super().__init__()
179+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
180+
self.linear1 = torch.nn.Linear(4, 3)
181+
182+
def forward(self, x):
183+
y = self.linear1(x)
184+
y = y.to(memory_format=torch.channels_last)
185+
y = y.to(memory_format=torch.contiguous_format)
186+
return self.conv1(y)
187+
188+
LinearConvDimSwapModule = LinearConvDimSwap()
189+
190+
def test_conv_linear_dim_order_swap_partitioner(self):
191+
self.run_tester(self.LinearConvDimSwapModule, (torch.randn(1, 3, 6, 4),))
192+
176193
def test_qs8_channels_last_tagged_reshape_pass(self):
177194
for module, num_reshape in self.modules.items():
178195
(

0 commit comments

Comments
 (0)