Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,6 @@ exclude_patterns = [
"backends/vulkan/quantizer/**",
"backends/vulkan/test/**",
"backends/cadence/aot/quantizer/**",
"backends/qualcomm/quantizer/**",
"examples/qualcomm/**",
"backends/xnnpack/quantizer/**",
"backends/xnnpack/test/**",
"exir/tests/test_passes.py",
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def get_to_edge_transform_passes(
from executorch.backends.qualcomm._passes import utils
from executorch.exir.dialects._ops import ops as exir_ops

utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
utils.q_ops.add(exir_ops.edge.torchao.quantize_affine.default)
utils.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default)

passes_job = (
passes_job if passes_job is not None else get_capture_program_passes()
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ def get_quant_encoding_conf(
)
# TODO: refactor this when target could be correctly detected
per_block_encoding = {
exir_ops.edge.pt2e_quant.quantize_affine.default,
exir_ops.edge.pt2e_quant.dequantize_affine.default,
exir_ops.edge.torchao.quantize_affine.default,
exir_ops.edge.torchao.dequantize_affine.default,
}
if quant_attrs[QCOM_ENCODING] in per_block_encoding:
return self.make_qnn_per_block_config(node, quant_attrs)
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/partition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
torch.ops.aten.upsample_bicubic2d.vec,
# This request is ignored because it is in a blocklist. Refer to exir/program/_program.py
torch.ops.aten.unbind.int,
torch.ops.pt2e_quant.quantize_affine.default,
torch.ops.pt2e_quant.dequantize_affine.default,
torch.ops.torchao.quantize_affine.default,
torch.ops.torchao.dequantize_affine.default,
]
return do_not_decompose
55 changes: 26 additions & 29 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
from torch._ops import OpOverload

from torch._subclasses import FakeTensor
from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
from torch.fx import Node

from torch.ao.quantization.observer import FixedQParamsObserver
from torch.ao.quantization.quantizer import (
from torchao.quantization.pt2e import FixedQParamsFakeQuantize, FixedQParamsObserver
from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
DerivedQuantizationSpec,
QuantizationAnnotation,
QuantizationSpec,
SharedQuantizationSpec,
)
from torch.ao.quantization.quantizer.utils import (
_annotate_input_qspec_map,
_annotate_output_qspec,
)
from torch.fx import Node

from .qconfig import (
get_16a16w_qnn_ptq_config,
Expand Down Expand Up @@ -618,19 +615,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No
return

# TODO current only support 16a16w
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)

_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
quantization_config.input_activation,
)
nodes_to_mark_annotated = [node]
_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)


Expand Down Expand Up @@ -819,25 +816,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) ->
if _is_annotated([node]):
return

_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
quantization_config.weight,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
bias_node,
quantization_config.bias,
)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)


Expand Down Expand Up @@ -1002,12 +999,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
if _is_annotated([node]):
return

_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
act_node,
quantization_config.input_activation,
)
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
quantization_config.weight,
Expand All @@ -1018,9 +1015,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
bias_config = quantization_config.bias(node)
else:
bias_config = quantization_config.bias
_annotate_input_qspec_map(node, bias_node, bias_config)
annotate_input_qspec_map(node, bias_node, bias_config)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)

# We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
Expand All @@ -1038,29 +1035,29 @@ def annotate_batch_and_instance_norm(
return

annotated_args = [act]
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
act,
quantization_config.input_activation,
)
# QNN requires uint8 instead of int8 in 'weight' config
if weight is not None:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight,
quantization_config.input_activation,
)
annotated_args.append(weight)

if bias is not None:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
bias,
quantization_config.bias,
)
annotated_args.append(bias)

_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated([node, *annotated_args])


Expand All @@ -1070,7 +1067,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non
return

if _is_float_tensor(node):
_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated([node])


Expand All @@ -1086,32 +1083,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
return
input_act_qspec = quantization_config.input_activation

_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
act_node,
input_act_qspec,
)
if input_act_qspec.dtype == torch.int32:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
get_16a16w_qnn_ptq_config().weight,
)
else:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
weight_node,
input_act_qspec,
)
nodes_to_mark_annotated = [node, weight_node]
if bias_node:
_annotate_input_qspec_map(
annotate_input_qspec_map(
node,
bias_node,
quantization_config.bias,
)
nodes_to_mark_annotated.append(bias_node)
_annotate_output_qspec(node, quantization_config.output_activation)
annotate_output_qspec(node, quantization_config.output_activation)
_mark_nodes_as_annotated(nodes_to_mark_annotated)


Expand Down
6 changes: 3 additions & 3 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
QuantizationConfig,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver
from torch.ao.quantization.quantizer import (
from torch.fx import Node
from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver
from torchao.quantization.pt2e.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
SharedQuantizationSpec,
)
from torch.fx import Node


def annotate_mimi_decoder(gm: torch.fx.GraphModule):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Tuple

import torch
from torch.ao.quantization.observer import MappingType, PerBlock
from torch.ao.quantization.pt2e._affine_quantization import (
from torchao.quantization.pt2e import MappingType, PerBlock
from torchao.quantization.pt2e._affine_quantization import (
_get_reduction_params,
AffineQuantizedMinMaxObserver,
choose_qparams_affine_with_min_max,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from torch.ao.quantization.observer import UniformQuantizationObserverBase
from torchao.quantization.pt2e import UniformQuantizationObserverBase


# TODO move to torch/ao/quantization/observer.py.
Expand Down
11 changes: 6 additions & 5 deletions backends/qualcomm/quantizer/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
PerBlockParamObserver,
)
from torch import Tensor
from torch.ao.quantization.fake_quantize import (
from torch.fx import Node
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
)
from torch.ao.quantization.observer import (
MinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
)
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec
from torch.fx import Node
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
QuantizationSpec,
)


@dataclass(eq=True)
Expand Down
7 changes: 3 additions & 4 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager

from torch._ops import OpOverload
from torch.ao.quantization.quantizer import Quantizer
from torch.fx import GraphModule
from torchao.quantization.pt2e import UniformQuantizationObserverBase
from torchao.quantization.pt2e.quantizer import Quantizer

from .annotators import OP_ANNOTATOR

Expand Down Expand Up @@ -130,9 +131,7 @@ class ModuleQConfig:
is_qat: bool = False
is_conv_per_channel: bool = False
is_linear_per_channel: bool = False
act_observer: Optional[
torch.ao.quantization.observer.UniformQuantizationObserverBase
] = None
act_observer: Optional[UniformQuantizationObserverBase] = None

def __post_init__(self):
if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT:
Expand Down
7 changes: 4 additions & 3 deletions backends/qualcomm/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import torch
import torchao
from executorch import exir
from executorch.backends.qualcomm._passes.utils import dq_ops
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
Expand Down Expand Up @@ -537,8 +538,8 @@ def get_qdq_module(
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.pt2e_quant.quantize_affine.default,
torch.ops.pt2e_quant.dequantize_affine.default,
torch.ops.torchao.quantize_affine.default,
torch.ops.torchao.dequantize_affine.default,
}
if not bypass_check:
self.assertTrue(nodes.intersection(q_and_dq))
Expand Down Expand Up @@ -569,7 +570,7 @@ def get_prepared_qat_module(
quantizer.set_submodule_qconfig_list(submodule_qconfig_list)

prepared = prepare_qat_pt2e(m, quantizer)
return torch.ao.quantization.move_exported_model_to_train(prepared)
return torchao.quantization.pt2e.move_exported_model_to_train(prepared)

def get_converted_sgd_trained_module(
self,
Expand Down
2 changes: 1 addition & 1 deletion examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer

from torch.ao.quantization.observer import MinMaxObserver
from torchao.quantization.pt2e import MinMaxObserver
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

sys.setrecursionlimit(4096)
Expand Down
2 changes: 1 addition & 1 deletion examples/qualcomm/oss_scripts/moshi/mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from huggingface_hub import hf_hub_download
from moshi.models import loaders

from torch.ao.quantization.observer import MinMaxObserver
from torchao.quantization.pt2e import MinMaxObserver


def seed_all(seed):
Expand Down
Loading
Loading