Skip to content

Commit c26c56b

Browse files
committed
partition to_dim_order_copy in XNN delegate
1 parent 0fa73fd commit c26c56b

File tree

5 files changed

+167
-1
lines changed

5 files changed

+167
-1
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
395395
# The node requires nchw inputs
396396
for input_node in node.all_input_nodes:
397397
self.input_to_nchw(graph_module, input_node, node)
398+
elif node.target == exir_ops.edge.aten._to_copy.default:
399+
if node.kwargs["memory_format"] == torch.channels_last:
400+
self.mark_as_nhwc_node(node)
401+
else:
402+
self.mark_as_nchw_node(node)
398403
else:
399404
# The node can have inputs in any format (but all must be the
400405
# same format)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
SquareRootConfig,
5151
SubConfig,
5252
TanhConfig,
53+
ToDimOrderCopyConfig,
5354
UpsampleBilinear2dConfig,
5455
)
5556
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -102,6 +103,7 @@
102103
ReciprocalSquareRootConfig,
103104
ReLUConfig,
104105
TanhConfig,
106+
ToDimOrderCopyConfig,
105107
SigmoidConfig,
106108
SliceCopyConfig,
107109
SoftmaxConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,35 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
397397
return [ConfigPrecisionType.FP32]
398398

399399

400+
class ToDimOrderCopyConfig(GenericNodePartitionerConfig):
401+
target_name = "_to_dim_order_copy.default"
402+
403+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
404+
"""
405+
Only support dim order conversion partitioning, not DType conversions
406+
"""
407+
if not self.check_common_constraints(node, ep):
408+
return False
409+
410+
# Get input node and compare dtypes
411+
input_node = get_input_node(node, 0)
412+
input_dtype = input_node.meta["val"].dtype
413+
output_dtype = node.meta["val"].dtype
414+
415+
# Return False if doing dtype conversion
416+
if input_dtype != output_dtype:
417+
why(
418+
node,
419+
reason=f"dtype conversion from {input_dtype} to {output_dtype} is not supported",
420+
)
421+
return False
422+
423+
return True
424+
425+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
426+
return [ConfigPrecisionType.FP32]
427+
428+
400429
class MeanDimConfig(GenericNodePartitionerConfig):
401430
target_name = "mean.dim"
402431

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.backends.xnnpack.test.tester import Tester
12+
13+
14+
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
15+
def setUp(self):
16+
torch._dynamo.reset()
17+
18+
def run_tester(self, module, inputs):
19+
tester = Tester(
20+
module.eval(),
21+
inputs,
22+
)
23+
tester.export().to_edge_transform_and_lower().check_not(
24+
["executorch_exir_dialects_edge__ops_aten__to_copy_default"]
25+
).to_executorch().serialize().run_method_and_compare_outputs()
26+
27+
class ChannelLastBeforeLinear(torch.nn.Module):
28+
def __init__(self):
29+
super().__init__()
30+
self.linear = torch.nn.Linear(3, 3)
31+
32+
def forward(self, x):
33+
y = x.to(memory_format=torch.channels_last)
34+
return self.linear(y)
35+
36+
ChannelLastBeforeLinearModule = ChannelLastBeforeLinear()
37+
38+
def test_channel_last_before_linear(self):
39+
self.run_tester(self.ChannelLastBeforeLinearModule, (torch.randn(1, 3, 3, 3),))
40+
41+
class ContiguousBeforeConv(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
self.conv = torch.nn.Conv2d(3, 3, 3)
45+
46+
def forward(self, x):
47+
y = x.to(memory_format=torch.contiguous_format)
48+
return self.conv(y)
49+
50+
ContiguousBeforeConvModule = ContiguousBeforeConv()
51+
52+
def test_contiguous_before_conv(self):
53+
self.run_tester(self.ContiguousBeforeConvModule, (torch.randn(1, 3, 6, 6),))
54+
55+
class DtypeAndMemoryFormatConversion(torch.nn.Module):
56+
def __init__(self):
57+
super().__init__()
58+
self.conv = torch.nn.Conv2d(3, 3, 3)
59+
60+
def forward(self, x):
61+
y = x.to(torch.float, memory_format=torch.channels_last)
62+
return self.conv(y)
63+
64+
DtypeAndMemoryFormatConversionModule = DtypeAndMemoryFormatConversion()
65+
66+
def test_dtype_and_memory_format_conversion(self):
67+
self.run_tester(
68+
self.DtypeAndMemoryFormatConversionModule,
69+
(torch.randint(0, 10, (1, 3, 6, 6), dtype=torch.int32),),
70+
)
71+
72+
class DtypeAndMemoryFormatWithLinear(torch.nn.Module):
73+
def __init__(self):
74+
super().__init__()
75+
self.linear = torch.nn.Linear(3, 3)
76+
77+
def forward(self, x):
78+
y = x.to(torch.float, memory_format=torch.channels_last)
79+
return self.linear(y)
80+
81+
DtypeAndMemoryFormatWithLinearModule = DtypeAndMemoryFormatWithLinear()
82+
83+
def test_dtype_and_memory_format_with_linear(self):
84+
self.run_tester(
85+
self.DtypeAndMemoryFormatWithLinearModule,
86+
(torch.randint(0, 10, (1, 3, 3, 3), dtype=torch.int16),),
87+
)
88+
89+
class QuantizedToCopy(torch.nn.Module):
90+
def __init__(self):
91+
super().__init__()
92+
self.conv = torch.nn.Conv2d(3, 3, 3)
93+
94+
def forward(self, x):
95+
y = x.to(memory_format=torch.channels_last)
96+
return self.conv(y)
97+
98+
QuantizedToCopyModule = QuantizedToCopy()
99+
100+
def test_quantized_to_copy(self):
101+
tester = Tester(
102+
self.QuantizedToCopyModule.eval(),
103+
(torch.randn(1, 3, 6, 6),),
104+
)
105+
106+
tester.quantize().export().to_edge_transform_and_lower().check_not(
107+
[
108+
"executorch_exir_dialects_edge__ops_aten__to_copy_default",
109+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
110+
]
111+
).to_executorch().serialize().run_method_and_compare_outputs()

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def run_tester(self, module, inputs):
4848
module.eval(),
4949
inputs,
5050
)
51-
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
51+
tester.export().to_edge_transform_and_lower().check_not(
52+
["executorch_exir_dialects_edge__ops_aten__to_copy_default"]
53+
).to_executorch().serialize().run_method_and_compare_outputs()
5254

5355
class LinearConv(torch.nn.Module):
5456
def __init__(self):
@@ -173,6 +175,23 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
173175
.run_method_and_compare_outputs()
174176
)
175177

178+
class LinearConvDimSwap(torch.nn.Module):
179+
def __init__(self):
180+
super().__init__()
181+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
182+
self.linear1 = torch.nn.Linear(4, 3)
183+
184+
def forward(self, x):
185+
y = self.linear1(x)
186+
y = y.to(memory_format=torch.channels_last)
187+
y = y.to(memory_format=torch.contiguous_format)
188+
return self.conv1(y)
189+
190+
LinearConvDimSwapModule = LinearConvDimSwap()
191+
192+
def test_conv_linear_dim_order_swap_partitioner(self):
193+
self.run_tester(self.LinearConvDimSwapModule, (torch.randn(1, 3, 6, 4),))
194+
176195
def test_qs8_channels_last_tagged_reshape_pass(self):
177196
for module, num_reshape in self.modules.items():
178197
(

0 commit comments

Comments
 (0)