|
10 | 10 | from torch._ops import OpOverload |
11 | 11 | from torch._subclasses import FakeTensor |
12 | 12 |
|
13 | | -from torch.ao.quantization.quantizer import QuantizationAnnotation |
14 | | -from torch.ao.quantization.quantizer.utils import ( |
15 | | - _annotate_input_qspec_map, |
16 | | - _annotate_output_qspec, |
17 | | -) |
18 | | - |
19 | 13 | from torch.export import export_for_training |
20 | 14 | from torch.fx import Graph, Node |
21 | 15 | from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( |
22 | 16 | SubgraphMatcherWithNameNodeMap, |
23 | 17 | ) |
24 | 18 |
|
| 19 | +from torchao.quantization.pt2e.quantizer import ( |
| 20 | + annotate_input_qspec_map, |
| 21 | + annotate_output_qspec as _annotate_output_qspec, |
| 22 | + QuantizationAnnotation, |
| 23 | +) |
| 24 | + |
25 | 25 | from .qconfig import QuantizationConfig |
26 | 26 |
|
27 | 27 |
|
@@ -108,7 +108,7 @@ def _annotate_fused_activation_pattern( |
108 | 108 | torch.ops.aten.linear.default, |
109 | 109 | ]: |
110 | 110 | weight_node = producer_node.args[1] |
111 | | - _annotate_input_qspec_map( |
| 111 | + annotate_input_qspec_map( |
112 | 112 | producer_node, |
113 | 113 | weight_node, |
114 | 114 | quant_config.weight, |
@@ -201,7 +201,7 @@ def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None: |
201 | 201 | return |
202 | 202 |
|
203 | 203 | weight_node = node.args[1] |
204 | | - _annotate_input_qspec_map( |
| 204 | + annotate_input_qspec_map( |
205 | 205 | node, |
206 | 206 | weight_node, |
207 | 207 | quant_config.weight, |
@@ -260,5 +260,5 @@ def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None: |
260 | 260 | return |
261 | 261 |
|
262 | 262 | wgt_node = node.args[0] |
263 | | - _annotate_input_qspec_map(node, wgt_node, quant_config.activation) |
| 263 | + annotate_input_qspec_map(node, wgt_node, quant_config.activation) |
264 | 264 | _mark_as_annotated([node]) |
0 commit comments