Skip to content

Commit e7c43b6

Browse files
roman-janik-nxpStrycekSimon
authored andcommitted
NXP backend: Abstract PartitionAnchors annotations of arg indexes
1 parent 7e228ee commit e7c43b6

File tree

2 files changed

+86
-93
lines changed

2 files changed

+86
-93
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List, Optional, Tuple, Union
8-
97
import torch
108

119
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
@@ -27,6 +25,7 @@
2725
LinearPattern,
2826
MaxPoolPattern,
2927
MeanDimPattern,
28+
NodeArgsIdx,
3029
PadPattern,
3130
PermutePattern,
3231
QuantizationPattern,
@@ -106,57 +105,43 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
106105
)
107106

108107
def annotate_inputs(
109-
inputs: Union[
110-
List[Tuple[fx.Node, int]],
111-
List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
112-
],
113-
spec: Optional[QuantizationSpec],
108+
inputs: (
109+
list[tuple[fx.Node, NodeArgsIdx]]
110+
| list[tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec]]
111+
),
112+
spec: QuantizationSpec | None,
114113
) -> None:
115-
for node, idx, *custom_spec in inputs:
114+
for node, args_idx, *custom_spec in inputs:
116115
# pyre-ignore[16]: no attribute
117116
annotation = node.meta.get(
118117
Q_ANNOTATION_KEY,
119118
QuantizationAnnotation(_annotated=True),
120119
)
121120
arg = (
122121
# pyre-ignore[16]: no attribute
123-
node.args[idx]
124-
if isinstance(idx, int)
122+
node.args[args_idx.idx]
123+
if args_idx.inner_idx is None
125124
# pyre-ignore[16]: no attribute
126-
else node.args[idx[0]][idx[1]]
125+
else node.args[args_idx.idx][args_idx.inner_idx]
127126
)
128127
annotation.input_qspec_map[arg] = (
129128
custom_spec[0] if custom_spec else spec
130129
)
131130
# pyre-ignore[16]: no attribute
132131
node.meta[Q_ANNOTATION_KEY] = annotation
133132

134-
def annotate_weights_or_biases(
135-
weights_or_biases: List[Tuple[fx.Node, int]],
136-
spec: Optional[QuantizationSpec],
137-
) -> None:
138-
for node, idx, *custom_spec in weights_or_biases:
139-
annotation = node.meta.get(
140-
Q_ANNOTATION_KEY,
141-
QuantizationAnnotation(_annotated=True),
142-
)
143-
annotation.input_qspec_map[node.args[idx]] = (
144-
custom_spec[0] if custom_spec else spec
145-
)
146-
node.meta[Q_ANNOTATION_KEY] = annotation
147-
148133
# pyre-ignore[6]: incompatible parameter type
149134
annotate_inputs(anchors.inputs, input_act_qspec)
150-
annotate_weights_or_biases(anchors.weights, weight_qspec)
135+
annotate_inputs(anchors.weights, weight_qspec)
151136
# pyre-ignore[6]: incompatible parameter type
152-
annotate_weights_or_biases(anchors.biases, bias_qspec)
137+
annotate_inputs(anchors.biases, bias_qspec)
153138
return model
154139

155140
def validate(self, model: fx.GraphModule) -> None:
156141
pass
157142

158143
@classmethod
159-
def get_supported_operators(cls) -> List[OperatorConfig]:
144+
def get_supported_operators(cls) -> list[OperatorConfig]:
160145
return []
161146

162147

@@ -195,12 +180,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
195180

196181
class NeutronQuantizer(ComposableQuantizer):
197182
def __init__(self):
198-
static_qconfig = QuantizationConfig(
199-
act_qspec,
200-
act_qspec,
201-
wgt_qspec,
202-
None,
203-
)
183+
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
204184
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
205185
super().__init__(
206186
[

0 commit comments

Comments
 (0)