Skip to content

Commit cf669e3

Browse files
authored
Qualcomm AI Engine Direct - Static LLM Decoder Refactor (#13314)
### Summary - Update UT name - Revert R3 changes to original behavior - Minor refactor on code logic. ### Test plan NA
1 parent 4438d31 commit cf669e3

File tree

6 files changed

+15
-11
lines changed

6 files changed

+15
-11
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4590,7 +4590,7 @@ def test_static_qwen2_5(self):
45904590
msg["inference_speed"], inference_speed_ref[self.model]
45914591
)
45924592

4593-
def test_qwen3(self):
4593+
def test_static_qwen3(self):
45944594
if not self.required_envs():
45954595
self.skipTest("missing required envs")
45964596

@@ -4613,7 +4613,7 @@ def test_qwen3(self):
46134613
"--ptq",
46144614
"16a8w",
46154615
"--decoder_model",
4616-
"qwen3_0.6b",
4616+
"qwen3_0_6b",
46174617
"--model_mode",
46184618
"hybrid",
46194619
"--prefill_ar_len",

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class Qwen3_0_6B(HFModel):
6868
@register_hf_model("qwen3_1_7b")
6969
@dataclass(init=False, frozen=True)
7070
class Qwen3_1_7B(HFModel):
71-
repo_id: str = "Qwen/Qwen/Qwen3-1.7B"
71+
repo_id: str = "Qwen/Qwen3-1.7B"
7272
params_path: str = os.path.join(
7373
BASE_DIR, "../../../models/qwen3/config/1_7b_config.json"
7474
)

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@
1515
"stories110m": "llama2",
1616
"llama3_2": "llama3",
1717
"qwen2_5": "qwen2_5",
18+
"qwen3_0_6b": "qwen2_5", # TODO: temp workaround, use special token for qwen3 in runner
19+
"qwen3_1_7b": "qwen2_5",
1820
"phi_4_mini": "phi_4_mini",
1921
}

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def compile(args, pte_filename, tokenizer):
429429
if args.checkpoint is None: # HF models
430430
checkpoint = download_and_convert_hf_checkpoint(
431431
SUPPORTED_HF_MODELS[args.decoder_model].repo_id,
432-
SUPPORTED_HF_MODELS[args.decoder_model].convert_weights,
432+
SUPPORTED_HF_MODELS[args.decoder_model].convert_weights.__func__,
433433
)
434434
state_dict = torch.load(
435435
checkpoint, weights_only=True, map_location="cpu", mmap=True

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
104104

105105
self.scale = float(self.head_dim) ** 0.5
106106

107-
if hasattr(config, "enable_r3") and config.enable_r3:
107+
if getattr(config, "enable_r3", False):
108108
self.register_buffer(
109109
"r3_weight",
110110
torch.tensor(
@@ -223,18 +223,20 @@ def forward_sha( # noqa: C901
223223
if self.use_qk_norm and self.qk_norm_before_rope:
224224
q[i] = self.q_norm_fn(q[i])
225225
q[i] = self.apply_rope_emb(q[i], freqs_cos, freqs_sin)
226-
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
227-
q[i] = torch.matmul(q[i], self.r3_weight)
228226
if self.use_qk_norm and not self.qk_norm_before_rope:
229227
q[i] = self.q_norm_fn(q[i])
228+
if getattr(self.config, "enable_r3", False):
229+
q[i] = torch.matmul(q[i], self.r3_weight)
230+
230231
for i in range(len(k)):
231232
if self.use_qk_norm and self.qk_norm_before_rope:
232233
k[i] = self.k_norm_fn(k[i])
233-
k[i] = self.apply_rope_emb(k[i], freqs_cos, freqs_sin).transpose(1, 2)
234-
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
235-
k[i] = torch.matmul(k[i], self.r3_weight)
234+
k[i] = self.apply_rope_emb(k[i], freqs_cos, freqs_sin)
236235
if self.use_qk_norm and not self.qk_norm_before_rope:
237236
k[i] = self.k_norm_fn(k[i])
237+
if getattr(self.config, "enable_r3", False):
238+
k[i] = torch.matmul(k[i], self.r3_weight)
239+
k[i] = k[i].transpose(1, 2)
238240

239241
output_y = []
240242
kh, vh = [], []

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
* @file
1111
*
1212
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Qwen2.5 0.5B, Qwen3 0.6B
13-
* / 1.7B phi4-mini-instruct with Qualcomm AI Engine Direct.
13+
* / 1.7B, phi4-mini-instruct with Qualcomm AI Engine Direct.
1414
*
1515
*/
1616

0 commit comments

Comments
 (0)