diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 0efacb143dd..230087eeab4 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -16,6 +16,7 @@ QUANT_ANNOTATION_KEY, ) from executorch.exir.dialects._ops import ops as exir_ops +from torch.ao.quantization.observer import MinMaxObserver from torch.ao.quantization.quantizer import ( QuantizationAnnotation, SharedQuantizationSpec, @@ -23,6 +24,86 @@ from torch.fx import Node +def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: + """ + This function is specific for matmul op 16a8w. + """ + + def annotate_matmul(node: Node, quantization_config: QuantizationConfig): + input_qspec_map = {} + input_act = node.args[0] + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + input_act1 = node.args[1] + input_spec1 = quantization_config.weight + input_qspec_map[input_act1] = input_spec1 + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_cat(node: Node, quantization_config: QuantizationConfig): + input_nodes = node.args[0] + + first_input_node = input_nodes[0] + input_qspec_map = {} + input_qspec_map[first_input_node] = quantization_config.input_activation + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, node) + ) + + for input_node in input_nodes[1:]: + if input_node not in input_qspec_map: + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act0_qspec, + _annotated=True, + ) + + def annotate_single_in_single_out( + node: Node, quantization_config: QuantizationConfig + ) -> None: + + input_qspec_map = {} + input_act = node.args[0] + input_qspec_map[input_act] = quantization_config.input_activation + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_matmul_input1(node: Node): + quantization_config_8a8w = get_default_8bit_qnn_ptq_config( + act_symmetric=True, act_observer=MinMaxObserver + ) + while isinstance(node, Node) and node.op == "call_function": + if node.target in [ + torch.ops.aten.permute.default, + torch.ops.aten.transpose.int, + ]: + annotate_single_in_single_out(node, quantization_config_8a8w) + node = node.args[0] + elif node.target == torch.ops.aten.cat.default: + annotate_cat(node, quantization_config_8a8w) + node = node.args[0][0] + else: + node = node.args[0] + + quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver) + + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: + annotate_matmul(node, quantization_config_16a8w) + annotate_matmul_input1(node.args[1]) + + def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 """ This function is specific for llama matmul op 16a8w. diff --git a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py b/examples/qualcomm/oss_scripts/llama2/model/static_llama.py index 94d44379b67..49fc5721281 100755 --- a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama2/model/static_llama.py @@ -13,7 +13,6 @@ FeedForward, ModelArgs, precompute_freqs_cis, - RMSNorm, ) @@ -191,8 +190,8 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False): config=config, output_new_cache_only=output_new_cache_only ) self.feed_forward = FeedForward(config) - self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) - self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) + self.ffn_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) def forward( self, @@ -236,7 +235,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True): for _ in range(config.n_layers) ] ) - self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) self.output = nn.Linear(config.dim, config.vocab_size, bias=False) self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) freqs_cos, freqs_sin = precompute_freqs_cis( diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py index 20d674888da..c3872ef9e93 100755 --- a/examples/qualcomm/oss_scripts/llama3_2/llama.py +++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py @@ -19,8 +19,8 @@ from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.quantizer.custom_annotation import ( + annotate_matmul_16a8w, custom_annotate_llama_last_conv_16a8w, - custom_annotate_llama_matmul_16a8w, ) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype @@ -78,7 +78,7 @@ def calibrate( token_list = sp_model.encode(user_prompts, bos=True, eos=False) with torch.no_grad(): - while token_list[-1] != sp_model.eos_id and pos < 512: + while token_list[-1] != sp_model.eos_id and pos < 511: logits, new_k_caches, new_v_caches = module( torch.full((1, 1), token_list[pos], dtype=torch.int32), torch.full((1, 1), pos), @@ -297,7 +297,7 @@ def compile(args): quant_dtype, custom_annotations=( custom_annotate_llama_last_conv_16a8w, - custom_annotate_llama_matmul_16a8w, + annotate_matmul_16a8w, ), ) end_quantize_ts = time.time()