Skip to content

Commit 2905b98

Browse files
committed
Corrects input to conv
1 parent 4d064da commit 2905b98

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -283,31 +283,34 @@ def input_to_nhwc(
283283
]
284284
else:
285285
# Need to create NHWC node
286-
origin = input_node
286+
source_node = input_node
287+
287288
# TODO: safe/correct to always trace back?
288-
# Trace back to source node
289-
while hasattr(origin, "args") and isinstance(origin.args, tuple) and len(origin.args) > 0:
290-
origin = origin.args[0]
289+
# Trace back to find original source node
290+
while (
291+
hasattr(source_node, "args")
292+
and isinstance(source_node.args, tuple)
293+
and len(source_node.args) > 0
294+
):
295+
source_node = source_node.args[0]
291296

292-
with graph_module.graph.inserting_after(origin):
297+
with graph_module.graph.inserting_after(source_node):
293298
input_node_nhwc = self.create_call_function_node(
294299
graph_module=graph_module,
295300
target=exir_ops.edge.aten._to_copy.default,
296-
args=(origin,),
301+
args=(source_node,),
297302
memory_format=torch.channels_last,
298303
)
299304

300-
# If input_node was not source
301-
if origin != input_node:
302-
print("Permuted\n\n")
305+
# If input_node was not the original source node
306+
if source_node != input_node:
307+
input_node = source_node
303308
# Replace downstream source node with NHWC node
304-
for user in list(origin.users):
309+
for user in list(input_node.users):
305310
if user != input_node_nhwc:
306-
user.replace_input_with(origin, input_node_nhwc)
311+
user.replace_input_with(input_node, input_node_nhwc)
307312
graph_module.recompile()
308313

309-
self.mark_as_nhwc_node(input_node_nhwc)
310-
311314
self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
312315
graph_module=graph_module,
313316
original_input=input_node,

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,14 +249,10 @@ def _test_dq_conv2d(
249249

250250
tester = Tester(m, inputs, dynamic_shapes=dynamic_shapes)
251251
tester = tester.quantize(Quantize(quantization_config=quant_config))
252-
exported = tester.export()
252+
tester.export()
253253

254254
tester.check(["torch.ops.quantized_decomposed.choose_qparams"])
255255

256-
tester.stages["export"] = exported.stages[exported.cur]
257-
print("\n----------Exported Graph:")
258-
print(tester.stages["export"].graph_module.code)
259-
260256
tester.to_edge_transform_and_lower(
261257
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
262258
)
@@ -265,7 +261,7 @@ def _test_dq_conv2d(
265261
tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"])
266262

267263
tester.to_executorch()
268-
#tester.serialize()
264+
# tester.serialize()
269265
tester.serialize().dump_artifact("conv2d.pte")
270266

271267
tester.run_method_and_compare_outputs(atol=atol)
@@ -751,7 +747,11 @@ def test_dq_conv2d(self) -> None:
751747
class SimpleConv2d(torch.nn.Module):
752748
def __init__(self):
753749
super().__init__()
754-
self.conv = torch.nn.Conv2d(3, 10, 3, )
750+
self.conv = torch.nn.Conv2d(
751+
3,
752+
10,
753+
3,
754+
)
755755
self.conv.weight.requires_grad = False
756756
self.conv.bias.requires_grad = False
757757

backends/xnnpack/xnnpack_preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def preprocess(
144144
graph_module = ep.graph_module
145145

146146
node_to_external_map = generate_node_to_external_map(ep, graph_module)
147-
print("\n----------XNNPack Preprocess Graph:", graph_module)
147+
148148
# Make sure all inputs are contiguous_format or NCHW or default dim order
149149
assert_default_dim_order(graph_module)
150150

0 commit comments

Comments
 (0)