From f9030b20072cc5c447af8a542744c202a95ecf5a Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Thu, 3 Oct 2024 22:47:12 -0700 Subject: [PATCH] fixed layer norm quantization annotation --- backends/qualcomm/quantizer/utils.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index ed9afd70ce6..c54b2bc3ebb 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -1055,17 +1055,25 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> if _is_annotated([node]): return + input_act_qspec = quantization_config.input_activation _annotate_input_qspec_map( node, act_node, - quantization_config.input_activation, - ) - _annotate_input_qspec_map( - node, - weight_node, - quantization_config.input_activation, + input_act_qspec, ) + if input_act_qspec.dtype == torch.int32: + _annotate_input_qspec_map( + node, + weight_node, + get_default_16bit_qnn_ptq_config().weight, + ) + else: + _annotate_input_qspec_map( + node, + weight_node, + input_act_qspec, + ) nodes_to_mark_annotated = [node, weight_node] if bias_node: _annotate_input_qspec_map(