|
30 | 30 | is_vgf, |
31 | 31 | ) # usort: skip |
32 | 32 | from executorch.exir.backend.compile_spec_schema import CompileSpec |
33 | | -from torch.ao.quantization.fake_quantize import ( |
| 33 | + |
| 34 | +from torch.fx import GraphModule, Node |
| 35 | +from torchao.quantization.pt2e import ( |
34 | 36 | FakeQuantize, |
35 | 37 | FusedMovingAvgObsFakeQuantize, |
36 | | -) |
37 | | -from torch.ao.quantization.observer import ( |
38 | 38 | HistogramObserver, |
39 | 39 | MinMaxObserver, |
40 | 40 | MovingAverageMinMaxObserver, |
41 | 41 | MovingAveragePerChannelMinMaxObserver, |
| 42 | + ObserverOrFakeQuantizeConstructor, |
42 | 43 | PerChannelMinMaxObserver, |
43 | 44 | PlaceholderObserver, |
44 | 45 | ) |
45 | | -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor |
46 | | -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer |
47 | | -from torch.ao.quantization.quantizer.utils import ( |
48 | | - _annotate_input_qspec_map, |
49 | | - _annotate_output_qspec, |
| 46 | + |
| 47 | +from torchao.quantization.pt2e.quantizer import ( |
| 48 | + annotate_input_qspec_map, |
| 49 | + annotate_output_qspec, |
| 50 | + QuantizationSpec, |
| 51 | + Quantizer, |
50 | 52 | ) |
51 | | -from torch.fx import GraphModule, Node |
52 | 53 |
|
53 | 54 | __all__ = [ |
54 | 55 | "TOSAQuantizer", |
@@ -97,7 +98,7 @@ def get_symmetric_quantization_config( |
97 | 98 | weight_qscheme = ( |
98 | 99 | torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric |
99 | 100 | ) |
100 | | - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( |
| 101 | + weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = ( |
101 | 102 | MinMaxObserver |
102 | 103 | ) |
103 | 104 | if is_qat: |
@@ -337,14 +338,14 @@ def _annotate_io( |
337 | 338 | if is_annotated(node): |
338 | 339 | continue |
339 | 340 | if node.op == "placeholder" and len(node.users) > 0: |
340 | | - _annotate_output_qspec( |
| 341 | + annotate_output_qspec( |
341 | 342 | node, |
342 | 343 | quantization_config.get_output_act_qspec(), |
343 | 344 | ) |
344 | 345 | mark_node_as_annotated(node) |
345 | 346 | if node.op == "output": |
346 | 347 | parent = node.all_input_nodes[0] |
347 | | - _annotate_input_qspec_map( |
| 348 | + annotate_input_qspec_map( |
348 | 349 | node, parent, quantization_config.get_input_act_qspec() |
349 | 350 | ) |
350 | 351 | mark_node_as_annotated(node) |
|
0 commit comments