Skip to content

Commit d36f7ea

Browse files
committed
Static LLM Decoder Refactor
1 parent 9c73b5d commit d36f7ea

File tree

5 files changed

+14
-10
lines changed

5 files changed

+14
-10
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4588,7 +4588,7 @@ def test_static_qwen2_5(self):
45884588
msg["inference_speed"], inference_speed_ref[self.model]
45894589
)
45904590

4591-
def test_qwen3(self):
4591+
def test_static_qwen3(self):
45924592
if not self.required_envs():
45934593
self.skipTest("missing required envs")
45944594

@@ -4611,7 +4611,7 @@ def test_qwen3(self):
46114611
"--ptq",
46124612
"16a8w",
46134613
"--decoder_model",
4614-
"qwen3_0.6b",
4614+
"qwen3_0_6b",
46154615
"--model_mode",
46164616
"hybrid",
46174617
"--prefill_ar_len",

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@
1515
"stories110m": "llama2",
1616
"llama3_2": "llama3",
1717
"qwen2_5": "qwen2_5",
18+
"qwen3_0_6b": "qwen2_5", # TODO: temp workaround, use special token for qwen3 in runner
19+
"qwen3_1_7b": "qwen2_5",
1820
"phi_4_mini": "phi_4_mini",
1921
}

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def compile(args, pte_filename, tokenizer):
428428
if args.checkpoint is None: # HF models
429429
checkpoint = download_and_convert_hf_checkpoint(
430430
SUPPORTED_HF_MODELS[args.decoder_model].repo_id,
431-
SUPPORTED_HF_MODELS[args.decoder_model].convert_weights,
431+
SUPPORTED_HF_MODELS[args.decoder_model].convert_weights.__func__,
432432
)
433433
state_dict = torch.load(
434434
checkpoint, weights_only=True, map_location="cpu", mmap=True

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
104104

105105
self.scale = float(self.head_dim) ** 0.5
106106

107-
if hasattr(config, "enable_r3") and config.enable_r3:
107+
if getattr(config, "enable_r3", False):
108108
self.register_buffer(
109109
"r3_weight",
110110
torch.tensor(
@@ -223,18 +223,20 @@ def forward_sha( # noqa: C901
223223
if self.use_qk_norm and self.qk_norm_before_rope:
224224
q[i] = self.q_norm_fn(q[i])
225225
q[i] = self.apply_rope_emb(q[i], freqs_cos, freqs_sin)
226-
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
227-
q[i] = torch.matmul(q[i], self.r3_weight)
228226
if self.use_qk_norm and not self.qk_norm_before_rope:
229227
q[i] = self.q_norm_fn(q[i])
228+
if getattr(self.config, "enable_r3", False):
229+
q[i] = torch.matmul(q[i], self.r3_weight)
230+
230231
for i in range(len(k)):
231232
if self.use_qk_norm and self.qk_norm_before_rope:
232233
k[i] = self.k_norm_fn(k[i])
233-
k[i] = self.apply_rope_emb(k[i], freqs_cos, freqs_sin).transpose(1, 2)
234-
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
235-
k[i] = torch.matmul(k[i], self.r3_weight)
234+
k[i] = self.apply_rope_emb(k[i], freqs_cos, freqs_sin)
236235
if self.use_qk_norm and not self.qk_norm_before_rope:
237236
k[i] = self.k_norm_fn(k[i])
237+
if getattr(self.config, "enable_r3", False):
238+
k[i] = torch.matmul(k[i], self.r3_weight)
239+
k[i] = k[i].transpose(1, 2)
238240

239241
output_y = []
240242
kh, vh = [], []

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
* @file
1111
*
1212
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Qwen2.5 0.5B, Qwen3 0.6B
13-
* / 1.7B phi4-mini-instruct with Qualcomm AI Engine Direct.
13+
* / 1.7B, phi4-mini-instruct with Qualcomm AI Engine Direct.
1414
*
1515
*/
1616

0 commit comments

Comments
 (0)