Skip to content

Commit 709e739

Browse files
authored
Qualcomm AI Engine Direct - Optimization in static llama
Differential Revision: D65986986 Pull Request resolved: #6849
1 parent af87283 commit 709e739

File tree

3 files changed

+87
-7
lines changed

3 files changed

+87
-7
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,94 @@
1414
QuantizationConfig,
1515
)
1616
from executorch.exir.dialects._ops import ops as exir_ops
17+
from torch.ao.quantization.observer import MinMaxObserver
1718
from torch.ao.quantization.quantizer import (
1819
QuantizationAnnotation,
1920
SharedQuantizationSpec,
2021
)
2122
from torch.fx import Node
2223

2324

25+
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
26+
"""
27+
This function is specific for matmul op 16a8w.
28+
"""
29+
30+
def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
31+
input_qspec_map = {}
32+
input_act = node.args[0]
33+
input_spec = quantization_config.input_activation
34+
input_qspec_map[input_act] = input_spec
35+
36+
input_act1 = node.args[1]
37+
input_spec1 = quantization_config.weight
38+
input_qspec_map[input_act1] = input_spec1
39+
40+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
41+
input_qspec_map=input_qspec_map,
42+
output_qspec=quantization_config.output_activation,
43+
_annotated=True,
44+
)
45+
46+
def annotate_cat(node: Node, quantization_config: QuantizationConfig):
47+
input_nodes = node.args[0]
48+
49+
first_input_node = input_nodes[0]
50+
input_qspec_map = {}
51+
input_qspec_map[first_input_node] = quantization_config.input_activation
52+
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
53+
(first_input_node, node)
54+
)
55+
56+
for input_node in input_nodes[1:]:
57+
if input_node not in input_qspec_map:
58+
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
59+
60+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
61+
input_qspec_map=input_qspec_map,
62+
output_qspec=share_qparams_with_input_act0_qspec,
63+
_annotated=True,
64+
)
65+
66+
def annotate_single_in_single_out(
67+
node: Node, quantization_config: QuantizationConfig
68+
) -> None:
69+
70+
input_qspec_map = {}
71+
input_act = node.args[0]
72+
input_qspec_map[input_act] = quantization_config.input_activation
73+
74+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
75+
input_qspec_map=input_qspec_map,
76+
output_qspec=quantization_config.output_activation,
77+
_annotated=True,
78+
)
79+
80+
def annotate_matmul_input1(node: Node):
81+
quantization_config_8a8w = get_default_8bit_qnn_ptq_config(
82+
act_symmetric=True, act_observer=MinMaxObserver
83+
)
84+
while isinstance(node, Node) and node.op == "call_function":
85+
if node.target in [
86+
torch.ops.aten.permute.default,
87+
torch.ops.aten.transpose.int,
88+
]:
89+
annotate_single_in_single_out(node, quantization_config_8a8w)
90+
node = node.args[0]
91+
elif node.target == torch.ops.aten.cat.default:
92+
annotate_cat(node, quantization_config_8a8w)
93+
node = node.args[0][0]
94+
else:
95+
node = node.args[0]
96+
97+
quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver)
98+
99+
for node in gm.graph.nodes:
100+
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
101+
annotate_matmul(node, quantization_config_16a8w)
102+
annotate_matmul_input1(node.args[1])
103+
104+
24105
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
25106
"""
26107
This function is specific for llama matmul op 16a8w.

examples/qualcomm/oss_scripts/llama2/model/static_llama.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
FeedForward,
1414
ModelArgs,
1515
precompute_freqs_cis,
16-
RMSNorm,
1716
)
1817

1918

@@ -191,8 +190,8 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
191190
config=config, output_new_cache_only=output_new_cache_only
192191
)
193192
self.feed_forward = FeedForward(config)
194-
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
195-
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
193+
self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
194+
self.ffn_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
196195

197196
def forward(
198197
self,
@@ -236,7 +235,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True):
236235
for _ in range(config.n_layers)
237236
]
238237
)
239-
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
238+
self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
240239
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
241240
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
242241
freqs_cos, freqs_sin = precompute_freqs_cis(

examples/qualcomm/oss_scripts/llama3_2/llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
2020

2121
from executorch.backends.qualcomm.quantizer.custom_annotation import (
22+
annotate_matmul_16a8w,
2223
custom_annotate_llama_last_conv_16a8w,
23-
custom_annotate_llama_matmul_16a8w,
2424
)
2525

2626
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
@@ -76,7 +76,7 @@ def calibrate(
7676
token_list = sp_model.encode(user_prompts, bos=True, eos=False)
7777

7878
with torch.no_grad():
79-
while token_list[-1] != sp_model.eos_id and pos < 512:
79+
while token_list[-1] != sp_model.eos_id and pos < 511:
8080
logits, new_k_caches, new_v_caches = module(
8181
torch.full((1, 1), token_list[pos], dtype=torch.int32),
8282
torch.full((1, 1), pos),
@@ -295,7 +295,7 @@ def compile(args):
295295
quant_dtype,
296296
custom_annotations=(
297297
custom_annotate_llama_last_conv_16a8w,
298-
custom_annotate_llama_matmul_16a8w,
298+
annotate_matmul_16a8w,
299299
),
300300
)
301301
end_quantize_ts = time.time()

0 commit comments

Comments
 (0)