Skip to content

Commit 62d52fc

Browse files
digantdesaifacebook-github-bot
authored andcommitted
XNNPACK: Add support for clone
Summary: * Partition `dim_order_clone.default" * Revert back to `aten.clone.default" * Run `RemoveCloneOpsTransform` to remove redundant clones * Lower `aten.clone.default` to XNNPACKStaticTranspose if left * Add tests Differential Revision: D83560001
1 parent f7c009e commit 62d52fc

File tree

7 files changed

+189
-3
lines changed

7 files changed

+189
-3
lines changed

backends/xnnpack/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ runtime.python_library(
88
deps = [
99
"//caffe2:torch",
1010
"//executorch/backends/transforms:addmm_mm_to_linear",
11+
"//executorch/backends/transforms:remove_clone_ops",
1112
"//executorch/backends/transforms:lib",
1213
"//executorch/backends/xnnpack/partition:partitioner_graphs",
1314
"//executorch/backends/xnnpack/serialization:xnnpack_schema",

backends/xnnpack/_passes/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import List, Optional, Type
89

10+
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
11+
912
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
1013

1114
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
@@ -38,6 +41,9 @@
3841

3942
from torch.export import ExportedProgram
4043

44+
logger = logging.getLogger(__name__)
45+
logger.setLevel(logging.WARNING)
46+
4147

4248
class XNNPACKPassManager:
4349
def __init__(
@@ -69,6 +75,7 @@ def __init__(
6975
PReLUReshapePass,
7076
ChannelsLastTaggedReshapePass,
7177
RemoveRedundantCopyPass,
78+
RemoveCloneOpsTransform,
7279
]
7380
else:
7481
self.passes = passes
@@ -92,4 +99,6 @@ def transform(self) -> ExportedProgram:
9299
f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}"
93100
)
94101
ep = _transform(ep, transform_pass)
102+
logger.debug(f"Running {pass_.__name__} pass")
103+
logger.debug(f"Transformed program: {ep}")
95104
return ep

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,10 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
493493
# The node requires nchw inputs
494494
for input_node in node.all_input_nodes:
495495
self.input_to_nchw(graph_module, input_node, node)
496-
elif node.target == exir_ops.edge.aten._to_copy.default:
496+
elif node.target in [
497+
exir_ops.edge.aten._to_copy.default,
498+
exir_ops.edge.aten.clone.default,
499+
]:
497500
self.tag_node(node)
498501
else:
499502
# The node can have inputs in any format (but all must be the

backends/xnnpack/operators/op_to_copy.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828
)
2929

3030

31-
@register_node_visitor
3231
class ConvertMemoryFormat(NodeVisitor):
33-
target = "aten._to_copy.default"
3432

3533
def __init__(self, *args) -> None:
3634
super().__init__(*args)
@@ -54,6 +52,13 @@ def define_node(
5452
input_quant_params = QuantParams.from_inputs(input_node, self._exported_program)
5553
output_quant_params = QuantParams.from_outputs(node)
5654

55+
# Ensure input and output have the same dtype
56+
input_dtype = input_node.meta["val"].dtype
57+
output_dtype = node.meta["val"].dtype
58+
assert (
59+
input_dtype == output_dtype
60+
), f"Input dtype {input_dtype} must match output dtype {output_dtype} for {node.target}. Expected dtype to not change."
61+
5762
permute_order = PERM_NCHW_TO_NHWC if to_channels_last else PERM_NHWC_TO_NCHW
5863

5964
self.define_tensor(
@@ -89,3 +94,15 @@ def define_node(
8994
debug_handle=debug_handle,
9095
)
9196
xnn_graph.xnodes.append(ser_node)
97+
98+
99+
@register_node_visitor
100+
class ConvertMemoryFormatToCopy(ConvertMemoryFormat):
101+
102+
target = "aten._to_copy.default"
103+
104+
105+
@register_node_visitor
106+
class ConvertMemoryFormatClone(ConvertMemoryFormat):
107+
108+
target = "aten.clone.default"

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
CatConfig,
2323
CeilConfig,
2424
ClampConfig,
25+
CloneDimOrderConfig,
2526
ConstantPadConfig,
2627
DeQuantizedPerTensorConfig,
2728
DivConfig,
@@ -117,4 +118,5 @@
117118
QuantizeAffineConfig,
118119
DeQuantizeAffineConfig,
119120
ChooseQParamsAffineConfig,
121+
CloneDimOrderConfig,
120122
]

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,14 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
454454
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
455455

456456

457+
class CloneDimOrderConfig(ToDimOrderCopyConfig):
458+
target_name = "_clone_dim_order.default"
459+
460+
"""
461+
Similar to ToDimOrderCopyConfig, but with different target name. We shouldn't change dtype anyway.
462+
"""
463+
464+
457465
class MeanDimConfig(GenericNodePartitionerConfig):
458466
target_name = "mean.dim"
459467

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.backends.xnnpack.test.tester import Tester
12+
13+
14+
class TestCloneMemoryFormat(unittest.TestCase):
15+
def setUp(self):
16+
torch._dynamo.reset()
17+
18+
def run_tester(self, module, inputs):
19+
tester = Tester(
20+
module.eval(),
21+
inputs,
22+
)
23+
tester.export().to_edge_transform_and_lower().check_not(
24+
["executorch_exir_dialects_edge__ops_aten_clone_default"]
25+
).to_executorch().serialize().run_method_and_compare_outputs()
26+
27+
class ChannelLastBeforeLinear(torch.nn.Module):
28+
def __init__(self):
29+
super().__init__()
30+
self.linear = torch.nn.Linear(3, 3)
31+
32+
def forward(self, x):
33+
y = x.clone(memory_format=torch.channels_last)
34+
return self.linear(y)
35+
36+
ChannelLastBeforeLinearModule = ChannelLastBeforeLinear()
37+
38+
def test_channel_last_before_linear(self):
39+
self.run_tester(self.ChannelLastBeforeLinearModule, (torch.randn(1, 3, 3, 3),))
40+
41+
class ContiguousBeforeConv(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
self.conv = torch.nn.Conv2d(3, 3, 3)
45+
46+
def forward(self, x):
47+
y = x.clone(memory_format=torch.contiguous_format)
48+
return self.conv(y)
49+
50+
ContiguousBeforeConvModule = ContiguousBeforeConv()
51+
52+
def test_contiguous_before_conv(self):
53+
self.run_tester(self.ContiguousBeforeConvModule, (torch.randn(1, 3, 6, 6),))
54+
55+
class CloneChannelsLastToContiguous(torch.nn.Module):
56+
def __init__(self):
57+
super().__init__()
58+
self.conv = torch.nn.Conv2d(3, 3, 3)
59+
60+
def forward(self, x):
61+
# Start with channels_last input
62+
x_channels_last = x.to(memory_format=torch.channels_last)
63+
# Clone to contiguous format
64+
y = x_channels_last.clone(memory_format=torch.contiguous_format)
65+
return self.conv(y)
66+
67+
CloneChannelsLastToContiguousModule = CloneChannelsLastToContiguous()
68+
69+
def test_clone_channels_last_to_contiguous(self):
70+
self.run_tester(
71+
self.CloneChannelsLastToContiguousModule, (torch.randn(1, 3, 6, 6),)
72+
)
73+
74+
class CloneContiguousToChannelsLast(torch.nn.Module):
75+
def __init__(self):
76+
super().__init__()
77+
self.conv = torch.nn.Conv2d(3, 3, 3)
78+
79+
def forward(self, x):
80+
# Clone contiguous input to channels_last format
81+
y = x.clone(memory_format=torch.channels_last)
82+
return self.conv(y)
83+
84+
CloneContiguousToChannelsLastModule = CloneContiguousToChannelsLast()
85+
86+
def test_clone_contiguous_to_channels_last(self):
87+
self.run_tester(
88+
self.CloneContiguousToChannelsLastModule, (torch.randn(1, 3, 6, 6),)
89+
)
90+
91+
class SimpleClone(torch.nn.Module):
92+
def __init__(self):
93+
super().__init__()
94+
self.conv = torch.nn.Conv2d(3, 3, 3)
95+
96+
def forward(self, x):
97+
# Simple clone without memory format (should default to contiguous)
98+
y = x.clone()
99+
return self.conv(y)
100+
101+
SimpleCloneModule = SimpleClone()
102+
103+
def test_simple_clone(self):
104+
self.run_tester(self.SimpleCloneModule, (torch.randn(1, 3, 6, 6),))
105+
106+
class QuantizedClone(torch.nn.Module):
107+
def __init__(self):
108+
super().__init__()
109+
self.conv = torch.nn.Conv2d(3, 3, 3)
110+
self.conv2 = torch.nn.Conv2d(3, 3, 3)
111+
112+
def forward(self, x):
113+
y = self.conv(x)
114+
y = y.clone(memory_format=torch.contiguous_format)
115+
return self.conv2(y)
116+
117+
QuantizedCloneModule = QuantizedClone()
118+
119+
def test_quantized_clone(self):
120+
tester = Tester(
121+
self.QuantizedCloneModule.eval(),
122+
(torch.randn(1, 3, 9, 9),),
123+
)
124+
125+
tester.quantize().export().to_edge_transform_and_lower().check_not(
126+
[
127+
"executorch_exir_dialects_edge__ops_aten_clone_default",
128+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
129+
]
130+
).to_executorch().serialize().run_method_and_compare_outputs(qtol=1)
131+
132+
class ChainedClone(torch.nn.Module):
133+
def __init__(self):
134+
super().__init__()
135+
self.conv = torch.nn.Conv2d(3, 3, 3)
136+
137+
def forward(self, x):
138+
# Chain multiple clones with different memory formats
139+
y = x.clone(memory_format=torch.channels_last)
140+
z = y.clone(memory_format=torch.contiguous_format)
141+
return self.conv(z)
142+
143+
ChainedCloneModule = ChainedClone()
144+
145+
def test_chained_clone(self):
146+
self.run_tester(self.ChainedCloneModule, (torch.randn(1, 3, 6, 6),))

0 commit comments

Comments
 (0)