|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -from typing import List, Optional, Tuple, Union |
8 | | - |
9 | 7 | import torch |
10 | 8 | from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( |
11 | 9 | NeutronAtenPassManager, |
|
25 | 23 | LinearPattern, |
26 | 24 | MaxPoolPattern, |
27 | 25 | MeanDimPattern, |
| 26 | + NodeArgsIdx, |
28 | 27 | PadPattern, |
29 | 28 | PermutePattern, |
30 | 29 | QuantizationPattern, |
@@ -102,57 +101,43 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
102 | 101 | ) |
103 | 102 |
|
104 | 103 | def annotate_inputs( |
105 | | - inputs: Union[ |
106 | | - List[Tuple[fx.Node, int]], |
107 | | - List[Tuple[fx.Node, int, DerivedQuantizationSpec],], |
108 | | - ], |
109 | | - spec: Optional[QuantizationSpec], |
| 104 | + inputs: ( |
| 105 | + list[tuple[fx.Node, NodeArgsIdx]] |
| 106 | + | list[tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec]] |
| 107 | + ), |
| 108 | + spec: QuantizationSpec | None, |
110 | 109 | ) -> None: |
111 | | - for node, idx, *custom_spec in inputs: |
| 110 | + for node, args_idx, *custom_spec in inputs: |
112 | 111 | # pyre-ignore[16]: no attribute |
113 | 112 | annotation = node.meta.get( |
114 | 113 | Q_ANNOTATION_KEY, |
115 | 114 | QuantizationAnnotation(_annotated=True), |
116 | 115 | ) |
117 | 116 | arg = ( |
118 | 117 | # pyre-ignore[16]: no attribute |
119 | | - node.args[idx] |
120 | | - if isinstance(idx, int) |
| 118 | + node.args[args_idx.idx] |
| 119 | + if args_idx.inner_idx is None |
121 | 120 | # pyre-ignore[16]: no attribute |
122 | | - else node.args[idx[0]][idx[1]] |
| 121 | + else node.args[args_idx.idx][args_idx.inner_idx] |
123 | 122 | ) |
124 | 123 | annotation.input_qspec_map[arg] = ( |
125 | 124 | custom_spec[0] if custom_spec else spec |
126 | 125 | ) |
127 | 126 | # pyre-ignore[16]: no attribute |
128 | 127 | node.meta[Q_ANNOTATION_KEY] = annotation |
129 | 128 |
|
130 | | - def annotate_weights_or_biases( |
131 | | - weights_or_biases: List[Tuple[fx.Node, int]], |
132 | | - spec: Optional[QuantizationSpec], |
133 | | - ) -> None: |
134 | | - for node, idx, *custom_spec in weights_or_biases: |
135 | | - annotation = node.meta.get( |
136 | | - Q_ANNOTATION_KEY, |
137 | | - QuantizationAnnotation(_annotated=True), |
138 | | - ) |
139 | | - annotation.input_qspec_map[node.args[idx]] = ( |
140 | | - custom_spec[0] if custom_spec else spec |
141 | | - ) |
142 | | - node.meta[Q_ANNOTATION_KEY] = annotation |
143 | | - |
144 | 129 | # pyre-ignore[6]: incompatible parameter type |
145 | 130 | annotate_inputs(anchors.inputs, input_act_qspec) |
146 | | - annotate_weights_or_biases(anchors.weights, weight_qspec) |
| 131 | + annotate_inputs(anchors.weights, weight_qspec) |
147 | 132 | # pyre-ignore[6]: incompatible parameter type |
148 | | - annotate_weights_or_biases(anchors.biases, bias_qspec) |
| 133 | + annotate_inputs(anchors.biases, bias_qspec) |
149 | 134 | return model |
150 | 135 |
|
151 | 136 | def validate(self, model: fx.GraphModule) -> None: |
152 | 137 | pass |
153 | 138 |
|
154 | 139 | @classmethod |
155 | | - def get_supported_operators(cls) -> List[OperatorConfig]: |
| 140 | + def get_supported_operators(cls) -> list[OperatorConfig]: |
156 | 141 | return [] |
157 | 142 |
|
158 | 143 |
|
@@ -191,12 +176,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]: |
191 | 176 |
|
192 | 177 | class NeutronQuantizer(ComposableQuantizer): |
193 | 178 | def __init__(self): |
194 | | - static_qconfig = QuantizationConfig( |
195 | | - act_qspec, |
196 | | - act_qspec, |
197 | | - wgt_qspec, |
198 | | - None, |
199 | | - ) |
| 179 | + static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None) |
200 | 180 | static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None) |
201 | 181 | super().__init__( |
202 | 182 | [ |
|
0 commit comments