|
30 | 30 | is_vgf, |
31 | 31 | ) # usort: skip |
32 | 32 | from executorch.exir.backend.compile_spec_schema import CompileSpec |
33 | | -from torch.ao.quantization.fake_quantize import ( |
| 33 | +from torch.fx import GraphModule, Node |
| 34 | +from torchao.quantization.pt2e import ( |
34 | 35 | FakeQuantize, |
35 | 36 | FusedMovingAvgObsFakeQuantize, |
36 | | -) |
37 | | -from torch.ao.quantization.observer import ( |
38 | 37 | HistogramObserver, |
39 | 38 | MinMaxObserver, |
40 | 39 | MovingAverageMinMaxObserver, |
41 | 40 | MovingAveragePerChannelMinMaxObserver, |
| 41 | + ObserverOrFakeQuantizeConstructor, |
42 | 42 | PerChannelMinMaxObserver, |
43 | 43 | PlaceholderObserver, |
44 | 44 | ) |
45 | | -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor |
46 | | -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer |
47 | | -from torch.ao.quantization.quantizer.utils import ( |
48 | | - _annotate_input_qspec_map, |
49 | | - _annotate_output_qspec, |
| 45 | +from torchao.quantization.pt2e.quantizer import ( |
| 46 | + annotate_input_qspec_map, |
| 47 | + annotate_output_qspec, |
| 48 | + QuantizationSpec, |
| 49 | + Quantizer, |
50 | 50 | ) |
51 | | -from torch.fx import GraphModule, Node |
52 | 51 |
|
53 | 52 | __all__ = [ |
54 | 53 | "TOSAQuantizer", |
@@ -97,7 +96,7 @@ def get_symmetric_quantization_config( |
97 | 96 | weight_qscheme = ( |
98 | 97 | torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric |
99 | 98 | ) |
100 | | - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( |
| 99 | + weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = ( |
101 | 100 | MinMaxObserver |
102 | 101 | ) |
103 | 102 | if is_qat: |
@@ -337,14 +336,14 @@ def _annotate_io( |
337 | 336 | if is_annotated(node): |
338 | 337 | continue |
339 | 338 | if node.op == "placeholder" and len(node.users) > 0: |
340 | | - _annotate_output_qspec( |
| 339 | + annotate_output_qspec( |
341 | 340 | node, |
342 | 341 | quantization_config.get_output_act_qspec(), |
343 | 342 | ) |
344 | 343 | mark_node_as_annotated(node) |
345 | 344 | if node.op == "output": |
346 | 345 | parent = node.all_input_nodes[0] |
347 | | - _annotate_input_qspec_map( |
| 346 | + annotate_input_qspec_map( |
348 | 347 | node, parent, quantization_config.get_input_act_qspec() |
349 | 348 | ) |
350 | 349 | mark_node_as_annotated(node) |
|
0 commit comments