Skip to content

Commit 216a1ec

Browse files
authored
Qualcomm AI Engine Direct - GA Qwen 2.5 0.5B (#12333)
## Summary: - Add a decoder_model_wrapper.py to ensure that the exported model can be fully delegated in Qnn Backend - Add a e2e script to run qwen 2.5 - Support spin quant R3 - Replace Qwen2Attention with QCQwen2Attention - Pre-compute freqs_cos and freqs_sin to bypass rotary embedding - Replace Qwen2RMSNorm with torch.nn,.RMSNorm - Tag quant IO to avoid insering Q/DQ for I/O - Reuse executorch llama runner, llama_main Note that accuracy currently is bad, need to investigate more. ## Reproduce command ``` python3 examples/qualcomm/oss_scripts/qwen/qwen.py -s <serial>-H <host> -m SM8750 --prompt "My favourite condiment is " -b build-android --decoder_model qwen2.5_0.5B --ptq 16a16w ``` ## Results ### 7/9 ptq: 16a16w Speed: 62 tok/sec on SM8750, seq_len = 128 Accuracy: Bad Outputs: ``` I 00:00:02.944266 executorch:stats.h:108] Prompt Tokens: 6 Generated Tokens: 121 I 00:00:02.944270 executorch:stats.h:114] Model Load Time: 0.677000 (seconds) I 00:00:02.944274 executorch:stats.h:124] Total inference time: 2.034000 (seconds) Rate: 59.488692 (tokens/second) I 00:00:02.944279 executorch:stats.h:132] Prompt evaluation: 0.093000 (seconds) Rate: 64.516129 (tokens/second) I 00:00:02.944283 executorch:stats.h:143] Generated 121 tokens: 1.941000 (seconds) Rate: 62.339001 (tokens/second) I 00:00:02.944288 executorch:stats.h:151] Time to first generated token: 0.093000 (seconds) I 00:00:02.944292 executorch:stats.h:158] Sampling time over 127 tokens: 0.059000 (seconds) My favourite condiment is a thing, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan, and I am a fan ``` ### 7/11 ptq: 16a8w Speed: 135 tok/sec on SM8750, seq_len = 128 Accuracy: Seems better Outputs: ``` I 00:00:00.734588 executorch:text_llm_runner.cpp:100] RSS after loading model: 829.648438 MiB (0 if unsupported) I 00:00:00.734865 executorch:text_llm_runner.cpp:157] Max new tokens resolved: 122, given start_pos 0, num_prompt_tokens 6, max_context_len 128 I 00:00:00.784392 executorch:text_llm_runner.cpp:184] RSS after prompt prefill: 829.648438 MiB (0 if unsupported) I 00:00:01.677137 executorch:text_llm_runner.cpp:204] RSS after finishing text generation: 829.648438 MiB (0 if unsupported) I 00:00:01.677171 executorch:stats.h:108] Prompt Tokens: 6 Generated Tokens: 121 I 00:00:01.677180 executorch:stats.h:114] Model Load Time: 0.431000 (seconds) I 00:00:01.677187 executorch:stats.h:124] Total inference time: 0.943000 (seconds) Rate: 128.313892 (tokens/second) I 00:00:01.677193 executorch:stats.h:132] Prompt evaluation: 0.050000 (seconds) Rate: 120.000000 (tokens/second) I 00:00:01.677201 executorch:stats.h:143] Generated 121 tokens: 0.893000 (seconds) Rate: 135.498320 (tokens/second) I 00:00:01.677208 executorch:stats.h:151] Time to first generated token: 0.050000 (seconds) I 00:00:01.677215 executorch:stats.h:158] Sampling time over 127 tokens: 0.017000 (seconds) [INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters [INFO] [Qnn ExecuTorch]: Destroy Qnn context [INFO] [Qnn ExecuTorch]: Destroy Qnn device [INFO] [Qnn ExecuTorch]: Destroy Qnn backend [WARNING] [Qnn ExecuTorch]: QnnDsp <W> Function not called, PrepareLib isn't loaded! /data/local/tmp/shewu/executorch/qwen_qnn_q16/outputs/: 1 file pulled. 0.7 MB/s (883 bytes in 0.001s) INFO:root:Results[0]: Setting up pretokenizer... Pretokenizer set up My favourite condiment is iced tea. I love it so much that I have to have it every day. I have a habit of making it at home. I have a few recipes for iced tea. I have a few favorite iced tea recipes. I have a few favorite iced tea recipes. I have a few favorite iced tea recipes. I have a few favorite iced tea recipes. I have a few favorite iced tea recipes. I have a few favorite iced tea recipes. I have a few favorite iced tea recipes. I have a few favorite iced tea recipes. I have a few favorite ``` cc: @winskuo-quic , @haowhsu-quic
1 parent c28a0b4 commit 216a1ec

27 files changed

+2096
-100
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .annotate_unbind import AnnotateUnbind
1111
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1212
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
13+
from .convert_linear_to_conv2d import ConvertLinearToConv2d
1314
from .convert_square_to_pow import ConvertSquareToPow
1415
from .decompose_any import DecomposeAny
1516
from .decompose_cdist import DecomposeCDist
@@ -48,6 +49,7 @@
4849
AnnotateUnbind,
4950
ConvertBmmToMatmul,
5051
ConvertConv1dToConv2d,
52+
ConvertLinearToConv2d,
5153
ConvertSquareToPow,
5254
DecomposeAny,
5355
DecomposeCDist,

backends/qualcomm/_passes/build_quant_io.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
3939
if QCOM_QUANTIZED_IO in n.meta:
4040
n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO])
4141

42+
spec = []
43+
for user in list(call_delegate[0].users):
44+
spec.append(self._make_spec(user.meta["val"]))
45+
call_delegate[0].meta["spec"] = tuple(spec)
46+
4247
def call(self, graph_module: torch.fx.GraphModule):
4348
self._build(graph_module)
4449
graph_module.graph.eliminate_dead_code()

backends/qualcomm/_passes/convert_conv1d_to_conv2d.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
1010
from executorch.exir.pass_base import ExportPass, PassResult
1111

12-
from .utils import copy_meta
12+
from .utils import append_qdq, copy_meta
1313

1414

1515
class ConvertConv1dToConv2d(ExportPass):
@@ -26,31 +26,6 @@ def __init__(self, edge_program: torch.export.ExportedProgram):
2626
torch.ops.aten.conv_transpose1d.default: torch.ops.aten.conv_transpose2d.input,
2727
}
2828

29-
def append_qdq(
30-
self,
31-
graph_module: torch.fx.GraphModule,
32-
node: torch.fx.Node,
33-
qdq_node: torch.fx.Node,
34-
):
35-
q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
36-
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
37-
if qdq_node.target not in {q_op, dq_op}:
38-
return node
39-
40-
with graph_module.graph.inserting_after(node):
41-
q_args = (node, *qdq_node.args[1:])
42-
q_node = graph_module.graph.create_node("call_function", q_op, q_args)
43-
q_node.meta = copy_meta(node.meta)
44-
q_node.meta["val"] = q_node.meta["val"].to(q_args[-1])
45-
with graph_module.graph.inserting_after(q_node):
46-
dq_args = (q_node, *qdq_node.args[1:])
47-
dq_node = graph_module.graph.create_node(
48-
"call_function", dq_op, dq_args
49-
)
50-
dq_node.meta = copy_meta(node.meta)
51-
52-
return dq_node
53-
5429
def call(self, graph_module: torch.fx.GraphModule):
5530
graph = graph_module.graph
5631
for node in graph.nodes:
@@ -69,7 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule):
6944
unsqueeze_node.meta = copy_meta(
7045
input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
7146
)
72-
qdq_node_after_unsqueeze = self.append_qdq(
47+
qdq_node_after_unsqueeze = append_qdq(
7348
graph_module=graph_module,
7449
node=unsqueeze_node,
7550
qdq_node=input_node,
@@ -139,7 +114,7 @@ def call(self, graph_module: torch.fx.GraphModule):
139114
conv2d_node.meta = copy_meta(
140115
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
141116
)
142-
qdq_node_after_conv2d = self.append_qdq(
117+
qdq_node_after_conv2d = append_qdq(
143118
graph_module=graph_module,
144119
node=conv2d_node,
145120
qdq_node=list(node.users)[0],
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.qualcomm._passes.utils import append_qdq, copy_meta
9+
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
from torch.fx import GraphModule
12+
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix
13+
14+
15+
def _pad_list_to_4(lst):
16+
return lst + [1] * (4 - len(lst)) if len(lst) < 4 else lst[:4]
17+
18+
19+
class ConvertLinearToConv2d(ExportPass):
20+
"""
21+
Replace aten.linear.default with equivalent 1x1 conv2d using call_function nodes.
22+
"""
23+
24+
def __init__(self, edge_program: torch.export.ExportedProgram):
25+
super().__init__()
26+
self.edge_program = edge_program
27+
self.per_block_dq = torch.ops.torchao.dequantize_affine.default
28+
29+
def _register_tensor(
30+
self,
31+
gm: torch.fx.GraphModule,
32+
node: torch.fx.Node,
33+
tensor_constant: torch.Tensor,
34+
) -> torch.fx.Node:
35+
new_node_name = get_new_attr_name_with_prefix(node.name)(gm)
36+
gm.register_buffer(new_node_name, tensor_constant)
37+
38+
with gm.graph.inserting_before(node):
39+
get_attr_node = gm.graph.get_attr(new_node_name)
40+
get_attr_node.meta["val"] = tensor_constant
41+
return get_attr_node
42+
43+
def _append_dq(
44+
self,
45+
graph_module: torch.fx.GraphModule,
46+
node: torch.fx.Node,
47+
qdq_node: torch.fx.Node,
48+
):
49+
q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
50+
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
51+
52+
if qdq_node.target not in {q_op, dq_op}:
53+
return node
54+
55+
with graph_module.graph.inserting_after(node):
56+
dq_args = (node, *qdq_node.args[1:])
57+
dq_node = graph_module.graph.create_node("call_function", dq_op, dq_args)
58+
dq_node.meta = copy_meta(node.meta)
59+
return dq_node
60+
61+
def _create_node(
62+
self, graph_module, target, args, meta_node, new_meta_val, qdq_node
63+
):
64+
new_node = graph_module.graph.call_function(target, args)
65+
new_node.meta = copy_meta(
66+
meta_node.meta,
67+
lambda m, new_meta_val=new_meta_val: {
68+
**m,
69+
"val": new_meta_val,
70+
},
71+
)
72+
dq_node = append_qdq(
73+
graph_module=graph_module,
74+
node=new_node,
75+
qdq_node=qdq_node,
76+
)
77+
return dq_node
78+
79+
def _reshape_weight(self, graph_module, weight_node, dq_node):
80+
# After export, constant node will be placeholder from edge_program
81+
weight_val = get_parameter(weight_node, self.edge_program)
82+
assert weight_val is not None, "Cannot get the weight in linear node."
83+
84+
weight_val = weight_val.reshape(*weight_val.shape, 1, 1)
85+
# Create the new weight node when several node share the same weight
86+
# such as embedding and lm_head in LLM.
87+
if len(list(weight_node.users)) > 1:
88+
weight_node = self._register_tensor(graph_module, weight_node, weight_val)
89+
dq_node = self._append_dq(graph_module, weight_node, dq_node)
90+
else:
91+
set_parameter(
92+
(
93+
torch.nn.Parameter(weight_val)
94+
if weight_val.dtype == torch.float
95+
else weight_val
96+
),
97+
weight_node,
98+
self.edge_program,
99+
)
100+
101+
# Update node meta val
102+
weight_node.meta["val"] = weight_node.meta["val"].reshape(weight_val.shape)
103+
dq_node.meta["val"] = dq_node.meta["val"].reshape(weight_val.shape)
104+
# Update block size for per-block quant
105+
if dq_node.target == self.per_block_dq:
106+
new_args = list(dq_node.args)
107+
# pad block size
108+
new_args[1] = _pad_list_to_4(list(new_args[1]))
109+
dq_node.args = tuple(new_args)
110+
111+
return dq_node
112+
113+
def call(self, graph_module: GraphModule):
114+
graph = graph_module.graph
115+
116+
for node in list(graph.nodes):
117+
if node.target == torch.ops.aten.linear.default:
118+
input_node = node.args[0]
119+
# In quantization flow, weight_arg will be dq node.
120+
weight_arg = node.args[1]
121+
weight_node = (
122+
weight_arg if weight_arg.op == "placeholder" else weight_arg.args[0]
123+
)
124+
bias_arg = node.args[2] if len(node.args) > 2 else None
125+
126+
input_meta_val = input_node.meta["val"]
127+
output_meta_val = node.meta["val"]
128+
if bias_arg:
129+
bias_meta_val = bias_arg.meta["val"]
130+
131+
rank = input_meta_val.ndim
132+
with graph.inserting_before(node):
133+
# Step 1: reshape input
134+
# rank = 2: (dim, C) -> (1, C, 1, dim)
135+
# rank = 3: (N, dim, C) -> (N, C, 1, dim)
136+
# rank = 4: (N, H, W, C) -> (N, C, H, W)
137+
order = (0, 3, 1, 2)
138+
if rank <= 3:
139+
# (dim, C) -> (1, C, 1, dim)
140+
# (N, dim, C) -> (N, C, 1, dim)
141+
shape = (
142+
(1, *input_meta_val.shape, 1)
143+
if rank == 2
144+
else (*input_meta_val.shape, 1)
145+
)
146+
x_meta_val = input_meta_val.reshape(shape)
147+
input_node = self._create_node(
148+
graph_module,
149+
torch.ops.aten.reshape.default,
150+
(input_node, shape),
151+
node,
152+
x_meta_val,
153+
input_node,
154+
)
155+
order = (0, 2, 3, 1)
156+
157+
x_meta_val = x_meta_val.permute(order)
158+
x = self._create_node(
159+
graph_module,
160+
torch.ops.aten.permute.default,
161+
(input_node, order),
162+
node,
163+
x_meta_val,
164+
input_node,
165+
)
166+
167+
# Step 2: reshape weight
168+
weight_arg = self._reshape_weight(
169+
graph_module, weight_node, weight_arg
170+
)
171+
weight_meta_val = weight_arg.meta["val"]
172+
173+
conv_args = [x, weight_arg]
174+
conv_args_meta_val = [x_meta_val, weight_meta_val]
175+
if bias_arg:
176+
conv_args.append(bias_arg)
177+
conv_args_meta_val.append(bias_meta_val)
178+
else:
179+
conv_args.append(None)
180+
conv_args_meta_val.append(None)
181+
182+
conv_args.extend(
183+
[[1, 1], [0, 0], [1, 1], 1]
184+
) # stride, padding, dilation, groups
185+
conv_node_val = torch.nn.functional.conv2d(
186+
*conv_args_meta_val,
187+
stride=(1, 1),
188+
padding=(0, 0),
189+
dilation=(1, 1),
190+
groups=1,
191+
)
192+
conv_node = self._create_node(
193+
graph_module,
194+
torch.ops.aten.conv2d.default,
195+
tuple(conv_args),
196+
node,
197+
conv_node_val,
198+
list(node.users)[0],
199+
)
200+
201+
# Step 3: restore shape
202+
# rank = 2: (1, C, 1, dim) -> (dim, C)
203+
# rank = 3: (N, C, 1, dim) -> (N, dim C)
204+
# rank = 4: (N, C, H, W) -> (N, H, W, C)
205+
order = (0, 2, 3, 1) if rank == 4 else (0, 3, 1, 2)
206+
y_meta_val = conv_node_val.permute(order)
207+
y = self._create_node(
208+
graph_module,
209+
torch.ops.aten.permute.default,
210+
(conv_node, order),
211+
node,
212+
y_meta_val,
213+
list(node.users)[0],
214+
)
215+
if rank <= 3:
216+
target_shape = output_meta_val.shape
217+
y_meta_val = y_meta_val.reshape(target_shape)
218+
y = self._create_node(
219+
graph_module,
220+
torch.ops.aten.reshape.default,
221+
(y, target_shape),
222+
node,
223+
y_meta_val,
224+
list(node.users)[0],
225+
)
226+
227+
node.replace_all_uses_with(y)
228+
graph.erase_node(node)
229+
230+
graph.eliminate_dead_code()
231+
graph_module.recompile()
232+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
AnnotateUnbind,
1616
ConvertBmmToMatmul,
1717
ConvertConv1dToConv2d,
18+
ConvertLinearToConv2d,
1819
ConvertSquareToPow,
1920
DecomposeAny,
2021
DecomposeCDist,
@@ -82,7 +83,6 @@ def get_capture_program_passes():
8283
(AnnotateStack, True),
8384
(AnnotateUnbind, True),
8485
(ConvertBmmToMatmul, False),
85-
(ConvertConv1dToConv2d, True),
8686
(DecomposeAny, True),
8787
(DecomposeColIm, True),
8888
(DecomposeMinMaxDim, True),
@@ -92,7 +92,7 @@ def get_capture_program_passes():
9292
(I64toI32, True),
9393
(LayoutTransform, True),
9494
(RecomposePixelUnshuffle, True),
95-
(RecomposeRmsNorm, False),
95+
(RecomposeRmsNorm, True),
9696
(Remove0DTensor, True),
9797
(RemoveRedundancy, True),
9898
(TagQuantIO, False),
@@ -190,6 +190,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
190190
self.add_pass(RemoveRedundancy(quantization_capture=True))
191191
self.add_pass(ReduceDynamicRange())
192192
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
193+
self.add_pass(RecomposeRmsNorm(quantization_capture=True))
193194
self.add_pass(ReplaceArangeArgs())
194195
self.add_pass(DecomposeCDist())
195196
self.add_pass(DecomposeScaledDotProductAttention())
@@ -203,7 +204,9 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
203204
self.add_pass(LiftConstantScalarOperands())
204205
return self._transform(graph_module)
205206

206-
def transform_for_export_pipeline(self, exported_program: ExportedProgram):
207+
def transform_for_export_pipeline(
208+
self, exported_program: ExportedProgram, convert_linear_to_conv2d: bool = False
209+
):
207210
self.add_pass(DecomposeCDist())
208211
self.add_pass(DecomposeScaledDotProductAttention())
209212
self.add_pass(DecomposeRoll())
@@ -213,6 +216,8 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram):
213216
# this pass will rewrite state_dict, it needs to be accomplished before
214217
# to_edge_transform_and_lower
215218
self.add_pass(ConvertConv1dToConv2d(exported_program))
219+
if convert_linear_to_conv2d:
220+
self.add_pass(ConvertLinearToConv2d(exported_program))
216221
self.add_pass(ConvertSquareToPow())
217222
self.add_pass(LiftConstantScalarOperands())
218223
self._transform(exported_program.graph_module)

0 commit comments

Comments
 (0)