Skip to content

Commit 34410f4

Browse files
authored
Qualcomm AI Engine Direct - codegen2-1B (#15408)
Summary enable codgen2 1B $ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s -m --model_mode kv --max_seq_len 128 --decoder_model codegen2_1b --prompt "def hello world():" Test plan $ python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_codegen2_1b --model --build_folder build-android/ --executorch_root . -H -s @winskuo-quic @haowhsu-quic
1 parent 4a75896 commit 34410f4

File tree

18 files changed

+502
-68
lines changed

18 files changed

+502
-68
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
138138
weight = node.args[1]
139139
input_qspec_map[weight] = quantization_config.weight
140140

141+
if len(node.args) > 2 and isinstance(node.args[2], Node):
142+
input_qspec_map[node.args[2]] = quantization_config.bias(node)
143+
141144
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
142145
input_qspec_map=input_qspec_map,
143146
output_qspec=quantization_config.output_activation,

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5762,6 +5762,67 @@ def test_qnn_backend_seq_mse(self):
57625762

57635763

57645764
class TestExampleLLMScript(TestQNN):
5765+
def test_codegen2_1b(self):
5766+
if not self.required_envs():
5767+
self.skipTest("missing required envs")
5768+
5769+
prompt = "def hello_world():"
5770+
cmds = [
5771+
"python",
5772+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
5773+
"--artifact",
5774+
self.artifact_dir,
5775+
"--build_folder",
5776+
self.build_folder,
5777+
"--model",
5778+
self.model,
5779+
"--ip",
5780+
self.ip,
5781+
"--port",
5782+
str(self.port),
5783+
"--prompt",
5784+
prompt,
5785+
"--temperature",
5786+
"0",
5787+
"--decoder_model",
5788+
"codegen2_1b",
5789+
"--model_mode",
5790+
"kv",
5791+
"--max_seq_len",
5792+
"128",
5793+
]
5794+
if self.compile_only:
5795+
cmds.extend(["--compile_only"])
5796+
elif self.device:
5797+
cmds.extend(["--device", self.device])
5798+
if self.host:
5799+
cmds.extend(["--host", self.host])
5800+
elif self.enable_x86_64:
5801+
cmds.extend(["--enable_x86_64"])
5802+
if self.pre_gen_pte:
5803+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
5804+
5805+
golden_start_with = "def hello_world():"
5806+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
5807+
with Listener((self.ip, self.port)) as listener:
5808+
conn = listener.accept()
5809+
p.communicate()
5810+
msg = json.loads(conn.recv())
5811+
if "Error" in msg:
5812+
self.fail(msg["Error"])
5813+
else:
5814+
if not self.compile_only:
5815+
model_out = msg["result"][0]
5816+
self.assertTrue(
5817+
model_out.startswith(golden_start_with),
5818+
f"Expected Output: {golden_start_with}. Actual Output: {model_out}",
5819+
)
5820+
if not self.enable_x86_64:
5821+
pte_size = msg["pte_size"]
5822+
self.assertLessEqual(pte_size, 1_200_000_000) # 1200MB
5823+
if not self.compile_only and not self.enable_x86_64:
5824+
self.assertGreaterEqual(msg["inference_speed"], 60)
5825+
57655826
def test_static_gemma_2b(self):
57665827
if not self.required_envs():
57675828
self.skipTest("missing required envs")
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.examples.models.codegen.convert_weight import convert_weights
5+
from executorch.examples.models.llama.model import Llama2Model
6+
7+
8+
class CodeGenModel(Llama2Model):
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
13+
__all__ = [
14+
"CodegenModel",
15+
"convert_weights",
16+
]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"dim": 2048,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 8192,
5+
"n_heads": 16,
6+
"n_kv_heads": 16,
7+
"n_layers": 16,
8+
"vocab_size": 51200,
9+
"norm_eps": 1e-05,
10+
"max_seq_len": 2048,
11+
"bos_idx": 1,
12+
"eos_idx": 2,
13+
"model_architecture": "CodeGenModel",
14+
"use_hf_rope": true,
15+
"partial_rotary_factor": 0.5,
16+
"use_ffn_norm" : false,
17+
"norm_type": "layernorm",
18+
"output_bias": true
19+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import argparse
2+
import os
3+
from typing import Dict
4+
5+
import torch
6+
7+
from torchtune.models.convert_weights import get_mapped_key
8+
9+
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
10+
_HF__CODEGEN_2_FROM_META = {
11+
"tok_embeddings.weight": "transformer.wte.weight",
12+
"layers.{}.attention_norm.weight": "transformer.h.{}.ln_1.weight",
13+
"layers.{}.attention_norm.bias": "transformer.h.{}.ln_1.bias",
14+
"layers.{}.attention.wq.weight": "transformer.h.{}.attn.q_proj.weight",
15+
"layers.{}.attention.wk.weight": "transformer.h.{}.attn.k_proj.weight",
16+
"layers.{}.attention.wv.weight": "transformer.h.{}.attn.v_proj.weight",
17+
"layers.{}.attention.wo.weight": "transformer.h.{}.attn.out_proj.weight",
18+
"layers.{}.feed_forward.fc_in.weight": "transformer.h.{}.mlp.fc_in.weight",
19+
"layers.{}.feed_forward.fc_in.bias": "transformer.h.{}.mlp.fc_in.bias",
20+
"layers.{}.feed_forward.fc_out.weight": "transformer.h.{}.mlp.fc_out.weight",
21+
"layers.{}.feed_forward.fc_out.bias": "transformer.h.{}.mlp.fc_out.bias",
22+
"norm.weight": "transformer.ln_f.weight",
23+
"norm.bias": "transformer.ln_f.bias",
24+
"output.weight": "lm_head.weight",
25+
"output.bias": "lm_head.bias",
26+
}
27+
28+
29+
def codegen_hf_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
30+
converted_state_dict = {}
31+
keys_to_remove = []
32+
for key in state_dict:
33+
if ".attn.causal_mask" in key:
34+
keys_to_remove.append(key)
35+
for key in keys_to_remove:
36+
state_dict.pop(key)
37+
inverted_mapping_dict = {v: k for k, v in _HF__CODEGEN_2_FROM_META.items()}
38+
for key, value in state_dict.items():
39+
if key.endswith("attn.qkv_proj.weight"):
40+
mp_num = 8 # This number is from modeling_codegen.py
41+
dim, dim_kv = value.shape
42+
block = dim // mp_num
43+
split_size = block // 3
44+
45+
qkv_blocks = value.reshape(mp_num, block, dim_kv)
46+
q_blocks = qkv_blocks[:, 0:split_size, :]
47+
v_blocks = qkv_blocks[:, split_size : 2 * split_size, :]
48+
k_blocks = qkv_blocks[:, 2 * split_size : 3 * split_size, :]
49+
50+
q = q_blocks.reshape(-1, dim_kv)
51+
v = v_blocks.reshape(-1, dim_kv)
52+
k = k_blocks.reshape(-1, dim_kv)
53+
54+
for new_key, new_value in [("q_proj", q), ("k_proj", k), ("v_proj", v)]:
55+
new_key = key.replace("qkv_proj", new_key)
56+
new_key = get_mapped_key(new_key, inverted_mapping_dict)
57+
converted_state_dict[new_key] = new_value
58+
else:
59+
mapped_key = get_mapped_key(key, inverted_mapping_dict)
60+
converted_state_dict[mapped_key] = value
61+
62+
return converted_state_dict
63+
64+
65+
def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None:
66+
pt_path = os.path.join(input_dir_or_checkpoint, "pytorch_model.bin")
67+
print("Loading checkpoint from file...")
68+
sd = torch.load(pt_path, map_location="cpu", weights_only=True)
69+
print("Converting checkpoint...")
70+
sd = codegen_hf_to_meta(sd)
71+
72+
print("Saving checkpoint...")
73+
torch.save(sd, output_file)
74+
print("Done.")
75+
76+
77+
def main():
78+
parser = argparse.ArgumentParser(
79+
description="Convert Codegen weights to Meta format."
80+
)
81+
parser.add_argument(
82+
"input_dir",
83+
type=str,
84+
help="Path to directory containing checkpoint files, or path to a single checkpoint file.",
85+
)
86+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
87+
88+
args = parser.parse_args()
89+
convert_weights(args.input_dir, args.output)
90+
91+
92+
if __name__ == "__main__":
93+
main()

examples/models/llama/model_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,17 @@ class ModelArgs:
4646
head_dim: Optional[int] = None # Optional customized head_dim
4747
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
4848
ffn_dim_multiplier: Optional[float] = None
49+
model_architecture: str = (
50+
"LlamaForCausalLM" # This setting is currently only supported for the QNN backend
51+
)
4952
norm_eps: float = 1e-5
5053
post_attention_norm: bool = False
5154
post_ffn_norm: bool = False
5255
max_batch_size: int = 1
5356
max_seq_len: int = 2048
5457
max_context_len: int = 2048
58+
use_ffn_norm: bool = True
59+
output_bias: bool = False
5560
moe: bool = False # True to enable the MoE (Mixture of Experts)
5661
num_experts: int = 8 # Number of experts
5762
num_activated_experts: int = 2 # Number of experts to activate

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ This file provides you the instructions to run LLM Decoder model with different
55
1. LLAMA2 Stories 110M
66
2. LLAMA3.2 1B
77
3. LLAMA3.2 3B
8-
4. Gemma 2B
9-
5. Gemma3 1B
10-
6. Phi4-mini-instruct
11-
7. QWEN2.5 0.5B / 1.5B
12-
8. QWEN3 0.6B / 1.7B
13-
9. SmolLM2 135M
14-
10. SmolLM3 3B
8+
4. Codegen2 1B
9+
5. Gemma 2B
10+
6. Gemma3 1B
11+
7. Phi4-mini-instruct
12+
8. QWEN2.5 0.5B / 1.5B
13+
9. QWEN3 0.6B / 1.7B
14+
10. SmolLM2 135M
15+
11. SmolLM3 3B
1516

1617

1718
We offer the following modes to execute the model:
@@ -80,6 +81,12 @@ Default example using kv mode.
8081
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
8182
```
8283

84+
#### Codegen2
85+
Default example using kv mode.
86+
```bash
87+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model codegen2_1b --model_mode kv --max_seq_len 1024 --prompt "def hello_world():"
88+
```
89+
8390
#### Gemma 2B
8491
Default example using hybrid mode
8592
```bash
@@ -135,7 +142,6 @@ Default example using kv mode.
135142
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
136143
```
137144

138-
139145
### KV Cache update mechanism
140146
We have two distinct mechanisms for updating the key-value (KV) cache, which can be selected at runtime. Shift Pointer and Smart Mask.
141147

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
get_ptq_per_channel_quant_config,
2424
)
2525
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
26+
from executorch.examples.models.codegen import (
27+
convert_weights as convert_codegen_weights,
28+
)
2629

2730
from executorch.examples.models.gemma import convert_weights as convert_gemma_weights
2831
from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights
@@ -331,6 +334,28 @@ class Gemma_2B(LLMModelConfig):
331334
)
332335

333336

337+
@register_llm_model("codegen2_1b")
338+
@dataclass(init=False, frozen=True)
339+
class Codegen(LLMModelConfig):
340+
repo_id: str = "Salesforce/codegen2-1B_P"
341+
params_path: str = os.path.join(
342+
BASE_DIR, "../../../models/codegen/config/config.json"
343+
)
344+
convert_weights = convert_codegen_weights
345+
transform_weight = True
346+
instruct_model = False
347+
num_sharding = 1
348+
# quant config
349+
ptq = QuantDtype.use_16a8w
350+
group_size = None
351+
masked_softmax = True
352+
seq_mse_candidates = 0
353+
r1 = False
354+
r2 = False
355+
r3 = False
356+
custom_annotation = ()
357+
358+
334359
@register_llm_model("gemma3-1b")
335360
@dataclass(init=False, frozen=True)
336361
class Gemma3(LLMModelConfig):

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@
2525
"qwen3-1_7b": "qwen3",
2626
"smollm2_135m": "smollm2_135m",
2727
"smollm3-3b": "smollm3",
28+
"codegen2_1b": "codegen",
2829
}

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,6 @@ def compile(
445445
kv_config.use_kv_cache = True
446446
kv_config.enable_r3 = decoder_model_config.r3
447447
kv_config.kv_io_bit_width = decoder_model_config.get_kv_io_bit_width()
448-
449448
if decoder_model_config.masked_softmax:
450449
if is_qnn_sdk_version_less_than("2.35"):
451450
logging.warning(
@@ -561,25 +560,32 @@ def compile(
561560

562561
if decoder_model_config.transform_weight:
563562
# Change to HuggingFace weight to improve the performance of RoPE in HTP backend.
564-
def permute(w, heads):
563+
def permute(w, heads, partial_rotary_dim):
565564
dim_0 = w.size(0)
566565
dim_1 = w.size(1)
567-
return (
568-
w.view(heads, dim_0 // heads // 2, 2, dim_1)
569-
.transpose(1, 2)
566+
transformed_weight = (
567+
w.view(heads, -1, dim_0 // heads // 2 // partial_rotary_dim, 2, dim_1)
568+
.transpose(2, 3)
570569
.reshape(dim_0, dim_1)
571570
)
571+
return transformed_weight
572572

573573
n_heads = llama_instance_list[0].n_heads
574574
n_kv_heads = llama_instance_list[0].n_kv_heads
575575
n_layers = llama_instance_list[0].n_layers
576-
576+
partial_rotary_dim = int(
577+
1 // kv_config.partial_rotary_factor
578+
) # TODO Handle cases where input size isn't divisible.
577579
for layer_i in range(n_layers):
578580
state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute(
579-
state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads
581+
state_dict[f"layers.{layer_i}.attention.wq.weight"],
582+
n_heads,
583+
partial_rotary_dim,
580584
)
581585
state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute(
582-
state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads
586+
state_dict[f"layers.{layer_i}.attention.wk.weight"],
587+
n_kv_heads,
588+
partial_rotary_dim,
583589
)
584590

585591
for llama_instance in llama_instance_list:
@@ -648,6 +654,7 @@ def permute(w, heads):
648654
for layer in llama_instance.layers:
649655
if getattr(layer.attention, "prepare_sha", None):
650656
layer.attention.prepare_sha()
657+
651658
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
652659
layer.feed_forward.prepare_feedfoward_conv()
653660

@@ -1293,8 +1300,13 @@ def export_llama(args) -> None:
12931300
runtime_tokenizer_path = tokenizer_artifacts[-1]
12941301
tokenizer = get_tokenizer(runtime_tokenizer_path, tokenizer_config)
12951302

1303+
if args.decoder_model == "codegen2_1b":
1304+
# Override the default BOS and EOS token IDs for codegen2_1b
1305+
tokenizer.bos_id = 1
1306+
tokenizer.eos_id = 2
1307+
12961308
# TODO: Remove this once error is resolved.
1297-
if args.decoder_model == "phi_4_mini":
1309+
elif args.decoder_model == "phi_4_mini":
12981310
with open(runtime_tokenizer_path, "r+") as file:
12991311
data = json.load(file)
13001312
# TODO: Encountered the following error during runtime, so switched behavior for now.

0 commit comments

Comments
 (0)