Skip to content

Commit 2087aa4

Browse files
skywallrobert-kalmar
authored andcommitted
NXP backend: Quantize input placeholders in NeutronQuantizer
1 parent 3eea912 commit 2087aa4

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
no_outside_users,
4242
)
4343
from torch import fx
44+
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
4445
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
4546
from torchao.quantization.pt2e.quantizer import (
4647
ComposableQuantizer,
@@ -236,6 +237,8 @@ def transform_for_annotation(
236237
return pass_runner(model).graph_module
237238

238239
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
240+
self._annotate_inputs(model)
241+
239242
nodes = list(model.graph.nodes)
240243
for node in nodes:
241244
if (
@@ -251,5 +254,25 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
251254

252255
return model
253256

257+
def _is_input_annotated(self, node: fx.Node) -> bool:
258+
return (
259+
"quantization_annotation" in node.meta
260+
and node.meta["quantization_annotation"]._annotated
261+
)
262+
263+
def _mark_input_node_as_annotated(self, node: fx.Node) -> None:
264+
if "quantization_annotation" not in node.meta:
265+
node.meta["quantization_annotation"] = QuantizationAnnotation()
266+
node.meta["quantization_annotation"]._annotated = True
267+
268+
def _annotate_inputs(self, model: fx.GraphModule):
269+
for node in model.graph.nodes:
270+
if self._is_input_annotated(node):
271+
continue
272+
273+
if node.op == "placeholder" and len(node.users) > 0:
274+
_annotate_output_qspec(node, act_qspec)
275+
self._mark_input_node_as_annotated(node)
276+
254277
def validate(self, model: torch.fx.GraphModule) -> None:
255278
return super().validate(model)

backends/nxp/tests/test_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def test_quantizer_single_maxpool2d():
195195
m(*example_input)
196196

197197
nodes = list(m.graph.nodes)
198-
assert len(nodes) == 3
199-
assert nodes[1].name == "max_pool2d"
198+
assert len(nodes) == 7
199+
assert nodes[3].name == "max_pool2d"
200200
assert "quantization_annotation" not in nodes[1].meta
201201

202202

0 commit comments

Comments
 (0)