Skip to content

Commit c869d51

Browse files
Qualcomm AI Engine Direct - GA Static Olmo-1b
Summary: - e2e script for GA Static OLMo-1b - perf 16a4w block quant token rate in kv mode: ~= 63 tokens/sec(SM8750) - acc: PPL ~= (fp: 8.735 -> htp: 9.945) in wikitext dataset - add model params file & model weight converter - add workaround pass for LayerNorm without weight & unitest - fix layernorm op builder & fix layernorm quant annotator
1 parent 2845fd3 commit c869d51

File tree

22 files changed

+514
-44
lines changed

22 files changed

+514
-44
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .fuse_consecutive_cast import FuseConsecutiveCast
2929
from .fuse_consecutive_transpose import FuseConsecutiveTranspose
3030
from .i64_to_i32 import I64toI32
31+
from .insert_frozen_layer_norm_weight import InsertFrozenLayerNormWeight
3132
from .insert_io_qdq import InsertIOQDQ
3233
from .insert_requantize import InsertRequantize
3334
from .layout_transform import LayoutTransform
@@ -67,6 +68,7 @@
6768
FuseConsecutiveCast,
6869
FuseConsecutiveTranspose,
6970
I64toI32,
71+
InsertFrozenLayerNormWeight,
7072
InsertIOQDQ,
7173
InsertRequantize,
7274
LayoutTransform,
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
11+
12+
# TODO: Remove this workaround once HTP fixes the bug — LayerNorm without weights should be supported.
13+
class InsertFrozenLayerNormWeight(ExportPass):
14+
"""
15+
This pass injects a frozen weight parameter (filled with ones) into LayerNorm ops
16+
that were exported without weight (i.e., elementwise_affine=False), to satisfy
17+
backends that require the presence of a weight parameter.
18+
19+
It operates at the ExportedProgram level, modifying both the FX graph and
20+
the graph_signature to include the new frozen parameter.
21+
22+
Example transformation:
23+
24+
Before:
25+
%out = aten.layer_norm(%x, normalized_shape=[128], weight=None, bias=None, eps=1e-5)
26+
27+
After:
28+
%weight = get_attr("layer_norm_weight_0")
29+
%out = aten.layer_norm(%x, normalized_shape=[128], weight=%weight, bias=None, eps=1e-5)
30+
31+
The injected weight is a frozen parameter with all values set to 1.0.
32+
"""
33+
34+
def __init__(self):
35+
super(InsertFrozenLayerNormWeight, self).__init__()
36+
self.layer_norm = torch.ops.aten.layer_norm.default
37+
38+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
39+
graph = graph_module.graph
40+
modified = False
41+
frozen_weight_idx = 0
42+
43+
for node in graph.nodes:
44+
if node.op != "call_function" or node.target != self.layer_norm:
45+
continue
46+
47+
# Detect LayerNorm ops missing the 'weight' argument
48+
if len(node.args) < 3:
49+
normalized_shape = node.args[1]
50+
51+
# Create a frozen weight tensor filled with ones
52+
param_name = f"{self.layer_norm.__name__.split('.')[0]}_weight_{frozen_weight_idx}"
53+
frozen_weight = torch.ones(normalized_shape)
54+
graph_module.register_buffer(param_name, frozen_weight)
55+
with graph.inserting_before(node):
56+
weight_node = graph.get_attr(param_name)
57+
node.args = (node.args[0], node.args[1], weight_node, *node.args[3:])
58+
59+
frozen_weight_idx += 1
60+
modified = True
61+
62+
graph.eliminate_dead_code()
63+
graph_module.recompile()
64+
return PassResult(graph_module, modified)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
FuseConsecutiveCast,
3434
FuseConsecutiveTranspose,
3535
I64toI32,
36+
InsertFrozenLayerNormWeight,
3637
InsertIOQDQ,
3738
InsertRequantize,
3839
LayoutTransform,
@@ -201,6 +202,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
201202
self.add_pass(DecomposeEinsum())
202203
self.add_pass(DecomposeExpM1())
203204
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
205+
self.add_pass(InsertFrozenLayerNormWeight())
204206
self.add_pass(ReplaceInfValues())
205207
self.add_pass(LiftConstantScalarOperands())
206208
return self._transform(graph_module)
@@ -220,6 +222,7 @@ def transform_for_export_pipeline(
220222
if convert_linear_to_conv2d:
221223
self.add_pass(ConvertLinearToConv2d(exported_program))
222224
self.add_pass(ConvertSquareToPow())
225+
self.add_pass(InsertFrozenLayerNormWeight())
223226
self.add_pass(LiftConstantScalarOperands())
224227
self._transform(exported_program.graph_module)
225228
ep = lift_constant_tensor_pass(exported_program)

backends/qualcomm/builders/op_layer_norm.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def define_node(
4040
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
4141
nodes_to_wrappers,
4242
)
43+
layer_norm_input_tensors = [input_tensor_wrapper]
4344

4445
normalized_shapes = node.args[1]
4546
if (
@@ -55,16 +56,16 @@ def define_node(
5556
axis_shape = [len(axis)]
5657

5758
weight_node = self.get_node(node.args[2])
58-
weight_tensor = get_parameter(weight_node, self.edge_program)
59-
weight_tensor_wrapper = self.define_tensor(
60-
weight_node,
61-
node,
62-
weight_tensor,
63-
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
64-
nodes_to_wrappers,
65-
)
66-
67-
layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]
59+
if weight_node is not None:
60+
weight_tensor = get_parameter(weight_node, self.edge_program)
61+
weight_tensor_wrapper = self.define_tensor(
62+
weight_node,
63+
node,
64+
weight_tensor,
65+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
66+
nodes_to_wrappers,
67+
)
68+
layer_norm_input_tensors.append(weight_tensor_wrapper)
6869

6970
bias_node = self.get_node(node.args[3])
7071
if bias_node is not None:

backends/qualcomm/quantizer/annotators.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,9 +1235,12 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non
12351235
@register_annotator([torch.ops.aten.layer_norm.default])
12361236
def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> None:
12371237
act_node = node.args[0]
1238-
weight_node = node.args[2]
1239-
bias_node = None
1238+
# OLMo LayerNorm but with no learnable weight and bias.
1239+
weight_node = None
12401240
if len(node.args) > 2:
1241+
weight_node = node.args[2]
1242+
bias_node = None
1243+
if len(node.args) > 3:
12411244
bias_node = node.args[3]
12421245

12431246
if _is_annotated([node]):
@@ -1249,19 +1252,21 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
12491252
act_node,
12501253
input_act_qspec,
12511254
)
1252-
if input_act_qspec.dtype == torch.int32:
1253-
annotate_input_qspec_map(
1254-
node,
1255-
weight_node,
1256-
get_16a16w_qnn_ptq_config().weight,
1257-
)
1258-
else:
1259-
annotate_input_qspec_map(
1260-
node,
1261-
weight_node,
1262-
input_act_qspec,
1263-
)
1264-
nodes_to_mark_annotated = [node, weight_node]
1255+
nodes_to_mark_annotated = [node]
1256+
if weight_node:
1257+
if input_act_qspec.dtype == torch.int32:
1258+
annotate_input_qspec_map(
1259+
node,
1260+
weight_node,
1261+
get_16a16w_qnn_ptq_config().weight,
1262+
)
1263+
else:
1264+
annotate_input_qspec_map(
1265+
node,
1266+
weight_node,
1267+
input_act_qspec,
1268+
)
1269+
nodes_to_mark_annotated.append(weight_node)
12651270
if bias_node:
12661271
annotate_input_qspec_map(
12671272
node,

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from enum import Enum, unique
67
from typing import Sequence
78

89
import torch
@@ -31,6 +32,17 @@
3132
)
3233

3334

35+
@unique
36+
class StaticLLMQuantConfig(Enum):
37+
"""
38+
Layer namespace configuration for Qualcomm's static LLaMA quantization.
39+
"""
40+
41+
wq_sha = "wq_sha" # Query weight (single head)
42+
wk_sha = "wk_sha" # Key weight (single head)
43+
wv_sha = "wv_sha" # Value weight (single head)
44+
45+
3446
def annotate_eurobert(gm: torch.fx.GraphModule):
3547
"""
3648
QNN does not support int32 -> signed 16bit quant
@@ -166,11 +178,35 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
166178
)
167179

168180

169-
def annotate_wv_sha(gm: torch.fx.GraphModule, quantization_config: QuantizationConfig):
181+
def annotate_qkv_proj_sha(
182+
gm: torch.fx.GraphModule,
183+
quantization_config: QuantizationConfig,
184+
qkv_tags: set[StaticLLMQuantConfig],
185+
):
186+
"""
187+
Annotates QKV projection layers in a GraphModule for quantization,
188+
specifically layers defined in StaticLLMQuantConfig.
189+
190+
Args:
191+
qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers
192+
(e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in
193+
StaticLLMQuantConfig are allowed.
194+
195+
Raises:
196+
ValueError: If any tag in `qkv_tags` is not among the allowed enum members.
197+
"""
198+
199+
# Get all valid tags from the StaticLLMQuantConfig enum
200+
allowed_tags = set(StaticLLMQuantConfig)
201+
invalid_tags = qkv_tags - allowed_tags
202+
if invalid_tags:
203+
raise ValueError(
204+
f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}"
205+
)
206+
170207
for node in gm.graph.nodes:
171-
if (
172-
node.target == torch.ops.aten.conv2d.default
173-
and "wv_sha" in node.meta["stack_trace"]
208+
if node.target == torch.ops.aten.conv2d.default and any(
209+
tag.value in node.meta["stack_trace"] for tag in qkv_tags
174210
):
175211
input_qspec_map = {}
176212
input_qspec_map[node.args[0]] = quantization_config.input_activation

backends/qualcomm/tests/models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,17 @@ def forward(self, x):
11131113
return self.linear2(x1)
11141114

11151115

1116+
class LayerNormWithoutParams(torch.nn.Module):
1117+
def __init__(self, hidden_size: int):
1118+
super().__init__()
1119+
self.normalized_shape = (hidden_size,)
1120+
1121+
def forward(self, x):
1122+
return torch.nn.functional.layer_norm(
1123+
x, self.normalized_shape, None, None, eps=1e-5
1124+
)
1125+
1126+
11161127
class LayerNorm(torch.nn.Module):
11171128
def __init__(self, bias=True):
11181129
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,11 @@ def test_qnn_backend_up_sampling_nearest_2d_with_size(self):
832832
self.lower_module_and_test_output(module, sample_input)
833833

834834
def test_qnn_backend_layer_norm(self):
835-
modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405
835+
modules = [
836+
LayerNorm(), # noqa: F405
837+
LayerNorm(bias=False), # noqa: F405
838+
LayerNormWithoutParams(768), # noqa: F405
839+
]
836840
sample_input = (torch.randn(196, 768),)
837841
for i, module in enumerate(modules):
838842
with self.subTest(i=i):
@@ -2360,7 +2364,11 @@ def test_qnn_backend_up_sampling_nearest_2d_with_size(self):
23602364
self.lower_module_and_test_output(module, sample_input)
23612365

23622366
def test_qnn_backend_layer_norm(self):
2363-
modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405
2367+
modules = [
2368+
LayerNorm(), # noqa: F405
2369+
LayerNorm(bias=False), # noqa: F405
2370+
LayerNormWithoutParams(768), # noqa: F405
2371+
]
23642372
sample_input = (torch.randn(196, 768),)
23652373
for i, module in enumerate(modules):
23662374
with self.subTest(i=i):
@@ -4863,6 +4871,65 @@ def test_llama_stories_110m(self):
48634871
if not self.compile_only and not self.enable_x86_64:
48644872
self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai
48654873

4874+
def test_static_olmo(self):
4875+
if not self.required_envs():
4876+
self.skipTest("missing required envs")
4877+
4878+
prompt = "Simply put, the theory of relativity states that"
4879+
cmds = [
4880+
"python",
4881+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
4882+
"--artifact",
4883+
self.artifact_dir,
4884+
"--build_folder",
4885+
self.build_folder,
4886+
"--model",
4887+
self.model,
4888+
"--ip",
4889+
self.ip,
4890+
"--port",
4891+
str(self.port),
4892+
"--prompt",
4893+
f"{prompt}",
4894+
"--decoder_model",
4895+
"olmo-1b",
4896+
"--model_mode",
4897+
"kv",
4898+
"--temperature",
4899+
"0",
4900+
"--max_seq_len",
4901+
"1024",
4902+
"--eval_perplexity",
4903+
"--task",
4904+
"wikitext",
4905+
]
4906+
if self.compile_only:
4907+
cmds.extend(["--compile_only"])
4908+
elif self.device:
4909+
cmds.extend(["--device", self.device])
4910+
if self.host:
4911+
cmds.extend(["--host", self.host])
4912+
elif self.enable_x86_64:
4913+
cmds.extend(["--enable_x86_64"])
4914+
if self.pre_gen_pte:
4915+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
4916+
4917+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4918+
with Listener((self.ip, self.port)) as listener:
4919+
conn = listener.accept()
4920+
p.communicate()
4921+
msg = json.loads(conn.recv())
4922+
if "Error" in msg:
4923+
self.fail(msg["Error"])
4924+
else:
4925+
inference_speed_ref = {"SM8650": 35, "SM8750": 60}
4926+
self.assertLessEqual(msg["wiki_ppl"], 10)
4927+
self.assertLessEqual(msg["pte_size"], 1_000_000_000) # 1GB
4928+
if self.model in inference_speed_ref:
4929+
self.assertGreaterEqual(
4930+
msg["inference_speed"], inference_speed_ref[self.model]
4931+
)
4932+
48664933
def test_static_phi4(self):
48674934
if not self.required_envs():
48684935
self.skipTest("missing required envs")

examples/models/llama/model_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class ModelArgs:
2222
num_experts: int = 8 # Number of experts
2323
num_activated_experts: int = 2 # Number of experts to activate
2424
attention_type: str = "mha" # Attention type, registered in attention.py
25+
norm_type: str = "rmsnorm" # Normalization type, registered in norm.py
2526
attention_qkv_bias: bool = False
2627
use_kv_cache: bool = False # Use key/value cache
2728
use_sdpa_with_kv_cache_op: bool = (

examples/models/olmo/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.examples.models.llama.model import Llama2Model
5+
from executorch.examples.models.olmo.convert_weights import convert_weights
6+
7+
8+
class OlmoModel(Llama2Model):
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
13+
__all__ = [
14+
"OlmoModel",
15+
"convert_weights",
16+
]

0 commit comments

Comments
 (0)