Skip to content

Commit bcfe3ee

Browse files
roman-janik-nxpStrycekSimon
authored andcommitted
NXP backend: Abstract PartitionAnchors annotations of arg indexes
1 parent 5aa127e commit bcfe3ee

File tree

2 files changed

+84
-93
lines changed

2 files changed

+84
-93
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List, Optional, Tuple, Union
8-
97
import torch
108
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
119
NeutronAtenPassManager,
@@ -25,6 +23,7 @@
2523
LinearPattern,
2624
MaxPoolPattern,
2725
MeanDimPattern,
26+
NodeArgsIdx,
2827
PadPattern,
2928
PermutePattern,
3029
QuantizationPattern,
@@ -102,57 +101,43 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
102101
)
103102

104103
def annotate_inputs(
105-
inputs: Union[
106-
List[Tuple[fx.Node, int]],
107-
List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
108-
],
109-
spec: Optional[QuantizationSpec],
104+
inputs: (
105+
list[tuple[fx.Node, NodeArgsIdx]]
106+
| list[tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec]]
107+
),
108+
spec: QuantizationSpec | None,
110109
) -> None:
111-
for node, idx, *custom_spec in inputs:
110+
for node, args_idx, *custom_spec in inputs:
112111
# pyre-ignore[16]: no attribute
113112
annotation = node.meta.get(
114113
Q_ANNOTATION_KEY,
115114
QuantizationAnnotation(_annotated=True),
116115
)
117116
arg = (
118117
# pyre-ignore[16]: no attribute
119-
node.args[idx]
120-
if isinstance(idx, int)
118+
node.args[args_idx.idx]
119+
if args_idx.inner_idx is None
121120
# pyre-ignore[16]: no attribute
122-
else node.args[idx[0]][idx[1]]
121+
else node.args[args_idx.idx][args_idx.inner_idx]
123122
)
124123
annotation.input_qspec_map[arg] = (
125124
custom_spec[0] if custom_spec else spec
126125
)
127126
# pyre-ignore[16]: no attribute
128127
node.meta[Q_ANNOTATION_KEY] = annotation
129128

130-
def annotate_weights_or_biases(
131-
weights_or_biases: List[Tuple[fx.Node, int]],
132-
spec: Optional[QuantizationSpec],
133-
) -> None:
134-
for node, idx, *custom_spec in weights_or_biases:
135-
annotation = node.meta.get(
136-
Q_ANNOTATION_KEY,
137-
QuantizationAnnotation(_annotated=True),
138-
)
139-
annotation.input_qspec_map[node.args[idx]] = (
140-
custom_spec[0] if custom_spec else spec
141-
)
142-
node.meta[Q_ANNOTATION_KEY] = annotation
143-
144129
# pyre-ignore[6]: incompatible parameter type
145130
annotate_inputs(anchors.inputs, input_act_qspec)
146-
annotate_weights_or_biases(anchors.weights, weight_qspec)
131+
annotate_inputs(anchors.weights, weight_qspec)
147132
# pyre-ignore[6]: incompatible parameter type
148-
annotate_weights_or_biases(anchors.biases, bias_qspec)
133+
annotate_inputs(anchors.biases, bias_qspec)
149134
return model
150135

151136
def validate(self, model: fx.GraphModule) -> None:
152137
pass
153138

154139
@classmethod
155-
def get_supported_operators(cls) -> List[OperatorConfig]:
140+
def get_supported_operators(cls) -> list[OperatorConfig]:
156141
return []
157142

158143

@@ -191,12 +176,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
191176

192177
class NeutronQuantizer(ComposableQuantizer):
193178
def __init__(self):
194-
static_qconfig = QuantizationConfig(
195-
act_qspec,
196-
act_qspec,
197-
wgt_qspec,
198-
None,
199-
)
179+
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
200180
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
201181
super().__init__(
202182
[

0 commit comments

Comments
 (0)