Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from executorch.exir.dialects._ops import ops as exir_ops

not_supported_operator = [
# output size is data dependent on the slice index
exir_ops.edge.aten._embedding_bag.default,
# for graph sharding purpose, different from the op used in decoder models
exir_ops.edge.dim_order_ops._clone_dim_order.default,
# QNN does not support 4-bit embedding
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
]

Expand Down
71 changes: 51 additions & 20 deletions backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def annotate_sign(node: Node, quantization_config: QuantizationConfig) -> None:

@register_annotator([torch.ops.aten.slice.Tensor])
def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None:
annotate_single_in_single_out(node, quantization_config)
annotate_single_in_share_out(node, quantization_config)


@register_annotator([torch.ops.aten.slice_scatter.default])
Expand Down Expand Up @@ -1277,32 +1277,62 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->

@register_annotator([torch.ops.aten.cat.default, torch.ops.aten.concat.default])
def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None:
input_nodes = node.args[0]
def _derived_quant_spec(
node: torch.fx.Node, output_qspec: QuantizationSpec
) -> DerivedQuantizationSpec:
def _derive_concat_qparams_fn(
node, obs_or_fqs: List
) -> Tuple[torch.Tensor, torch.Tensor]:
# get maximum scale
first_obj = obs_or_fqs[0]
first_obj.min_val = min(obj.min_val for obj in obs_or_fqs)
first_obj.max_val = max(obj.max_val for obj in obs_or_fqs)
scale, offset = first_obj.calculate_qparams()

# rewrite value of input QDQ pairs with the maximum data range of all input tensors
# since the topology order is constrainted by framework which prevents us from creating
# dependencies like: [(node, input_0), ..., (node, input_n)]
for dq_node in node.args[0]:
# per-tensor args: (prev_node, scale, offset, q_min, q_max, dtype)
q_node = dq_node.args[0]
encoding = [scale.item(), offset.item(), *dq_node.args[3:]]
dq_node.args = (q_node, *encoding)
q_node.args = (q_node.args[0], *encoding)

return (scale, offset)

q_min = (
torch.iinfo(output_qspec.dtype).min
if output_qspec.quant_min is None
else output_qspec.quant_min
)
q_max = (
torch.iinfo(output_qspec.dtype).max
if output_qspec.quant_max is None
else output_qspec.quant_max
)
return DerivedQuantizationSpec(
derived_from=[(input, node) for input in node.args[0]],
derive_qparams_fn=partial(_derive_concat_qparams_fn, node),
dtype=output_qspec.dtype,
quant_min=q_min,
quant_max=q_max,
ch_axis=0,
qscheme=output_qspec.qscheme,
)

if _is_annotated([node]) or not _is_float_tensor(node):
return

assert isinstance(input_nodes, Sequence)

first_input_node = input_nodes[0]
input_qspec_map = {}
assert isinstance(first_input_node, Node)
assert isinstance(node, Node)
if _is_float_tensor(first_input_node):
input_qspec_map[first_input_node] = quantization_config.input_activation
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
(first_input_node, node)
)

for input_node in input_nodes[1:]:
if input_node not in input_qspec_map:
assert isinstance(input_node, Node)
if _is_float_tensor(input_node):
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
for input in node.args[0]:
input_qspec_map[input] = quantization_config.input_activation

node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=share_qparams_with_input_act0_qspec,
output_qspec=_derived_quant_spec(node, quantization_config.output_activation),
_annotated=True,
allow_implicit_sharing=False,
)


Expand Down Expand Up @@ -1345,6 +1375,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
input_act = node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation
share_qparams_with_input_node_qspec = SharedQuantizationSpec((input_act, node))

node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
Expand All @@ -1353,7 +1384,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:

for user in node.users:
user.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
output_qspec=quantization_config.output_activation,
output_qspec=share_qparams_with_input_node_qspec,
_annotated=True,
)

Expand Down
2 changes: 1 addition & 1 deletion backends/test/suite/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class TestCaseShim:
def __init__(self, test_runner):
self._test_runner = test_runner

def _test_op(self, model, args, flow, generate_random_test_inputs=True):
def _test_op(self, model, args, flow, generate_random_test_inputs=False):
self._test_runner.lower_and_run_model(
model, args, generate_random_test_inputs=generate_random_test_inputs
)
Expand Down
5 changes: 2 additions & 3 deletions examples/qualcomm/oss_scripts/fastvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(args):
dtype=torch.uint8,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(
**{"averaging_constant": 0.02}
**{"averaging_constant": 0.01}
),
)
weight_qspec = QuantizationSpec(
Expand All @@ -85,7 +85,7 @@ def main(args):
qscheme=torch.per_channel_symmetric,
ch_axis=0,
observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(
**{"steps": 200, "use_mse": True}
**{"steps": 100, "use_mse": True}
),
)
# rewrite default per-channel ptq config
Expand Down Expand Up @@ -114,7 +114,6 @@ def main(args):
dataset=inputs,
skip_node_id_set=skip_node_id_set,
skip_node_op_set=skip_node_op_set,
quant_dtype=QuantDtype.use_8a8w,
custom_quantizer=quantizer,
shared_buffer=args.shared_buffer,
)
Expand Down
Loading