diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 0a947759538..4abbcc3145c 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -10,7 +10,11 @@ from executorch.exir.dialects._ops import ops as exir_ops not_supported_operator = [ + # output size is data dependent on the slice index + exir_ops.edge.aten._embedding_bag.default, + # for graph sharding purpose, different from the op used in decoder models exir_ops.edge.dim_order_ops._clone_dim_order.default, + # QNN does not support 4-bit embedding exir_ops.edge.quantized_decomposed.embedding_4bit.dtype, ] diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index cf403a1a76d..b3111c57481 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -691,7 +691,7 @@ def annotate_sign(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator([torch.ops.aten.slice.Tensor]) def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) + annotate_single_in_share_out(node, quantization_config) @register_annotator([torch.ops.aten.slice_scatter.default]) @@ -1277,32 +1277,62 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> @register_annotator([torch.ops.aten.cat.default, torch.ops.aten.concat.default]) def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None: - input_nodes = node.args[0] + def _derived_quant_spec( + node: torch.fx.Node, output_qspec: QuantizationSpec + ) -> DerivedQuantizationSpec: + def _derive_concat_qparams_fn( + node, obs_or_fqs: List + ) -> Tuple[torch.Tensor, torch.Tensor]: + # get maximum scale + first_obj = obs_or_fqs[0] + first_obj.min_val = min(obj.min_val for obj in obs_or_fqs) + first_obj.max_val = max(obj.max_val for obj in obs_or_fqs) + scale, offset = first_obj.calculate_qparams() + + # rewrite value of input QDQ pairs with the maximum data range of all input tensors + # since the topology order is constrainted by framework which prevents us from creating + # dependencies like: [(node, input_0), ..., (node, input_n)] + for dq_node in node.args[0]: + # per-tensor args: (prev_node, scale, offset, q_min, q_max, dtype) + q_node = dq_node.args[0] + encoding = [scale.item(), offset.item(), *dq_node.args[3:]] + dq_node.args = (q_node, *encoding) + q_node.args = (q_node.args[0], *encoding) + + return (scale, offset) + + q_min = ( + torch.iinfo(output_qspec.dtype).min + if output_qspec.quant_min is None + else output_qspec.quant_min + ) + q_max = ( + torch.iinfo(output_qspec.dtype).max + if output_qspec.quant_max is None + else output_qspec.quant_max + ) + return DerivedQuantizationSpec( + derived_from=[(input, node) for input in node.args[0]], + derive_qparams_fn=partial(_derive_concat_qparams_fn, node), + dtype=output_qspec.dtype, + quant_min=q_min, + quant_max=q_max, + ch_axis=0, + qscheme=output_qspec.qscheme, + ) + if _is_annotated([node]) or not _is_float_tensor(node): return - assert isinstance(input_nodes, Sequence) - - first_input_node = input_nodes[0] input_qspec_map = {} - assert isinstance(first_input_node, Node) - assert isinstance(node, Node) - if _is_float_tensor(first_input_node): - input_qspec_map[first_input_node] = quantization_config.input_activation - share_qparams_with_input_act0_qspec = SharedQuantizationSpec( - (first_input_node, node) - ) - - for input_node in input_nodes[1:]: - if input_node not in input_qspec_map: - assert isinstance(input_node, Node) - if _is_float_tensor(input_node): - input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + for input in node.args[0]: + input_qspec_map[input] = quantization_config.input_activation node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=share_qparams_with_input_act0_qspec, + output_qspec=_derived_quant_spec(node, quantization_config.output_activation), _annotated=True, + allow_implicit_sharing=False, ) @@ -1345,6 +1375,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None: input_act = node.args[0] assert isinstance(input_act, Node) input_qspec_map[input_act] = quantization_config.input_activation + share_qparams_with_input_node_qspec = SharedQuantizationSpec((input_act, node)) node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, @@ -1353,7 +1384,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None: for user in node.users: user.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( - output_qspec=quantization_config.output_activation, + output_qspec=share_qparams_with_input_node_qspec, _annotated=True, ) diff --git a/backends/test/suite/operators/__init__.py b/backends/test/suite/operators/__init__.py index 7475af29e15..825aa316771 100644 --- a/backends/test/suite/operators/__init__.py +++ b/backends/test/suite/operators/__init__.py @@ -69,7 +69,7 @@ class TestCaseShim: def __init__(self, test_runner): self._test_runner = test_runner - def _test_op(self, model, args, flow, generate_random_test_inputs=True): + def _test_op(self, model, args, flow, generate_random_test_inputs=False): self._test_runner.lower_and_run_model( model, args, generate_random_test_inputs=generate_random_test_inputs ) diff --git a/examples/qualcomm/oss_scripts/fastvit.py b/examples/qualcomm/oss_scripts/fastvit.py index f931da66448..89c7091dda4 100644 --- a/examples/qualcomm/oss_scripts/fastvit.py +++ b/examples/qualcomm/oss_scripts/fastvit.py @@ -75,7 +75,7 @@ def main(args): dtype=torch.uint8, qscheme=torch.per_tensor_affine, observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args( - **{"averaging_constant": 0.02} + **{"averaging_constant": 0.01} ), ) weight_qspec = QuantizationSpec( @@ -85,7 +85,7 @@ def main(args): qscheme=torch.per_channel_symmetric, ch_axis=0, observer_or_fake_quant_ctr=PerChannelParamObserver.with_args( - **{"steps": 200, "use_mse": True} + **{"steps": 100, "use_mse": True} ), ) # rewrite default per-channel ptq config @@ -114,7 +114,6 @@ def main(args): dataset=inputs, skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, - quant_dtype=QuantDtype.use_8a8w, custom_quantizer=quantizer, shared_buffer=args.shared_buffer, )