|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -from typing import List, Optional, Tuple, Union |
8 | | - |
9 | 7 | import torch |
10 | 8 |
|
11 | 9 | from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( |
|
27 | 25 | LinearPattern, |
28 | 26 | MaxPoolPattern, |
29 | 27 | MeanDimPattern, |
| 28 | + NodeArgsIdx, |
30 | 29 | PadPattern, |
31 | 30 | PermutePattern, |
32 | 31 | QuantizationPattern, |
@@ -106,57 +105,43 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
106 | 105 | ) |
107 | 106 |
|
108 | 107 | def annotate_inputs( |
109 | | - inputs: Union[ |
110 | | - List[Tuple[fx.Node, int]], |
111 | | - List[Tuple[fx.Node, int, DerivedQuantizationSpec],], |
112 | | - ], |
113 | | - spec: Optional[QuantizationSpec], |
| 108 | + inputs: ( |
| 109 | + list[tuple[fx.Node, NodeArgsIdx]] |
| 110 | + | list[tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec]] |
| 111 | + ), |
| 112 | + spec: QuantizationSpec | None, |
114 | 113 | ) -> None: |
115 | | - for node, idx, *custom_spec in inputs: |
| 114 | + for node, args_idx, *custom_spec in inputs: |
116 | 115 | # pyre-ignore[16]: no attribute |
117 | 116 | annotation = node.meta.get( |
118 | 117 | Q_ANNOTATION_KEY, |
119 | 118 | QuantizationAnnotation(_annotated=True), |
120 | 119 | ) |
121 | 120 | arg = ( |
122 | 121 | # pyre-ignore[16]: no attribute |
123 | | - node.args[idx] |
124 | | - if isinstance(idx, int) |
| 122 | + node.args[args_idx.idx] |
| 123 | + if args_idx.inner_idx is None |
125 | 124 | # pyre-ignore[16]: no attribute |
126 | | - else node.args[idx[0]][idx[1]] |
| 125 | + else node.args[args_idx.idx][args_idx.inner_idx] |
127 | 126 | ) |
128 | 127 | annotation.input_qspec_map[arg] = ( |
129 | 128 | custom_spec[0] if custom_spec else spec |
130 | 129 | ) |
131 | 130 | # pyre-ignore[16]: no attribute |
132 | 131 | node.meta[Q_ANNOTATION_KEY] = annotation |
133 | 132 |
|
134 | | - def annotate_weights_or_biases( |
135 | | - weights_or_biases: List[Tuple[fx.Node, int]], |
136 | | - spec: Optional[QuantizationSpec], |
137 | | - ) -> None: |
138 | | - for node, idx, *custom_spec in weights_or_biases: |
139 | | - annotation = node.meta.get( |
140 | | - Q_ANNOTATION_KEY, |
141 | | - QuantizationAnnotation(_annotated=True), |
142 | | - ) |
143 | | - annotation.input_qspec_map[node.args[idx]] = ( |
144 | | - custom_spec[0] if custom_spec else spec |
145 | | - ) |
146 | | - node.meta[Q_ANNOTATION_KEY] = annotation |
147 | | - |
148 | 133 | # pyre-ignore[6]: incompatible parameter type |
149 | 134 | annotate_inputs(anchors.inputs, input_act_qspec) |
150 | | - annotate_weights_or_biases(anchors.weights, weight_qspec) |
| 135 | + annotate_inputs(anchors.weights, weight_qspec) |
151 | 136 | # pyre-ignore[6]: incompatible parameter type |
152 | | - annotate_weights_or_biases(anchors.biases, bias_qspec) |
| 137 | + annotate_inputs(anchors.biases, bias_qspec) |
153 | 138 | return model |
154 | 139 |
|
155 | 140 | def validate(self, model: fx.GraphModule) -> None: |
156 | 141 | pass |
157 | 142 |
|
158 | 143 | @classmethod |
159 | | - def get_supported_operators(cls) -> List[OperatorConfig]: |
| 144 | + def get_supported_operators(cls) -> list[OperatorConfig]: |
160 | 145 | return [] |
161 | 146 |
|
162 | 147 |
|
@@ -195,12 +180,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]: |
195 | 180 |
|
196 | 181 | class NeutronQuantizer(ComposableQuantizer): |
197 | 182 | def __init__(self): |
198 | | - static_qconfig = QuantizationConfig( |
199 | | - act_qspec, |
200 | | - act_qspec, |
201 | | - wgt_qspec, |
202 | | - None, |
203 | | - ) |
| 183 | + static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None) |
204 | 184 | static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None) |
205 | 185 | super().__init__( |
206 | 186 | [ |
|
0 commit comments