Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
)
from executorch.backends.qualcomm.quantizer.quantizer import (
get_16a8w_qnn_ptq_config,
get_16a8w_qnn_qat_config,
get_8a8w_qnn_ptq_config,
get_8a8w_qnn_qat_config,
get_ptq_per_channel_quant_config,
get_qat_per_channel_quant_config,
QuantizationConfig,
)
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -154,7 +157,9 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):


def annotate_matmul_16a8w( # noqa: C901
gm: torch.fx.GraphModule, annotate_conv=True
gm: torch.fx.GraphModule,
annotate_conv=True,
is_qat=False,
) -> None:
"""
This function is specific for matmul op 16a8w.
Expand Down Expand Up @@ -238,7 +243,6 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No
def annotate_single_in_single_out(
node: Node, quantization_config: QuantizationConfig
) -> None:

input_qspec_map = {}
input_act = node.args[0]
input_qspec_map[input_act] = quantization_config.input_activation
Expand All @@ -252,7 +256,6 @@ def annotate_single_in_single_out(
def annotate_single_in_share_out(
node: Node, quantization_config: QuantizationConfig
) -> None:

input_qspec_map = {}
input_act = node.args[0]
input_qspec_map[input_act] = quantization_config.input_activation
Expand Down Expand Up @@ -283,16 +286,27 @@ def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None:
_annotated=True,
)

def annotate_matmul_input1(node: Node):
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
act_symmetric=True, act_observer=MinMaxObserver
)
quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config(
act_dtype=torch.uint8,
weight_dtype=torch.int4,
act_observer=MinMaxObserver,
act_symmetric=True,
)
def annotate_matmul_input1(node: Node, is_qat: str):
if is_qat:
quantization_config_8a8w = get_8a8w_qnn_qat_config(
act_symmetric=True, act_observer=MinMaxObserver
)
quantization_config_8a4w_per_channel = get_qat_per_channel_quant_config(
act_dtype=torch.uint8,
weight_dtype=torch.int4,
act_observer=MinMaxObserver,
act_symmetric=True,
)
else:
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
act_symmetric=True, act_observer=MinMaxObserver
)
quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config(
act_dtype=torch.uint8,
weight_dtype=torch.int4,
act_observer=MinMaxObserver,
act_symmetric=True,
)
while isinstance(node, Node) and node.op == "call_function":
if node.target in [
torch.ops.aten.permute.default,
Expand Down Expand Up @@ -330,12 +344,19 @@ def annotate_matmul_input1(node: Node):
print(f"The node ({node}) is not expected in the input1 of the matmul")
node = node.args[0]

quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver)
if is_qat:
quantization_config_16a8w = get_16a8w_qnn_qat_config(
act_observer=MinMaxObserver
)
else:
quantization_config_16a8w = get_16a8w_qnn_ptq_config(
act_observer=MinMaxObserver
)

for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
annotate_matmul(node, quantization_config_16a8w)
annotate_matmul_input1(node.args[1])
annotate_matmul_input1(node.args[1], is_qat=is_qat)


def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
Expand Down
123 changes: 107 additions & 16 deletions backends/qualcomm/quantizer/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,65 @@ def get_16a8w_qnn_ptq_config(
return quantization_config


def get_16a8w_qnn_qat_config(
act_observer=MovingAverageMinMaxObserver,
) -> QuantizationConfig:
extra_args: Dict[str, Any] = {"eps": 2**-20}
act_fake_quant_ctr = FakeQuantize.with_args(
dtype=torch.int32,
quant_min=torch.iinfo(torch.uint16).min,
quant_max=torch.iinfo(torch.uint16).max,
qscheme=torch.per_tensor_affine,
reduce_range=True,
observer=act_observer.with_args(**extra_args),
)
act_quantization_spec = QuantizationSpec(
dtype=torch.int32,
quant_min=torch.iinfo(torch.uint16).min,
quant_max=torch.iinfo(torch.uint16).max,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=act_fake_quant_ctr,
)
weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
dtype=torch.int8,
quant_min=torch.iinfo(torch.int8).min + 1,
quant_max=torch.iinfo(torch.int8).max,
qscheme=torch.per_tensor_symmetric,
reduce_range=True,
observer=MovingAverageMinMaxObserver,
)
weight_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=torch.iinfo(torch.int8).min + 1,
quant_max=torch.iinfo(torch.int8).max,
qscheme=torch.per_tensor_symmetric,
ch_axis=0,
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
)
bias_fake_quant_ctr = FakeQuantize.with_args(
dtype=torch.int32,
quant_min=torch.iinfo(torch.int32).min,
quant_max=torch.iinfo(torch.int32).max,
qscheme=torch.per_tensor_symmetric,
observer=MovingAverageMinMaxObserver,
)
bias_quantization_spec = QuantizationSpec(
dtype=torch.int32,
quant_min=torch.iinfo(torch.int32).min,
quant_max=torch.iinfo(torch.int32).max,
qscheme=torch.per_tensor_symmetric,
observer_or_fake_quant_ctr=bias_fake_quant_ctr,
)
quantization_config = QuantizationConfig(
input_activation=act_quantization_spec,
output_activation=act_quantization_spec,
weight=weight_quantization_spec,
bias=bias_quantization_spec,
)

return quantization_config


def get_16a16w_qnn_ptq_config(
act_observer=MovingAverageMinMaxObserver,
) -> QuantizationConfig:
Expand Down Expand Up @@ -459,6 +518,7 @@ def get_qat_per_channel_quant_config(
act_dtype=torch.uint8,
weight_dtype=torch.int8,
act_observer=MovingAverageMinMaxObserver,
act_symmetric=False,
) -> QuantizationConfig:
supported_act_types = {
torch.uint8,
Expand All @@ -476,21 +536,38 @@ def get_qat_per_channel_quant_config(
), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"

# torch does not support uint16 quantization, use int32 to bypass
act_fake_quant_ctr = FakeQuantize.with_args(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
quant_min=torch.iinfo(act_dtype).min,
quant_max=torch.iinfo(act_dtype).max,
qscheme=torch.per_tensor_affine,
reduce_range=True,
observer=act_observer,
)
act_quantization_spec = QuantizationSpec(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
quant_min=torch.iinfo(act_dtype).min,
quant_max=torch.iinfo(act_dtype).max,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=act_fake_quant_ctr,
)
if act_symmetric:
# If zero_point is 128, htp can do optimizations.
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
act_fake_quant_ctr = FakeQuantize.with_args(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
qscheme=torch.per_tensor_symmetric,
reduce_range=True,
observer=act_observer,
)
act_quantization_spec = QuantizationSpec(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
qscheme=torch.per_tensor_symmetric,
ch_axis=0,
observer_or_fake_quant_ctr=act_fake_quant_ctr,
)
else:
act_fake_quant_ctr = FakeQuantize.with_args(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
quant_min=torch.iinfo(act_dtype).min,
quant_max=torch.iinfo(act_dtype).max,
qscheme=torch.per_tensor_affine,
reduce_range=True,
observer=act_observer,
)
act_quantization_spec = QuantizationSpec(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
quant_min=torch.iinfo(act_dtype).min,
quant_max=torch.iinfo(act_dtype).max,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=act_fake_quant_ctr,
)

weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
Expand All @@ -513,7 +590,21 @@ def get_qat_per_channel_quant_config(
observer_or_fake_quant_ctr=weight_fake_quant_ctr,
)

bias_quantization_spec = _derived_bias_quant_spec
bias_fake_quant_ctr = FakeQuantize.with_args(
dtype=torch.int32,
quant_min=torch.iinfo(torch.int32).min,
quant_max=torch.iinfo(torch.int32).max,
qscheme=torch.per_tensor_symmetric,
reduce_range=True,
observer=MovingAverageMinMaxObserver,
)
bias_quantization_spec = QuantizationSpec(
dtype=torch.int32,
quant_min=torch.iinfo(torch.int32).min,
quant_max=torch.iinfo(torch.int32).max,
qscheme=torch.per_tensor_symmetric,
observer_or_fake_quant_ctr=bias_fake_quant_ctr,
)

quantization_config = QuantizationConfig(
input_activation=act_quantization_spec,
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_16a4w_qnn_ptq_config,
get_16a4w_qnn_qat_config,
get_16a8w_qnn_ptq_config,
get_16a8w_qnn_qat_config,
get_8a8w_qnn_ptq_config,
get_8a8w_qnn_qat_config,
get_ptq_per_block_quant_config,
Expand All @@ -39,6 +40,7 @@
"QuantDtype",
"get_16a4w_qnn_ptq_config",
"get_16a8w_qnn_ptq_config",
"get_16a8w_qnn_qat_config",
"get_16a16w_qnn_ptq_config",
"get_8a8w_qnn_ptq_config",
"get_8a8w_qnn_qat_config",
Expand Down
Loading