diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 331622ee71b..99016871a8a 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -12,8 +12,11 @@ ) from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a8w_qnn_ptq_config, + get_16a8w_qnn_qat_config, get_8a8w_qnn_ptq_config, + get_8a8w_qnn_qat_config, get_ptq_per_channel_quant_config, + get_qat_per_channel_quant_config, QuantizationConfig, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -154,7 +157,9 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): def annotate_matmul_16a8w( # noqa: C901 - gm: torch.fx.GraphModule, annotate_conv=True + gm: torch.fx.GraphModule, + annotate_conv=True, + is_qat=False, ) -> None: """ This function is specific for matmul op 16a8w. @@ -238,7 +243,6 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No def annotate_single_in_single_out( node: Node, quantization_config: QuantizationConfig ) -> None: - input_qspec_map = {} input_act = node.args[0] input_qspec_map[input_act] = quantization_config.input_activation @@ -252,7 +256,6 @@ def annotate_single_in_single_out( def annotate_single_in_share_out( node: Node, quantization_config: QuantizationConfig ) -> None: - input_qspec_map = {} input_act = node.args[0] input_qspec_map[input_act] = quantization_config.input_activation @@ -283,16 +286,27 @@ def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None: _annotated=True, ) - def annotate_matmul_input1(node: Node): - quantization_config_8a8w = get_8a8w_qnn_ptq_config( - act_symmetric=True, act_observer=MinMaxObserver - ) - quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int4, - act_observer=MinMaxObserver, - act_symmetric=True, - ) + def annotate_matmul_input1(node: Node, is_qat: str): + if is_qat: + quantization_config_8a8w = get_8a8w_qnn_qat_config( + act_symmetric=True, act_observer=MinMaxObserver + ) + quantization_config_8a4w_per_channel = get_qat_per_channel_quant_config( + act_dtype=torch.uint8, + weight_dtype=torch.int4, + act_observer=MinMaxObserver, + act_symmetric=True, + ) + else: + quantization_config_8a8w = get_8a8w_qnn_ptq_config( + act_symmetric=True, act_observer=MinMaxObserver + ) + quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config( + act_dtype=torch.uint8, + weight_dtype=torch.int4, + act_observer=MinMaxObserver, + act_symmetric=True, + ) while isinstance(node, Node) and node.op == "call_function": if node.target in [ torch.ops.aten.permute.default, @@ -330,12 +344,19 @@ def annotate_matmul_input1(node: Node): print(f"The node ({node}) is not expected in the input1 of the matmul") node = node.args[0] - quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver) + if is_qat: + quantization_config_16a8w = get_16a8w_qnn_qat_config( + act_observer=MinMaxObserver + ) + else: + quantization_config_16a8w = get_16a8w_qnn_ptq_config( + act_observer=MinMaxObserver + ) for node in gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: annotate_matmul(node, quantization_config_16a8w) - annotate_matmul_input1(node.args[1]) + annotate_matmul_input1(node.args[1], is_qat=is_qat) def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 748128ceafd..b510a8d9c7e 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -187,6 +187,65 @@ def get_16a8w_qnn_ptq_config( return quantization_config +def get_16a8w_qnn_qat_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + reduce_range=True, + observer=act_observer.with_args(**extra_args), + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + reduce_range=True, + observer=MovingAverageMinMaxObserver, + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=weight_fake_quant_ctr, + ) + bias_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer=MovingAverageMinMaxObserver, + ) + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=bias_fake_quant_ctr, + ) + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + def get_16a16w_qnn_ptq_config( act_observer=MovingAverageMinMaxObserver, ) -> QuantizationConfig: @@ -459,6 +518,7 @@ def get_qat_per_channel_quant_config( act_dtype=torch.uint8, weight_dtype=torch.int8, act_observer=MovingAverageMinMaxObserver, + act_symmetric=False, ) -> QuantizationConfig: supported_act_types = { torch.uint8, @@ -476,21 +536,38 @@ def get_qat_per_channel_quant_config( ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" # torch does not support uint16 quantization, use int32 to bypass - act_fake_quant_ctr = FakeQuantize.with_args( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - reduce_range=True, - observer=act_observer, - ) - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_fake_quant_ctr, - ) + if act_symmetric: + # If zero_point is 128, htp can do optimizations. + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. + act_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + qscheme=torch.per_tensor_symmetric, + reduce_range=True, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) + else: + act_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + reduce_range=True, + observer=act_observer, + ) + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_fake_quant_ctr, + ) weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, @@ -513,7 +590,21 @@ def get_qat_per_channel_quant_config( observer_or_fake_quant_ctr=weight_fake_quant_ctr, ) - bias_quantization_spec = _derived_bias_quant_spec + bias_fake_quant_ctr = FakeQuantize.with_args( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + reduce_range=True, + observer=MovingAverageMinMaxObserver, + ) + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=bias_fake_quant_ctr, + ) quantization_config = QuantizationConfig( input_activation=act_quantization_spec, diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index e14d73f521d..5943b54d968 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -23,6 +23,7 @@ get_16a4w_qnn_ptq_config, get_16a4w_qnn_qat_config, get_16a8w_qnn_ptq_config, + get_16a8w_qnn_qat_config, get_8a8w_qnn_ptq_config, get_8a8w_qnn_qat_config, get_ptq_per_block_quant_config, @@ -39,6 +40,7 @@ "QuantDtype", "get_16a4w_qnn_ptq_config", "get_16a8w_qnn_ptq_config", + "get_16a8w_qnn_qat_config", "get_16a16w_qnn_ptq_config", "get_8a8w_qnn_ptq_config", "get_8a8w_qnn_qat_config",