Skip to content

Commit 4d064da

Browse files
committed
Refactor permute code
1 parent 8fcb117 commit 4d064da

File tree

3 files changed

+16
-24
lines changed

3 files changed

+16
-24
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,13 @@ def input_to_nhwc(
282282
ChannelsLastTaggedReshapePass.PARTNER_NODE
283283
]
284284
else:
285-
# trace back to permute
285+
# Need to create NHWC node
286286
origin = input_node
287+
# TODO: safe/correct to always trace back?
288+
# Trace back to source node
287289
while hasattr(origin, "args") and isinstance(origin.args, tuple) and len(origin.args) > 0:
288290
origin = origin.args[0]
289291

290-
# at x choose_qparams and quantize insert permute
291292
with graph_module.graph.inserting_after(origin):
292293
input_node_nhwc = self.create_call_function_node(
293294
graph_module=graph_module,
@@ -296,24 +297,17 @@ def input_to_nhwc(
296297
memory_format=torch.channels_last,
297298
)
298299

299-
for user in list(origin.users):
300-
if user != input_node_nhwc:
301-
user.replace_input_with(origin, input_node_nhwc)
300+
# If input_node was not source
301+
if origin != input_node:
302+
print("Permuted\n\n")
303+
# Replace downstream source node with NHWC node
304+
for user in list(origin.users):
305+
if user != input_node_nhwc:
306+
user.replace_input_with(origin, input_node_nhwc)
307+
graph_module.recompile()
302308

303-
graph_module.recompile()
304309
self.mark_as_nhwc_node(input_node_nhwc)
305310

306-
# TODO: uncomment, use case when permute not needed
307-
# # Need to create NHWC node ----------------------------- CONVERSION HAPPENING ----->>
308-
# with graph_module.graph.inserting_after(input_node):
309-
# input_node_nhwc = self.create_call_function_node(
310-
# graph_module=graph_module,
311-
# target=exir_ops.edge.aten._to_copy.default,
312-
# args=(input_node,),
313-
# memory_format=torch.channels_last,
314-
# )
315-
# self.mark_as_nhwc_node(input_node_nhwc)
316-
317311
self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
318312
graph_module=graph_module,
319313
original_input=input_node,

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,14 @@ def _test_dq_conv2d(
249249

250250
tester = Tester(m, inputs, dynamic_shapes=dynamic_shapes)
251251
tester = tester.quantize(Quantize(quantization_config=quant_config))
252-
253-
tester.stages["quantize"] = tester.stages[tester.cur]
254-
255252
exported = tester.export()
256253

257-
tester.stages["export"] = exported.stages[exported.cur]
258-
259254
tester.check(["torch.ops.quantized_decomposed.choose_qparams"])
260255

256+
tester.stages["export"] = exported.stages[exported.cur]
257+
print("\n----------Exported Graph:")
258+
print(tester.stages["export"].graph_module.code)
259+
261260
tester.to_edge_transform_and_lower(
262261
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
263262
)
@@ -266,7 +265,6 @@ def _test_dq_conv2d(
266265
tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"])
267266

268267
tester.to_executorch()
269-
270268
#tester.serialize()
271269
tester.serialize().dump_artifact("conv2d.pte")
272270

backends/xnnpack/xnnpack_preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def preprocess(
144144
graph_module = ep.graph_module
145145

146146
node_to_external_map = generate_node_to_external_map(ep, graph_module)
147-
147+
print("\n----------XNNPack Preprocess Graph:", graph_module)
148148
# Make sure all inputs are contiguous_format or NCHW or default dim order
149149
assert_default_dim_order(graph_module)
150150

0 commit comments

Comments
 (0)