|
10 | 10 | from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( |
11 | 11 | NeutronAtenPassManager, |
12 | 12 | ) |
13 | | - |
14 | 13 | from executorch.backends.nxp.quantizer.patterns import ( |
15 | 14 | AddmmPattern, |
16 | 15 | AvgPoolPattern, |
|
24 | 23 | ReluInPlacePattern, |
25 | 24 | ReluPattern, |
26 | 25 | ReshapePattern, |
| 26 | + SharedSpecPattern, |
27 | 27 | SoftMaxPattern, |
28 | 28 | ) |
29 | 29 | from executorch.backends.nxp.quantizer.utils import ( |
@@ -202,9 +202,34 @@ def __init__(self): |
202 | 202 | NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig), |
203 | 203 | ] |
204 | 204 | ) |
| 205 | + self.op_to_quantizer = { |
| 206 | + pt: q for q in self.quantizers for pt in q.pattern.partition_types() |
| 207 | + } |
| 208 | + self.op_to_applied_quantizer = { |
| 209 | + pt: False for q in self.quantizers for pt in q.pattern.partition_types() |
| 210 | + } |
205 | 211 |
|
206 | 212 | def transform_for_annotation( |
207 | 213 | self, model: torch.fx.GraphModule |
208 | 214 | ) -> torch.fx.GraphModule: |
209 | 215 | pass_runner = NeutronAtenPassManager() |
210 | 216 | return pass_runner(model).graph_module |
| 217 | + |
| 218 | + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 219 | + nodes = list(model.graph.nodes) |
| 220 | + for node in nodes: |
| 221 | + if ( |
| 222 | + node.target not in self.op_to_quantizer |
| 223 | + or self.op_to_applied_quantizer[node.target] |
| 224 | + ): |
| 225 | + continue |
| 226 | + else: |
| 227 | + quantizer = self.op_to_quantizer[node.target] |
| 228 | + quantizer.annotate(model) |
| 229 | + if not isinstance(quantizer.pattern, SharedSpecPattern): |
| 230 | + self.op_to_applied_quantizer[node.target] = True |
| 231 | + |
| 232 | + return model |
| 233 | + |
| 234 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 235 | + return super().validate(model) |
0 commit comments