Skip to content

Commit f5578a2

Browse files
authored
Migrate pt2e backend batch1
Differential Revision: D75119434 Pull Request resolved: #11033
1 parent d892b66 commit f5578a2

File tree

14 files changed

+50
-30
lines changed

14 files changed

+50
-30
lines changed

.lintrunner.toml

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,23 @@ code = "TORCH_AO_IMPORT"
384384
include_patterns = ["**/*.py"]
385385
exclude_patterns = [
386386
"third-party/**",
387+
# TODO: remove exceptions as we migrate
388+
# backends
389+
"backends/arm/quantizer/**",
390+
"backends/arm/test/ops/**",
391+
"backends/vulkan/quantizer/**",
392+
"backends/vulkan/test/**",
393+
"backends/cadence/aot/quantizer/**",
394+
"backends/qualcomm/quantizer/**",
395+
"examples/qualcomm/**",
396+
"backends/xnnpack/quantizer/**",
397+
"backends/xnnpack/test/**",
398+
"exir/tests/test_passes.py",
399+
"extension/llm/export/builder.py",
400+
"extension/llm/export/quantizer_lib.py",
401+
"exir/tests/test_memory_planning.py",
402+
"backends/transforms/duplicate_dynamic_quant_chain.py",
403+
"exir/backend/test/demos/test_xnnpack_qnnpack.py",
387404
]
388405

389406
command = [
@@ -392,7 +409,7 @@ command = [
392409
"lintrunner_adapters",
393410
"run",
394411
"grep_linter",
395-
"--pattern=\\bfrom torch\\.ao\\.quantization\\.(?:quantize_pt2e)(?:\\.[A-Za-z0-9_]+)*\\b",
412+
"--pattern=\\bfrom torch\\.ao\\.quantization\\.(?:quantizer|observer|quantize_pt2e|pt2e)(?:\\.[A-Za-z0-9_]+)*\\b",
396413
"--linter-name=TorchAOImport",
397414
"--error-name=Prohibited torch.ao.quantization import",
398415
"""--error-description=\

backends/cortex_m/test/test_replace_quant_nodes.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
ReplaceQuantNodesPass,
1717
)
1818
from executorch.exir.dialects._ops import ops as exir_ops
19-
from torch.ao.quantization.observer import HistogramObserver
20-
from torch.ao.quantization.quantizer.quantizer import (
19+
from torch.export import export, export_for_training
20+
from torch.fx import GraphModule
21+
from torchao.quantization.pt2e.observer import HistogramObserver
22+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
23+
from torchao.quantization.pt2e.quantizer import (
2124
QuantizationAnnotation,
2225
QuantizationSpec,
2326
Quantizer,
2427
)
25-
from torch.export import export, export_for_training
26-
from torch.fx import GraphModule
27-
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2828

2929

3030
@dataclass(eq=True, frozen=True)

backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.dim_order_utils import get_dim_order
1313
from executorch.exir.pass_base import ExportPass, PassResult
14-
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
14+
from torchao.quantization.pt2e import find_sequential_partitions
1515

1616

1717
class PermuteMemoryFormatsPass(ExportPass):

backends/example/example_operators/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
7+
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
88

99

1010
def _nodes_are_annotated(node_list):

backends/example/example_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from executorch.exir.graph_module import get_control_flow_submodules
2222
from torch.export import ExportedProgram
2323
from torch.fx.passes.operator_support import OperatorSupportBase
24-
from torchao.quantization.pt2e.graph_utils import find_sequential_partitions
24+
from torchao.quantization.pt2e import find_sequential_partitions
2525

2626

2727
@final

backends/mediatek/quantizer/annotator.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@
1010
from torch._ops import OpOverload
1111
from torch._subclasses import FakeTensor
1212

13-
from torch.ao.quantization.quantizer import QuantizationAnnotation
14-
from torch.ao.quantization.quantizer.utils import (
15-
_annotate_input_qspec_map,
16-
_annotate_output_qspec,
17-
)
18-
1913
from torch.export import export_for_training
2014
from torch.fx import Graph, Node
2115
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
2216
SubgraphMatcherWithNameNodeMap,
2317
)
2418

19+
from torchao.quantization.pt2e.quantizer import (
20+
annotate_input_qspec_map,
21+
annotate_output_qspec as _annotate_output_qspec,
22+
QuantizationAnnotation,
23+
)
24+
2525
from .qconfig import QuantizationConfig
2626

2727

@@ -108,7 +108,7 @@ def _annotate_fused_activation_pattern(
108108
torch.ops.aten.linear.default,
109109
]:
110110
weight_node = producer_node.args[1]
111-
_annotate_input_qspec_map(
111+
annotate_input_qspec_map(
112112
producer_node,
113113
weight_node,
114114
quant_config.weight,
@@ -201,7 +201,7 @@ def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None:
201201
return
202202

203203
weight_node = node.args[1]
204-
_annotate_input_qspec_map(
204+
annotate_input_qspec_map(
205205
node,
206206
weight_node,
207207
quant_config.weight,
@@ -260,5 +260,5 @@ def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None:
260260
return
261261

262262
wgt_node = node.args[0]
263-
_annotate_input_qspec_map(node, wgt_node, quant_config.activation)
263+
annotate_input_qspec_map(node, wgt_node, quant_config.activation)
264264
_mark_as_annotated([node])

backends/mediatek/quantizer/qconfig.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
import torch
1212

13-
from torch.ao.quantization.fake_quantize import FakeQuantize
14-
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
15-
from torch.ao.quantization.quantizer import QuantizationSpec
13+
from torchao.quantization.pt2e.fake_quantize import FakeQuantize
14+
from torchao.quantization.pt2e.observer import MinMaxObserver, PerChannelMinMaxObserver
15+
from torchao.quantization.pt2e.quantizer import QuantizationSpec
1616

1717

1818
@unique

backends/mediatek/quantizer/quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# except in compliance with the License. See the license file in the root
55
# directory of this source tree for more details.
66

7-
from torch.ao.quantization.quantizer import Quantizer
87
from torch.fx import GraphModule
8+
from torchao.quantization.pt2e.quantizer import Quantizer
99

1010
from .._passes.decompose_scaled_dot_product_attention import (
1111
DecomposeScaledDotProductAttention,

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
QuantizationSpec,
3636
)
3737
from torch import fx
38-
from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
39-
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer
40-
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
38+
from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver
39+
from torchao.quantization.pt2e.quantizer import DerivedQuantizationSpec, Quantizer
40+
from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer
4141

4242

4343
class NeutronAtenQuantizer(Quantizer):

backends/nxp/quantizer/patterns.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
1515
from torch import fx
1616
from torch._ops import OpOverload
17-
from torch.ao.quantization.quantizer import (
17+
from torchao.quantization.pt2e.quantizer import (
1818
DerivedQuantizationSpec,
1919
FixedQParamsQuantizationSpec,
2020
SharedQuantizationSpec,

0 commit comments

Comments
 (0)