Skip to content

Commit 6a209ff

Browse files
NXP backend: Improve quantization annotation process
- fixes multiple subsequent nodes with SharedSpecPattern problem
1 parent 00491fd commit 6a209ff

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
1111
NeutronAtenPassManager,
1212
)
13-
1413
from executorch.backends.nxp.quantizer.patterns import (
1514
AddmmPattern,
1615
AvgPoolPattern,
@@ -24,6 +23,7 @@
2423
ReluInPlacePattern,
2524
ReluPattern,
2625
ReshapePattern,
26+
SharedSpecPattern,
2727
SoftMaxPattern,
2828
ViewPattern,
2929
)
@@ -204,9 +204,36 @@ def __init__(self):
204204
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
205205
]
206206
)
207+
# Mapping ops defined in quantizer partition types to its quantizer
208+
self.op_to_quantizer = {
209+
pt: q for q in self.quantizers for pt in q.pattern.partition_types()
210+
}
211+
# Mapping ops to the quantizer application state
212+
self.op_to_applied_quantizer = {
213+
pt: False for q in self.quantizers for pt in q.pattern.partition_types()
214+
}
207215

208216
def transform_for_annotation(
209217
self, model: torch.fx.GraphModule
210218
) -> torch.fx.GraphModule:
211219
pass_runner = NeutronAtenPassManager()
212220
return pass_runner(model).graph_module
221+
222+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
223+
nodes = list(model.graph.nodes)
224+
for node in nodes:
225+
if (
226+
node.target not in self.op_to_quantizer
227+
or self.op_to_applied_quantizer[node.target]
228+
):
229+
continue
230+
else:
231+
quantizer = self.op_to_quantizer[node.target]
232+
quantizer.annotate(model)
233+
if not isinstance(quantizer.pattern, SharedSpecPattern):
234+
self.op_to_applied_quantizer[node.target] = True
235+
236+
return model
237+
238+
def validate(self, model: torch.fx.GraphModule) -> None:
239+
return super().validate(model)

0 commit comments

Comments
 (0)