Skip to content

Commit 320d280

Browse files
skywallrobert-kalmar
authored andcommitted
NXP backend: Quantize input placeholders in NeutronQuantizer
1 parent 4197fc1 commit 320d280

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
no_outside_users,
4242
)
4343
from torch import fx
44+
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
4445
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
4546
from torchao.quantization.pt2e.quantizer import (
4647
ComposableQuantizer,
@@ -237,6 +238,8 @@ def transform_for_annotation(
237238
return pass_runner(model).graph_module
238239

239240
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
241+
self._annotate_inputs(model)
242+
240243
nodes = list(model.graph.nodes)
241244
for node in nodes:
242245
if (
@@ -252,5 +255,25 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
252255

253256
return model
254257

258+
def _is_input_annotated(self, node: fx.Node) -> bool:
259+
return (
260+
"quantization_annotation" in node.meta
261+
and node.meta["quantization_annotation"]._annotated
262+
)
263+
264+
def _mark_input_node_as_annotated(self, node: fx.Node) -> None:
265+
if "quantization_annotation" not in node.meta:
266+
node.meta["quantization_annotation"] = QuantizationAnnotation()
267+
node.meta["quantization_annotation"]._annotated = True
268+
269+
def _annotate_inputs(self, model: fx.GraphModule):
270+
for node in model.graph.nodes:
271+
if self._is_input_annotated(node):
272+
continue
273+
274+
if node.op == "placeholder" and len(node.users) > 0:
275+
_annotate_output_qspec(node, act_qspec)
276+
self._mark_input_node_as_annotated(node)
277+
255278
def validate(self, model: torch.fx.GraphModule) -> None:
256279
return super().validate(model)

backends/nxp/tests/test_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def test_quantizer_single_maxpool2d():
195195
m(*example_input)
196196

197197
nodes = list(m.graph.nodes)
198-
assert len(nodes) == 3
199-
assert nodes[1].name == "max_pool2d"
198+
assert len(nodes) == 7
199+
assert nodes[3].name == "max_pool2d"
200200
assert "quantization_annotation" not in nodes[1].meta
201201

202202

0 commit comments

Comments
 (0)