Skip to content

Commit 8b91628

Browse files
committed
partition to_dim_order_copy in XNN delegate
1 parent 551f6b7 commit 8b91628

File tree

5 files changed

+20
-81
lines changed

5 files changed

+20
-81
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
400400
# The node requires nchw inputs
401401
for input_node in node.all_input_nodes:
402402
self.input_to_nchw(graph_module, input_node, node)
403+
elif node.target == exir_ops.edge.aten._to_copy.default:
404+
if node.meta["val"].is_contiguous():
405+
self.mark_as_nchw_node(node)
406+
else:
407+
self.mark_as_nhwc_node(node)
403408
else:
404409
# The node can have inputs in any format (but all must be the
405410
# same format)

backends/xnnpack/partition/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
SoftmaxConfig,
4949
SquareRootConfig,
5050
SubConfig,
51+
ToDimOrderCopyConfig,
5152
UpsampleBilinear2dConfig,
5253
)
5354
from executorch.backends.xnnpack.partition.config.node_configs import (
@@ -97,6 +98,7 @@
9798
PreluConfig,
9899
ReciprocalSquareRootConfig,
99100
ReLUConfig,
101+
ToDimOrderCopyConfig,
100102
# SDPAConfig, TODO: D60553559: preserving SDPA for fairseq fails
101103
SigmoidConfig,
102104
SliceCopyConfig,

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,13 @@ def supported_precision_types(self) -> List[ConfigPrecisionType]:
364364
return [ConfigPrecisionType.FP32]
365365

366366

367+
class ToDimOrderCopyConfig(GenericNodePartitionerConfig):
368+
target_name = "_to_dim_order_copy.default"
369+
370+
def supported_precision_types(self) -> List[ConfigPrecisionType]:
371+
return [ConfigPrecisionType.FP32]
372+
373+
367374
class MeanDimConfig(GenericNodePartitionerConfig):
368375
target_name = "mean.dim"
369376

backends/xnnpack/runtime/XNNExecutor.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,8 @@ ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) {
114114
XNN_MAX_TENSOR_DIMS,
115115
num_dims);
116116

117-
<<<<<<< HEAD
118117
for (int j = 0; j < num_dims; ++j) {
119118
dims[j] = tensor->size(static_cast<int>(dim_order[j]));
120-
=======
121-
for (int i = 0; i < num_dims; ++i) {
122-
dims[i] = tensor->size(static_cast<int>(dim_order[i]));
123-
>>>>>>> b5785b2fda (support channels last dim order in xnnpack)
124119
}
125120
status =
126121
xnn_reshape_external_value(runtime_.get(), ext_id, num_dims, dims);
@@ -229,15 +224,9 @@ ET_NODISCARD Error XNNExecutor::resize_outputs(EValue** args) const {
229224
Internal,
230225
"Failed to retrieve dim order from tensor!");
231226

232-
<<<<<<< HEAD
233227
for (int j = 0; j < num_dim; ++j) {
234228
expected_output_size[static_cast<int>(dim_order[j])] =
235229
static_cast<SizesType>(dims[j]);
236-
=======
237-
for (int i = 0; i < num_dim; ++i) {
238-
expected_output_size[static_cast<int>(dim_order[i])] =
239-
static_cast<SizesType>(dims[i]);
240-
>>>>>>> b5785b2fda (support channels last dim order in xnnpack)
241230
}
242231

243232
executorch::aten::ArrayRef<SizesType> output_size{

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 6 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -130,87 +130,23 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
130130
.run_method_and_compare_outputs()
131131
)
132132

133-
class LinearConv(torch.nn.Module):
133+
class LinearConvDimSwap(torch.nn.Module):
134134
def __init__(self):
135135
super().__init__()
136136
self.conv1 = torch.nn.Conv2d(3, 3, 3)
137137
self.linear1 = torch.nn.Linear(4, 3)
138138

139139
def forward(self, x):
140140
y = self.linear1(x)
141+
y = y.to(memory_format=torch.channels_last)
142+
y = y.to(memory_format=torch.contiguous_format)
141143
return self.conv1(y)
142144

143-
def test_conv_linear_dim_order_swaps_on_nhwc_input(self):
144-
tester = Tester(
145-
self.LinearConv().eval(),
146-
(torch.randn(1, 3, 6, 4).to(memory_format=torch.channels_last),),
147-
)
148-
149-
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
150-
151-
def test_conv_linear_dim_order_swaps_on_nchw_input(self):
152-
tester = Tester(
153-
self.LinearConv().eval(),
154-
(torch.randn(1, 3, 6, 4),),
155-
)
156-
157-
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
158-
159-
class ConvLinearConv(torch.nn.Module):
160-
def __init__(self):
161-
super().__init__()
162-
self.conv1 = torch.nn.Conv2d(3, 3, 3)
163-
self.linear1 = torch.nn.Linear(4, 4)
164-
165-
def forward(self, x):
166-
y = self.conv1(x)
167-
return self.linear1(y)
168-
169-
def test_linear_conv_dim_order_swaps_on_nhwc_input(self):
170-
tester = Tester(
171-
self.ConvLinearConv().eval(),
172-
(torch.randn(1, 3, 6, 6).to(memory_format=torch.channels_last),),
173-
)
174-
175-
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
176-
177-
def test_linear_conv_dim_order_swaps_on_nchw_input(self):
178-
tester = Tester(
179-
self.ConvLinearConv().eval(),
180-
(torch.randn(1, 3, 6, 6),),
181-
)
182-
183-
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
184-
185-
class Bilinear(torch.nn.Module):
186-
def __init__(self):
187-
super().__init__()
188-
189-
def forward(self, x):
190-
return torch.nn.functional.interpolate(
191-
x, scale_factor=2, mode="bilinear", align_corners=True
192-
)
193-
194-
def test_nhwc_input_on_nhwc_op(self):
195-
tester = Tester(
196-
self.Bilinear().eval(),
197-
(
198-
torch.arange(8)
199-
.reshape(1, 2, 2, 2)
200-
.to(torch.float32)
201-
.to(memory_format=torch.channels_last),
202-
),
203-
)
145+
LinearConvDimSwapModule = LinearConvDimSwap()
204146

205-
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
147+
def test_conv_linear_dim_order_swap_partitioner(self):
148+
self.run_tester(self.LinearConvDimSwapModule, (torch.randn(1, 3, 6, 4),))
206149

207-
def test_nchw_input_on_nhwc_op(self):
208-
tester = Tester(
209-
self.Bilinear().eval(),
210-
(torch.arange(8).reshape(1, 2, 2, 2).to(torch.float32),),
211-
)
212-
213-
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
214150

215151
def test_qs8_channels_last_tagged_reshape_pass(self):
216152
for module, num_reshape in self.modules.items():

0 commit comments

Comments
 (0)