Skip to content

Commit 338c637

Browse files
roman-janik-nxpStrycekSimon
authored andcommitted
NXP backend: Improve quantization annotation process
- fixes multiple subsequent nodes with SharedSpecPattern problem
1 parent 75e4044 commit 338c637

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
1111
NeutronAtenPassManager,
1212
)
13-
1413
from executorch.backends.nxp.quantizer.patterns import (
1514
AddmmPattern,
1615
AvgPoolPattern,
@@ -24,6 +23,7 @@
2423
ReluInPlacePattern,
2524
ReluPattern,
2625
ReshapePattern,
26+
SharedSpecPattern,
2727
SoftMaxPattern,
2828
)
2929
from executorch.backends.nxp.quantizer.utils import (
@@ -202,9 +202,34 @@ def __init__(self):
202202
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
203203
]
204204
)
205+
self.op_to_quantizer = {
206+
pt: q for q in self.quantizers for pt in q.pattern.partition_types()
207+
}
208+
self.op_to_applied_quantizer = {
209+
pt: False for q in self.quantizers for pt in q.pattern.partition_types()
210+
}
205211

206212
def transform_for_annotation(
207213
self, model: torch.fx.GraphModule
208214
) -> torch.fx.GraphModule:
209215
pass_runner = NeutronAtenPassManager()
210216
return pass_runner(model).graph_module
217+
218+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
219+
nodes = list(model.graph.nodes)
220+
for node in nodes:
221+
if (
222+
node.target not in self.op_to_quantizer
223+
or self.op_to_applied_quantizer[node.target]
224+
):
225+
continue
226+
else:
227+
quantizer = self.op_to_quantizer[node.target]
228+
quantizer.annotate(model)
229+
if not isinstance(quantizer.pattern, SharedSpecPattern):
230+
self.op_to_applied_quantizer[node.target] = True
231+
232+
return model
233+
234+
def validate(self, model: torch.fx.GraphModule) -> None:
235+
return super().validate(model)

0 commit comments

Comments
 (0)