Skip to content

Commit 8fcb117

Browse files
committed
Permute before quant
1 parent 0b5b0e8 commit 8fcb117

File tree

5 files changed

+34
-33
lines changed

5 files changed

+34
-33
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,16 +282,38 @@ def input_to_nhwc(
282282
ChannelsLastTaggedReshapePass.PARTNER_NODE
283283
]
284284
else:
285-
# Need to create NHWC node
286-
with graph_module.graph.inserting_after(input_node):
285+
# trace back to permute
286+
origin = input_node
287+
while hasattr(origin, "args") and isinstance(origin.args, tuple) and len(origin.args) > 0:
288+
origin = origin.args[0]
289+
290+
# at x choose_qparams and quantize insert permute
291+
with graph_module.graph.inserting_after(origin):
287292
input_node_nhwc = self.create_call_function_node(
288293
graph_module=graph_module,
289294
target=exir_ops.edge.aten._to_copy.default,
290-
args=(input_node,),
295+
args=(origin,),
291296
memory_format=torch.channels_last,
292297
)
298+
299+
for user in list(origin.users):
300+
if user != input_node_nhwc:
301+
user.replace_input_with(origin, input_node_nhwc)
302+
303+
graph_module.recompile()
293304
self.mark_as_nhwc_node(input_node_nhwc)
294305

306+
# TODO: uncomment, use case when permute not needed
307+
# # Need to create NHWC node ----------------------------- CONVERSION HAPPENING ----->>
308+
# with graph_module.graph.inserting_after(input_node):
309+
# input_node_nhwc = self.create_call_function_node(
310+
# graph_module=graph_module,
311+
# target=exir_ops.edge.aten._to_copy.default,
312+
# args=(input_node,),
313+
# memory_format=torch.channels_last,
314+
# )
315+
# self.mark_as_nhwc_node(input_node_nhwc)
316+
295317
self.insert_copy_and_assign_partner_nodes_quantization_sensitive(
296318
graph_module=graph_module,
297319
original_input=input_node,

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ def _supported_symmetric_quantized_operators() -> dict[str, list[OperatorPattern
7171
"conv2d": [
7272
[torch.nn.Conv2d, torch.nn.ReLU],
7373
[torch.nn.Conv2d, F.relu],
74+
[torch.nn.Conv2d],
7475
[F.conv2d, torch.nn.ReLU],
7576
[F.conv2d, F.relu],
77+
[F.conv2d],
7678
],
7779
"linear": [[torch.nn.Linear], [F.linear]],
7880
"add": [[torch.add]],

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -305,16 +305,6 @@ def _do_annotate_conv(
305305
if not is_conv_node(n):
306306
continue
307307

308-
# TODO: Check for dynamically quantized convs and check if nn.Conv2d is always lowered
309-
# Only dynamically quantize 2D convolutions
310-
# Handle both nn.Conv2d and aten.conv2d.default
311-
if n.op == "call_module":
312-
mod = gm.get_submodule(n.target)
313-
if not hasattr(mod, "padding") or len(mod.padding) != 2:
314-
continue
315-
elif n.op == "call_function" and n.target != torch.ops.aten.conv2d.default:
316-
continue
317-
318308
conv_node = n
319309

320310
# This is hacky!

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ Error defineStaticTransposeNode(
11721172
ET_CHECK_OR_RETURN_ERROR(
11731173
status == xnn_status_success,
11741174
Internal,
1175-
"Failed to create sigmoid node %i with code: %s",
1175+
"Failed to create static transpose node %i with code: %s",
11761176
node->debug_handle(),
11771177
xnn_status_to_string(status));
11781178

backends/xnnpack/test/ops/test_conv2d.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,6 @@ def _test_dq_conv2d(
240240
quant_config = get_symmetric_quantization_config(
241241
is_per_channel=True,
242242
is_dynamic=True,
243-
act_qmin=-128,
244-
act_qmax=127,
245-
weight_qmin=-128,
246-
weight_qmax=127,
247243
)
248244

249245
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
@@ -254,35 +250,26 @@ def _test_dq_conv2d(
254250
tester = Tester(m, inputs, dynamic_shapes=dynamic_shapes)
255251
tester = tester.quantize(Quantize(quantization_config=quant_config))
256252

257-
# Print after quantization
258253
tester.stages["quantize"] = tester.stages[tester.cur]
259-
print("\n----------Annotated Graph:")
260-
print(tester.stages["quantize"].graph_module.code)
261254

262255
exported = tester.export()
263256

264-
# Print after exporting
265257
tester.stages["export"] = exported.stages[exported.cur]
266-
print("\n----------Exported Graph:")
267-
print(tester.stages["export"].graph_module.code)
268258

269-
# Check for choose_qparams
270259
tester.check(["torch.ops.quantized_decomposed.choose_qparams"])
271260

272261
tester.to_edge_transform_and_lower(
273262
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
274263
)
275264

276-
# Print after lower and partition
277-
print("\n----------Lowered Graph:")
278-
print(tester.stages[tester.cur].graph_module.code)
279-
280-
tester.check(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
281265
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
282266
tester.check_not(["executorch_exir_dialects_edge__ops_aten_conv2d_default"])
283267

284268
tester.to_executorch()
285-
tester.serialize()
269+
270+
#tester.serialize()
271+
tester.serialize().dump_artifact("conv2d.pte")
272+
286273
tester.run_method_and_compare_outputs(atol=atol)
287274

288275
def test_fp16_conv2d(self) -> None:
@@ -766,15 +753,15 @@ def test_dq_conv2d(self) -> None:
766753
class SimpleConv2d(torch.nn.Module):
767754
def __init__(self):
768755
super().__init__()
769-
self.conv = torch.nn.Conv2d(1, 2, 3)
756+
self.conv = torch.nn.Conv2d(3, 10, 3, )
770757
self.conv.weight.requires_grad = False
771758
self.conv.bias.requires_grad = False
772759

773760
def forward(self, x):
774761
return self.conv(x)
775762

776763
def get_inputs(self):
777-
return (torch.randn(1, 1, 8, 8),)
764+
return (torch.randn(1, 3, 8, 8),)
778765

779766
model = SimpleConv2d()
780767
self._test_dq_conv2d(

0 commit comments

Comments
 (0)