Skip to content

Commit d195f85

Browse files
committed
partition to_dim_order_copy in XNN delegate
1 parent 18e4240 commit d195f85

File tree

5 files changed

+116
-0
lines changed

5 files changed

+116
-0
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
395395
# The node requires nchw inputs
396396
for input_node in node.all_input_nodes:
397397
self.input_to_nchw(graph_module, input_node, node)
398+
elif node.target == exir_ops.edge.aten._to_copy.default:
399+
if node.meta["val"].is_contiguous():
400+
self.mark_as_nchw_node(node)
401+
else:
402+
self.mark_as_nhwc_node(node)
398403
else:
399404
# The node can have inputs in any format (but all must be the
400405
# same format)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
SoftmaxConfig,
5050
SquareRootConfig,
5151
SubConfig,
52+
ToDimOrderCopyConfig,
5253
UpsampleBilinear2dConfig,
5354
)
5455
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -99,6 +100,7 @@
99100
PreluConfig,
100101
ReciprocalSquareRootConfig,
101102
ReLUConfig,
103+
ToDimOrderCopyConfig,
102104
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
103105
SigmoidConfig,
104106
SliceCopyConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
371371
return [ConfigPrecisionType.FP32]
372372

373373

374+
class ToDimOrderCopyConfig(GenericNodePartitionerConfig):
375+
target_name = "_to_dim_order_copy.default"
376+
377+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
378+
return [ConfigPrecisionType.FP32]
379+
380+
374381
class MeanDimConfig(GenericNodePartitionerConfig):
375382
target_name = "mean.dim"
376383

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.backends.xnnpack.test.tester import Tester
12+
13+
14+
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
15+
def setUp(self):
16+
torch._dynamo.reset()
17+
18+
def run_tester(self, module, inputs):
19+
tester = Tester(
20+
module.eval(),
21+
inputs,
22+
)
23+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
24+
25+
class ChannelLastBeforeLinear(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.linear = torch.nn.Linear(3, 3)
29+
30+
def forward(self, x):
31+
y = x.to(memory_format=torch.channels_last)
32+
return self.linear(y)
33+
34+
ChannelLastBeforeLinearModule = ChannelLastBeforeLinear()
35+
36+
def test_channel_last_before_linear(self):
37+
self.run_tester(self.ChannelLastBeforeLinearModule, (torch.randn(1, 3, 3, 3),))
38+
39+
class ContiguousBeforeConv(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
self.conv = torch.nn.Conv2d(3, 3, 3)
43+
44+
def forward(self, x):
45+
y = x.to(memory_format=torch.contiguous_format)
46+
return self.conv(y)
47+
48+
ContiguousBeforeConvModule = ContiguousBeforeConv()
49+
50+
def test_contiguous_before_conv(self):
51+
self.run_tester(self.ContiguousBeforeConvModule, (torch.randn(1, 3, 6, 6),))
52+
53+
class DtypeAndMemoryFormatConversion(torch.nn.Module):
54+
def __init__(self):
55+
super().__init__()
56+
self.conv = torch.nn.Conv2d(3, 3, 3)
57+
58+
def forward(self, x):
59+
y = x.to(torch.float, memory_format=torch.channels_last)
60+
return self.conv(y)
61+
62+
DtypeAndMemoryFormatConversionModule = DtypeAndMemoryFormatConversion()
63+
64+
def test_dtype_and_memory_format_conversion(self):
65+
self.run_tester(
66+
self.DtypeAndMemoryFormatConversionModule,
67+
(torch.randint(0, 10, (1, 3, 6, 6), dtype=torch.int32),),
68+
)
69+
70+
class DtypeAndMemoryFormatWithLinear(torch.nn.Module):
71+
def __init__(self):
72+
super().__init__()
73+
self.linear = torch.nn.Linear(3, 3)
74+
75+
def forward(self, x):
76+
y = x.to(torch.float, memory_format=torch.channels_last)
77+
return self.linear(y)
78+
79+
DtypeAndMemoryFormatWithLinearModule = DtypeAndMemoryFormatWithLinear()
80+
81+
def test_dtype_and_memory_format_with_linear(self):
82+
self.run_tester(
83+
self.DtypeAndMemoryFormatWithLinearModule,
84+
(torch.randint(0, 10, (1, 3, 3, 3), dtype=torch.int16),),
85+
)

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,23 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
173173
.run_method_and_compare_outputs()
174174
)
175175

176+
class LinearConvDimSwap(torch.nn.Module):
177+
def __init__(self):
178+
super().__init__()
179+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
180+
self.linear1 = torch.nn.Linear(4, 3)
181+
182+
def forward(self, x):
183+
y = self.linear1(x)
184+
y = y.to(memory_format=torch.channels_last)
185+
y = y.to(memory_format=torch.contiguous_format)
186+
return self.conv1(y)
187+
188+
LinearConvDimSwapModule = LinearConvDimSwap()
189+
190+
def test_conv_linear_dim_order_swap_partitioner(self):
191+
self.run_tester(self.LinearConvDimSwapModule, (torch.randn(1, 3, 6, 4),))
192+
176193
def test_qs8_channels_last_tagged_reshape_pass(self):
177194
for module, num_reshape in self.modules.items():
178195
(

0 commit comments

Comments
 (0)