Skip to content

Commit 5f9da89

Browse files
metascroyfacebook-github-bot
authored andcommitted
Migrate ExecuTorch's use of pt2e from torch.ao to torchao (#10294)
Summary: Most code related to PT2E quantization is migrating from torch.ao.quantization to torchao.quantization.pt2e. torchao.quantization.pt2e contains an exact copy of PT2E code in torch.ao.quantization. The torchao pin in ExecuTorch has already been bumped pick up these changes. Pull Request resolved: #10294 Reviewed By: SS-JIA Differential Revision: D74694311 Pulled By: metascroy
1 parent 9aaea31 commit 5f9da89

File tree

81 files changed

+324
-267
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+324
-267
lines changed

.lintrunner.toml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,31 @@ command = [
378378
'--',
379379
'@{{PATHSFILE}}',
380380
]
381+
382+
[[linter]]
383+
code = "TORCH_AO_IMPORT"
384+
include_patterns = ["**/*.py"]
385+
exclude_patterns = [
386+
"third-party/**",
387+
]
388+
389+
command = [
390+
"python3",
391+
"-m",
392+
"lintrunner_adapters",
393+
"run",
394+
"grep_linter",
395+
"--pattern=\\bfrom torch\\.ao\\.quantization\\.(?:quantizer|observer|quantize_pt2e|pt2e)(?:\\.[A-Za-z0-9_]+)*\\b",
396+
"--linter-name=TorchAOImport",
397+
"--error-name=Prohibited torch.ao.quantization import",
398+
"""--error-description=\
399+
Imports from torch.ao.quantization are not allowed. \
400+
Please import from torchao.quantization.pt2e instead.\n \
401+
* torchao.quantization.pt2e (includes all the utils, including observers, fake quants etc.) \n \
402+
* torchao.quantization.pt2e.quantizer (quantizer related objects and utils) \n \
403+
* torchao.quantization.pt2e.quantize_pt2e (prepare_pt2e, prepare_qat_pt2e, convert_pt2e) \n\n \
404+
If you need something from torch.ao.quantization, you can add your file to an exclude_patterns for TORCH_AO_IMPORT in .lintrunner.toml. \
405+
""",
406+
"--",
407+
"@{{PATHSFILE}}",
408+
]

.mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,6 @@ ignore_missing_imports = True
9797

9898
[mypy-zstd]
9999
ignore_missing_imports = True
100+
101+
[mypy-torchao.*]
102+
follow_untyped_imports = True

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
)
1616

1717
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
18-
from torch.ao.quantization.quantize_pt2e import (
18+
from torch.export import export_for_training
19+
from torchao.quantization.pt2e.quantize_pt2e import (
1920
convert_pt2e,
2021
prepare_pt2e,
2122
prepare_qat_pt2e,
2223
)
23-
from torch.export import export_for_training
2424

2525

2626
class TestCoreMLQuantizer:

backends/arm/quantizer/arm_quantizer.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,24 @@
3030
is_vgf,
3131
) # usort: skip
3232
from executorch.exir.backend.compile_spec_schema import CompileSpec
33-
from torch.ao.quantization.fake_quantize import (
33+
from torch.fx import GraphModule, Node
34+
from torchao.quantization.pt2e import (
3435
FakeQuantize,
3536
FusedMovingAvgObsFakeQuantize,
36-
)
37-
from torch.ao.quantization.observer import (
3837
HistogramObserver,
3938
MinMaxObserver,
4039
MovingAverageMinMaxObserver,
4140
MovingAveragePerChannelMinMaxObserver,
41+
ObserverOrFakeQuantizeConstructor,
4242
PerChannelMinMaxObserver,
4343
PlaceholderObserver,
4444
)
45-
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
46-
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
47-
from torch.ao.quantization.quantizer.utils import (
48-
_annotate_input_qspec_map,
49-
_annotate_output_qspec,
45+
from torchao.quantization.pt2e.quantizer import (
46+
annotate_input_qspec_map,
47+
annotate_output_qspec,
48+
QuantizationSpec,
49+
Quantizer,
5050
)
51-
from torch.fx import GraphModule, Node
5251

5352
__all__ = [
5453
"TOSAQuantizer",
@@ -97,7 +96,7 @@ def get_symmetric_quantization_config(
9796
weight_qscheme = (
9897
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
9998
)
100-
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
99+
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
101100
MinMaxObserver
102101
)
103102
if is_qat:
@@ -337,14 +336,14 @@ def _annotate_io(
337336
if is_annotated(node):
338337
continue
339338
if node.op == "placeholder" and len(node.users) > 0:
340-
_annotate_output_qspec(
339+
annotate_output_qspec(
341340
node,
342341
quantization_config.get_output_act_qspec(),
343342
)
344343
mark_node_as_annotated(node)
345344
if node.op == "output":
346345
parent = node.all_input_nodes[0]
347-
_annotate_input_qspec_map(
346+
annotate_input_qspec_map(
348347
node, parent, quantization_config.get_input_act_qspec()
349348
)
350349
mark_node_as_annotated(node)

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
import torch
1717
from torch._subclasses import FakeTensor
18-
19-
from torch.ao.quantization.quantizer import QuantizationAnnotation
2018
from torch.fx import GraphModule, Node
2119

20+
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
21+
2222

2323
def is_annotated(node: Node) -> bool:
2424
"""Given a node return whether the node is annotated."""

backends/arm/quantizer/quantization_annotator.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
import torch.fx
1313
from executorch.backends.arm.quantizer import QuantizationConfig
1414
from executorch.backends.arm.tosa_utils import get_node_debug_info
15-
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
16-
from torch.ao.quantization.quantizer.utils import (
17-
_annotate_input_qspec_map,
18-
_annotate_output_qspec,
19-
)
2015
from torch.fx import Node
16+
from torchao.quantization.pt2e.quantizer import (
17+
annotate_input_qspec_map,
18+
annotate_output_qspec,
19+
QuantizationSpecBase,
20+
SharedQuantizationSpec,
21+
)
2122

2223
from .arm_quantizer_utils import (
2324
is_annotated,
@@ -118,7 +119,7 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
118119
strict=True,
119120
):
120121
assert isinstance(n_arg, Node)
121-
_annotate_input_qspec_map(node, n_arg, qspec)
122+
annotate_input_qspec_map(node, n_arg, qspec)
122123
if quant_property.mark_annotated:
123124
mark_node_as_annotated(n_arg) # type: ignore[attr-defined]
124125

@@ -129,7 +130,7 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
129130
assert not quant_property.optional
130131
assert quant_property.index == 0, "Only one output annotation supported currently"
131132

132-
_annotate_output_qspec(node, quant_property.qspec)
133+
annotate_output_qspec(node, quant_property.qspec)
133134

134135

135136
def _match_pattern(

backends/arm/quantizer/quantization_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from dataclasses import dataclass
1010

1111
import torch
12-
from torch.ao.quantization import ObserverOrFakeQuantize
12+
from torchao.quantization.pt2e import ObserverOrFakeQuantize
1313

14-
from torch.ao.quantization.quantizer import (
14+
from torchao.quantization.pt2e.quantizer import (
1515
DerivedQuantizationSpec,
1616
FixedQParamsQuantizationSpec,
1717
QuantizationSpec,

backends/arm/test/ops/test_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
)
2020
from executorch.backends.arm.tosa_specification import TosaSpecification
2121
from executorch.backends.xnnpack.test.tester import Quantize
22-
from torch.ao.quantization.observer import HistogramObserver
23-
from torch.ao.quantization.quantizer import QuantizationSpec
22+
from torchao.quantization.pt2e.observer import HistogramObserver
23+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
2424

2525
aten_op = "torch.ops.aten.add.Tensor"
2626
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"

backends/arm/test/ops/test_sigmoid_16bit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
)
1919
from executorch.backends.arm.tosa_specification import TosaSpecification
2020
from executorch.backends.xnnpack.test.tester import Quantize
21-
from torch.ao.quantization.observer import HistogramObserver
22-
from torch.ao.quantization.quantizer import QuantizationSpec
21+
from torchao.quantization.pt2e.observer import HistogramObserver
22+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
2323

2424

2525
def _get_16_bit_quant_config():

backends/arm/test/ops/test_sigmoid_32bit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
)
1515
from executorch.backends.arm.tosa_specification import TosaSpecification
1616
from executorch.backends.xnnpack.test.tester import Quantize
17-
from torch.ao.quantization.observer import HistogramObserver
18-
from torch.ao.quantization.quantizer import QuantizationSpec
17+
from torchao.quantization.pt2e.observer import HistogramObserver
18+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
1919

2020

2121
def _get_16_bit_quant_config():

0 commit comments

Comments
 (0)