Skip to content

Commit 3117b63

Browse files
skywallStrycekSimon
authored andcommitted
NXP backend: Quantize input placeholders in NeutronQuantizer
1 parent d6e25e2 commit 3117b63

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
no_outside_users,
3636
)
3737
from torch import fx
38+
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
3839
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
3940
from torchao.quantization.pt2e.quantizer import (
4041
ComposableQuantizer,
@@ -224,6 +225,8 @@ def transform_for_annotation(
224225
return pass_runner(model).graph_module
225226

226227
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
228+
self._annotate_inputs(model)
229+
227230
nodes = list(model.graph.nodes)
228231
for node in nodes:
229232
if (
@@ -239,5 +242,25 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
239242

240243
return model
241244

245+
def _is_input_annotated(self, node: fx.Node) -> bool:
246+
return (
247+
"quantization_annotation" in node.meta
248+
and node.meta["quantization_annotation"]._annotated
249+
)
250+
251+
def _mark_input_node_as_annotated(self, node: fx.Node) -> None:
252+
if "quantization_annotation" not in node.meta:
253+
node.meta["quantization_annotation"] = QuantizationAnnotation()
254+
node.meta["quantization_annotation"]._annotated = True
255+
256+
def _annotate_inputs(self, model: fx.GraphModule):
257+
for node in model.graph.nodes:
258+
if self._is_input_annotated(node):
259+
continue
260+
261+
if node.op == "placeholder" and len(node.users) > 0:
262+
_annotate_output_qspec(node, act_qspec)
263+
self._mark_input_node_as_annotated(node)
264+
242265
def validate(self, model: torch.fx.GraphModule) -> None:
243266
return super().validate(model)

backends/nxp/tests/test_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def test_quantizer_single_maxpool2d():
195195
m(*example_input)
196196

197197
nodes = list(m.graph.nodes)
198-
assert len(nodes) == 3
199-
assert nodes[1].name == "max_pool2d"
198+
assert len(nodes) == 7
199+
assert nodes[3].name == "max_pool2d"
200200
assert "quantization_annotation" not in nodes[1].meta
201201

202202

0 commit comments

Comments
 (0)