Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from executorch.exir.dialects._ops import ops as exir_ops

not_supported_operator = [
# output size is data dependent on the slice index
exir_ops.edge.aten._embedding_bag.default,
# for graph sharding purpose, different from the op used in decoder models
exir_ops.edge.dim_order_ops._clone_dim_order.default,
# QNN does not support 4-bit embedding
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
]

Expand Down
56 changes: 34 additions & 22 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY

from .observers.concat_observer import ConcatObserver

from .qconfig import (
get_16a16w_qnn_ptq_config,
get_16a4w_qnn_qat_config,
Expand Down Expand Up @@ -691,7 +693,7 @@ def annotate_sign(node: Node, quantization_config: QuantizationConfig) -> None:

@register_annotator([torch.ops.aten.slice.Tensor])
def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
annotate_single_in_share_out(node, quantization_config)


@register_annotator([torch.ops.aten.slice_scatter.default])
Expand Down Expand Up @@ -1277,31 +1279,40 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->

@register_annotator([torch.ops.aten.cat.default, torch.ops.aten.concat.default])
def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None:
input_nodes = node.args[0]
if _is_annotated([node]) or not _is_float_tensor(node):
return

assert isinstance(input_nodes, Sequence)

first_input_node = input_nodes[0]
input_qspec_map = {}
assert isinstance(first_input_node, Node)
assert isinstance(node, Node)
if _is_float_tensor(first_input_node):
input_qspec_map[first_input_node] = quantization_config.input_activation
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
(first_input_node, node)
)

for input_node in input_nodes[1:]:
if input_node not in input_qspec_map:
assert isinstance(input_node, Node)
if _is_float_tensor(input_node):
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec

input_qspec_map, input_nodes = {}, node.args[0]
for input in input_nodes:
input_qspec = input.meta.get(Q_ANNOTATION_KEY, None)
if (
# placeholder
input_qspec is None
or
# keep shared qspec here for propagation the data range
# without introducing extra requantizations
not isinstance(input_qspec.output_qspec, SharedQuantizationSpec)
):
input_qspec_map[input] = quantization_config.input_activation

output_qspec = QuantizationSpec(
dtype=quantization_config.output_activation.dtype,
qscheme=quantization_config.output_activation.qscheme,
quant_max=quantization_config.output_activation.quant_max,
quant_min=quantization_config.output_activation.quant_min,
observer_or_fake_quant_ctr=ConcatObserver.with_args(
# we need to know the concat node in order to hack all the input observers' data range
# since deep copy of fake tensor (node.meta["val"]) is inhibited
# we could only ship grap & node name and perform postprocess inside observer currently
**{
"node_name": node.name,
"graph": node.graph,
}
),
)
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
output_qspec=output_qspec,
_annotated=True,
)

Expand Down Expand Up @@ -1345,6 +1356,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
input_act = node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation
share_qparams_with_input_node_qspec = SharedQuantizationSpec((input_act, node))

node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
Expand All @@ -1353,7 +1365,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:

for user in node.users:
user.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
output_qspec=quantization_config.output_activation,
output_qspec=share_qparams_with_input_node_qspec,
_annotated=True,
)

Expand Down
74 changes: 74 additions & 0 deletions backends/qualcomm/quantizer/observers/concat_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torchao.quantization.pt2e import UniformQuantizationObserverBase


class ConcatObserver(UniformQuantizationObserverBase):
"""
Fetch maximum data range of all tensors to be concatenated
"""

def __init__(
self,
node_name,
graph,
dtype=torch.uint8,
qscheme=torch.per_tensor_affine,
reduce_range=False,
quant_min=None,
quant_max=None,
factory_kwargs=None,
eps=torch.finfo(torch.float32).eps, # noqa: B008
is_dynamic=False,
**kwargs,
) -> None:
super().__init__(
dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs,
eps=eps,
is_dynamic=is_dynamic,
**kwargs,
)

factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
# get concat node and its inputs
self.concat_node = [node for node in graph.nodes if node.name == node_name][0]
self.input_nodes = self.concat_node.args[0]
self.input_observers = []

def forward(self, x_orig):
# calculate the min / max first
self.min_val = min(self.min_val, x_orig.min())
self.max_val = max(self.max_val, x_orig.max())

if len(self.input_observers) == 0:
# collect observers first if they are not cached
# we cannot do this in constructor since observers have not appeared
for node in self.input_nodes:
obs_node = list(
filter(lambda user: user != self.concat_node, node.users.keys())
)[0]
self.input_observers.append(
getattr(obs_node.graph.owning_module, obs_node.name)
)

# update min / max for all observers of input nodes
for observers in self.input_observers:
observers.min_val = self.min_val
observers.max_val = self.max_val

return x_orig

def calculate_qparams(self):
return self._calculate_qparams(self.min_val, self.max_val)
2 changes: 1 addition & 1 deletion backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6043,7 +6043,7 @@ def test_llama_stories_110m(self):
# x86 does not allow weight sharing, so we don't check pte size
if not self.enable_x86_64:
pte_size = msg["pte_size"]
self.assertLessEqual(pte_size, 130_000_000) # 130MB
self.assertLessEqual(pte_size, 135_000_000) # 135MB
if not self.compile_only and not self.enable_x86_64:
self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai

Expand Down
2 changes: 1 addition & 1 deletion backends/test/suite/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class TestCaseShim:
def __init__(self, test_runner):
self._test_runner = test_runner

def _test_op(self, model, args, flow, generate_random_test_inputs=True):
def _test_op(self, model, args, flow, generate_random_test_inputs=False):
self._test_runner.lower_and_run_model(
model, args, generate_random_test_inputs=generate_random_test_inputs
)
Expand Down
5 changes: 2 additions & 3 deletions examples/qualcomm/oss_scripts/fastvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(args):
dtype=torch.uint8,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(
**{"averaging_constant": 0.02}
**{"averaging_constant": 0.01}
),
)
weight_qspec = QuantizationSpec(
Expand All @@ -85,7 +85,7 @@ def main(args):
qscheme=torch.per_channel_symmetric,
ch_axis=0,
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(
**{"steps": 200, "use_mse": True}
**{"steps": 100, "use_mse": True}
),
)
# rewrite default per-channel ptq config
Expand Down Expand Up @@ -114,7 +114,6 @@ def main(args):
dataset=inputs,
skip_node_id_set=skip_node_id_set,
skip_node_op_set=skip_node_op_set,
quant_dtype=QuantDtype.use_8a8w,
custom_quantizer=quantizer,
shared_buffer=args.shared_buffer,
)
Expand Down
Loading