From cf62b4867a0e7cfd2979517b2526b113c2e83db6 Mon Sep 17 00:00:00 2001 From: thchenqti Date: Wed, 3 Sep 2025 16:51:52 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - codegen2-1B --- .../qualcomm/quantizer/custom_annotation.py | 3 + backends/qualcomm/tests/test_qnn_delegate.py | 61 ++++++++++++ examples/models/codegen/__init__.py | 16 ++++ examples/models/codegen/config/config.json | 19 ++++ examples/models/codegen/convert_weight.py | 93 +++++++++++++++++++ examples/models/llama/model_args.py | 5 + examples/qualcomm/oss_scripts/llama/README.md | 22 +++-- .../qualcomm/oss_scripts/llama/__init__.py | 25 +++++ .../oss_scripts/llama/decoder_constants.py | 1 + examples/qualcomm/oss_scripts/llama/llama.py | 28 ++++-- .../oss_scripts/llama/model/__init__.py | 16 ++++ .../oss_scripts/llama/model/apply_rope.py | 52 +++++++++++ .../oss_scripts/llama/model/feed_forward.py | 90 ++++++++++++++++++ .../oss_scripts/llama/model/layernorm.py | 48 ++++++++++ .../oss_scripts/llama/model/static_llama.py | 77 ++++++--------- .../oss_scripts/llama/qnn_llama_runner.cpp | 5 +- .../oss_scripts/llama/runner/runner.cpp | 4 + .../oss_scripts/llama/runner/runner.h | 3 +- 18 files changed, 500 insertions(+), 68 deletions(-) create mode 100644 examples/models/codegen/__init__.py create mode 100644 examples/models/codegen/config/config.json create mode 100644 examples/models/codegen/convert_weight.py create mode 100644 examples/qualcomm/oss_scripts/llama/model/__init__.py create mode 100644 examples/qualcomm/oss_scripts/llama/model/apply_rope.py create mode 100644 examples/qualcomm/oss_scripts/llama/model/feed_forward.py create mode 100644 examples/qualcomm/oss_scripts/llama/model/layernorm.py diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 3f10dbaa3fc..c592ad64da6 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -138,6 +138,9 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None weight = node.args[1] input_qspec_map[weight] = quantization_config.weight + if len(node.args) > 2 and isinstance(node.args[2], Node): + input_qspec_map[node.args[2]] = quantization_config.bias(node) + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, output_qspec=quantization_config.output_activation, diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 1c7e63f1bf4..97426cd9a7e 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -5692,6 +5692,67 @@ def test_qnn_backend_seq_mse(self): class TestExampleLLMScript(TestQNN): + def test_codegen2_1b(self): + if not self.required_envs(): + self.skipTest("missing required envs") + + prompt = "def hello_world():" + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--prompt", + prompt, + "--temperature", + "0", + "--decoder_model", + "codegen2_1b", + "--model_mode", + "kv", + "--max_seq_len", + "128", + ] + if self.compile_only: + cmds.extend(["--compile_only"]) + elif self.device: + cmds.extend(["--device", self.device]) + if self.host: + cmds.extend(["--host", self.host]) + elif self.enable_x86_64: + cmds.extend(["--enable_x86_64"]) + if self.pre_gen_pte: + cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + + golden_start_with = "def hello_world():" + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + if not self.compile_only: + model_out = msg["result"][0] + self.assertTrue( + model_out.startswith(golden_start_with), + f"Expected Output: {golden_start_with}. Actual Output: {model_out}", + ) + if not self.enable_x86_64: + pte_size = msg["pte_size"] + self.assertLessEqual(pte_size, 1_200_000_000) # 1200MB + if not self.compile_only and not self.enable_x86_64: + self.assertGreaterEqual(msg["inference_speed"], 60) + def test_static_gemma_2b(self): if not self.required_envs(): self.skipTest("missing required envs") diff --git a/examples/models/codegen/__init__.py b/examples/models/codegen/__init__.py new file mode 100644 index 00000000000..359e3bd6243 --- /dev/null +++ b/examples/models/codegen/__init__.py @@ -0,0 +1,16 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.codegen.convert_weight import convert_weights +from executorch.examples.models.llama.model import Llama2Model + + +class CodeGenModel(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "CodegenModel", + "convert_weights", +] diff --git a/examples/models/codegen/config/config.json b/examples/models/codegen/config/config.json new file mode 100644 index 00000000000..4a9b40a1154 --- /dev/null +++ b/examples/models/codegen/config/config.json @@ -0,0 +1,19 @@ +{ + "dim": 2048, + "ffn_dim_multiplier": 1, + "hidden_dim": 8192, + "n_heads": 16, + "n_kv_heads": 16, + "n_layers": 16, + "vocab_size": 51200, + "norm_eps": 1e-05, + "max_seq_len": 2048, + "bos_idx": 1, + "eos_idx": 2, + "model_architecture": "CodeGenModel", + "use_hf_rope": true, + "partial_rotary_factor": 0.5, + "use_ffn_norm" : false, + "norm_type": "layernorm", + "output_bias": true +} diff --git a/examples/models/codegen/convert_weight.py b/examples/models/codegen/convert_weight.py new file mode 100644 index 00000000000..5f98128038e --- /dev/null +++ b/examples/models/codegen/convert_weight.py @@ -0,0 +1,93 @@ +import argparse +import os +from typing import Dict + +import torch + +from torchtune.models.convert_weights import get_mapped_key + +# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. +_HF__CODEGEN_2_FROM_META = { + "tok_embeddings.weight": "transformer.wte.weight", + "layers.{}.attention_norm.weight": "transformer.h.{}.ln_1.weight", + "layers.{}.attention_norm.bias": "transformer.h.{}.ln_1.bias", + "layers.{}.attention.wq.weight": "transformer.h.{}.attn.q_proj.weight", + "layers.{}.attention.wk.weight": "transformer.h.{}.attn.k_proj.weight", + "layers.{}.attention.wv.weight": "transformer.h.{}.attn.v_proj.weight", + "layers.{}.attention.wo.weight": "transformer.h.{}.attn.out_proj.weight", + "layers.{}.feed_forward.fc_in.weight": "transformer.h.{}.mlp.fc_in.weight", + "layers.{}.feed_forward.fc_in.bias": "transformer.h.{}.mlp.fc_in.bias", + "layers.{}.feed_forward.fc_out.weight": "transformer.h.{}.mlp.fc_out.weight", + "layers.{}.feed_forward.fc_out.bias": "transformer.h.{}.mlp.fc_out.bias", + "norm.weight": "transformer.ln_f.weight", + "norm.bias": "transformer.ln_f.bias", + "output.weight": "lm_head.weight", + "output.bias": "lm_head.bias", +} + + +def codegen_hf_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + converted_state_dict = {} + keys_to_remove = [] + for key in state_dict: + if ".attn.causal_mask" in key: + keys_to_remove.append(key) + for key in keys_to_remove: + state_dict.pop(key) + inverted_mapping_dict = {v: k for k, v in _HF__CODEGEN_2_FROM_META.items()} + for key, value in state_dict.items(): + if key.endswith("attn.qkv_proj.weight"): + mp_num = 8 # This number is from modeling_codegen.py + dim, dim_kv = value.shape + block = dim // mp_num + split_size = block // 3 + + qkv_blocks = value.reshape(mp_num, block, dim_kv) + q_blocks = qkv_blocks[:, 0:split_size, :] + v_blocks = qkv_blocks[:, split_size : 2 * split_size, :] + k_blocks = qkv_blocks[:, 2 * split_size : 3 * split_size, :] + + q = q_blocks.reshape(-1, dim_kv) + v = v_blocks.reshape(-1, dim_kv) + k = k_blocks.reshape(-1, dim_kv) + + for new_key, new_value in [("q_proj", q), ("k_proj", k), ("v_proj", v)]: + new_key = key.replace("qkv_proj", new_key) + new_key = get_mapped_key(new_key, inverted_mapping_dict) + converted_state_dict[new_key] = new_value + else: + mapped_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[mapped_key] = value + + return converted_state_dict + + +def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None: + pt_path = os.path.join(input_dir_or_checkpoint, "pytorch_model.bin") + print("Loading checkpoint from file...") + sd = torch.load(pt_path, map_location="cpu") + print("Converting checkpoint...") + sd = codegen_hf_to_meta(sd) + + print("Saving checkpoint...") + torch.save(sd, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Codegen weights to Meta format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing checkpoint files, or path to a single checkpoint file.", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 20663c81e7d..3f82286b8ed 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -46,12 +46,17 @@ class ModelArgs: head_dim: Optional[int] = None # Optional customized head_dim multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None + model_architecture: str = ( + "LlamaForCausalLM" # This setting is currently only supported for the QNN backend + ) norm_eps: float = 1e-5 post_attention_norm: bool = False post_ffn_norm: bool = False max_batch_size: int = 1 max_seq_len: int = 2048 max_context_len: int = 2048 + use_ffn_norm: bool = True + output_bias: bool = False moe: bool = False # True to enable the MoE (Mixture of Experts) num_experts: int = 8 # Number of experts num_activated_experts: int = 2 # Number of experts to activate diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index be25324d63d..e6fa9a66e26 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -5,13 +5,14 @@ This file provides you the instructions to run LLM Decoder model with different 1. LLAMA2 Stories 110M 2. LLAMA3.2 1B 3. LLAMA3.2 3B - 4. Gemma 2B - 5. Gemma3 1B - 6. Phi4-mini-instruct - 7. QWEN2.5 0.5B / 1.5B - 8. QWEN3 0.6B / 1.7B - 9. SmolLM2 135M - 10. SmolLM3 3B + 4. Codegen2 1B + 5. Gemma 2B + 6. Gemma3 1B + 7. Phi4-mini-instruct + 8. QWEN2.5 0.5B / 1.5B + 9. QWEN3 0.6B / 1.7B + 10. SmolLM2 135M + 11. SmolLM3 3B We offer the following modes to execute the model: @@ -80,6 +81,12 @@ Default example using kv mode. python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` +#### Codegen2 +Default example using kv mode. +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model codegen2_1b --model_mode kv --max_seq_len 1024 --prompt "def hello_world():" +``` + #### Gemma 2B Default example using hybrid mode ```bash @@ -135,7 +142,6 @@ Default example using kv mode. python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` - ### KV Cache update mechanism We have two distinct mechanisms for updating the key-value (KV) cache, which can be selected at runtime. Shift Pointer and Smart Mask. diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py index 628defc1496..e2407e6812a 100644 --- a/examples/qualcomm/oss_scripts/llama/__init__.py +++ b/examples/qualcomm/oss_scripts/llama/__init__.py @@ -23,6 +23,9 @@ get_ptq_per_channel_quant_config, ) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.models.codegen import ( + convert_weights as convert_codegen_weights, +) from executorch.examples.models.gemma import convert_weights as convert_gemma_weights from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights @@ -331,6 +334,28 @@ class Gemma_2B(LLMModelConfig): ) +@register_llm_model("codegen2_1b") +@dataclass(init=False, frozen=True) +class Codegen(LLMModelConfig): + repo_id: str = "Salesforce/codegen2-1B_P" + params_path: str = os.path.join( + BASE_DIR, "../../../models/codegen/config/config.json" + ) + convert_weights = convert_codegen_weights + transform_weight = True + instruct_model = False + num_sharding = 1 + # quant config + ptq = QuantDtype.use_16a8w + group_size = None + masked_softmax = True + seq_mse_candidates = 0 + r1 = False + r2 = False + r3 = False + custom_annotation = () + + @register_llm_model("gemma3-1b") @dataclass(init=False, frozen=True) class Gemma3(LLMModelConfig): diff --git a/examples/qualcomm/oss_scripts/llama/decoder_constants.py b/examples/qualcomm/oss_scripts/llama/decoder_constants.py index d43ceb8351a..c7e7c0cb944 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_constants.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_constants.py @@ -25,4 +25,5 @@ "qwen3-1_7b": "qwen3", "smollm2_135m": "smollm2_135m", "smollm3-3b": "smollm3", + "codegen2_1b": "codegen", } diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 887e680341f..73fe45f1c60 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -445,7 +445,6 @@ def compile( kv_config.use_kv_cache = True kv_config.enable_r3 = decoder_model_config.r3 kv_config.kv_io_bit_width = decoder_model_config.get_kv_io_bit_width() - if decoder_model_config.masked_softmax: if is_qnn_sdk_version_less_than("2.35"): logging.warning( @@ -561,25 +560,30 @@ def compile( if decoder_model_config.transform_weight: # Change to HuggingFace weight to improve the performance of RoPE in HTP backend. - def permute(w, heads): + def permute(w, heads, partial_rotary_dim): dim_0 = w.size(0) dim_1 = w.size(1) - return ( - w.view(heads, dim_0 // heads // 2, 2, dim_1) - .transpose(1, 2) + transformed_weight = ( + w.view(heads, -1, dim_0 // heads // 2 // partial_rotary_dim, 2, dim_1) + .transpose(2, 3) .reshape(dim_0, dim_1) ) + return transformed_weight n_heads = llama_instance_list[0].n_heads n_kv_heads = llama_instance_list[0].n_kv_heads n_layers = llama_instance_list[0].n_layers - + partial_rotary_dim = int(1 // kv_config.partial_rotary_factor) # TODO Handle cases where input size isn't divisible. for layer_i in range(n_layers): state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute( - state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads + state_dict[f"layers.{layer_i}.attention.wq.weight"], + n_heads, + partial_rotary_dim, ) state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute( - state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads + state_dict[f"layers.{layer_i}.attention.wk.weight"], + n_kv_heads, + partial_rotary_dim, ) for llama_instance in llama_instance_list: @@ -648,6 +652,7 @@ def permute(w, heads): for layer in llama_instance.layers: if getattr(layer.attention, "prepare_sha", None): layer.attention.prepare_sha() + if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): layer.feed_forward.prepare_feedfoward_conv() @@ -1299,8 +1304,13 @@ def export_llama(args) -> None: runtime_tokenizer_path = tokenizer_artifacts[-1] tokenizer = get_tokenizer(runtime_tokenizer_path, tokenizer_config) + if args.decoder_model == "codegen2_1b": + # Override the default BOS and EOS token IDs for codegen2_1b + tokenizer.bos_id = 1 + tokenizer.eos_id = 2 + # TODO: Remove this once error is resolved. - if args.decoder_model == "phi_4_mini": + elif args.decoder_model == "phi_4_mini": with open(runtime_tokenizer_path, "r+") as file: data = json.load(file) # TODO: Encountered the following error during runtime, so switched behavior for now. diff --git a/examples/qualcomm/oss_scripts/llama/model/__init__.py b/examples/qualcomm/oss_scripts/llama/model/__init__.py new file mode 100644 index 00000000000..ea94730615c --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/model/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .apply_rope import ROTARY_EMB_REGISTRY +from .feed_forward import FeedForward_REGISTRY +from .layernorm import NORM_REGISTRY + + +__all__ = [ + FeedForward_REGISTRY, + ROTARY_EMB_REGISTRY, + NORM_REGISTRY, +] diff --git a/examples/qualcomm/oss_scripts/llama/model/apply_rope.py b/examples/qualcomm/oss_scripts/llama/model/apply_rope.py new file mode 100644 index 00000000000..6d011c47336 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/model/apply_rope.py @@ -0,0 +1,52 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Callable, Dict + +import torch + + +ROTARY_EMB_REGISTRY: Dict[str, Callable] = {} + + +def register_rotary_emb(name: str): + """Register a rotary embedding function.""" + + def decorator(fn: Callable): + ROTARY_EMB_REGISTRY[name] = fn + return fn + + return decorator + + +@register_rotary_emb("partial") +def apply_partial_rotary_emb_single(x, freqs_cos, freqs_sin): + if x.dim() == 4: + freqs_cos = freqs_cos[None, :, None, :] + freqs_sin = freqs_sin[None, :, None, :] + rotary_dim = freqs_cos.shape[-1] * 2 + x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:] + x_r, x_i = x_rot[..., : x_rot.shape[-1] // 2], x_rot[..., x_rot.shape[-1] // 2 :] + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + x_rotated = torch.cat([x_out_r, x_out_i], dim=-1) + return torch.cat([x_rotated, x_pass], dim=-1) + + +@register_rotary_emb("default") +def apply_rotary_emb_single(x, freqs_cos, freqs_sin): + # The implementation of RoPE in HuggingFace processes query and key with two half instead of interleaved way. + # The main difference is stride in StrideSlice op. For interleaved way, stride is two which is not friendly for HTP backend. + # Ref: https://github.com/huggingface/transformers/issues/25199 + x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + # broadcast for batch_prefill mode input x + if x.dim() == 4: + freqs_cos = freqs_cos[None, :, None, :] + freqs_sin = freqs_sin[None, :, None, :] + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out diff --git a/examples/qualcomm/oss_scripts/llama/model/feed_forward.py b/examples/qualcomm/oss_scripts/llama/model/feed_forward.py new file mode 100644 index 00000000000..062123b52cc --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/model/feed_forward.py @@ -0,0 +1,90 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from abc import ABC, abstractmethod +from typing import Dict, Type + +import torch +from executorch.examples.models.llama.model_args import ModelArgs +from transformers.activations import GELUActivation + + +class FeedForwardBase(torch.nn.Module, ABC): + """Abstract base class for feed forward layers with unified interface.""" + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for feed forward layer. + Args: + x: Input tensor + Returns: + Output tensor + """ + pass + + +FeedForward_REGISTRY: Dict[str, Type[FeedForwardBase]] = {} + + +def register_feed_forward(name: str): + """Decorator to register norm classes""" + + def decorator(cls: Type[FeedForwardBase]): + FeedForward_REGISTRY[name] = cls + return cls + + return decorator + + +@register_feed_forward("CodeGenModel") +class CodegenFeedForward(FeedForwardBase): + """FeedForward with fc_in and fc_out""" + + def __init__(self, args: ModelArgs): # in MLP: intermediate_size= 4 * embed_dim + super().__init__() + + assert args.hidden_dim is not None + self.dim = args.dim + self.hidden_dim: int = args.hidden_dim + + self.fc_in = torch.nn.Linear(self.dim, self.hidden_dim) + self.fc_out = torch.nn.Linear(self.hidden_dim, self.dim) + # HF uses NewGelu, however, Gelu is a fused op in QNN and can run faster + self.act = GELUActivation(use_gelu_python=False) + + def prepare_feedfoward_conv(self): + intermediate_size = 4 * self.dim + self.fc_in_conv = torch.nn.Conv2d(self.dim, intermediate_size, 1, bias=True) + self.fc_out_conv = torch.nn.Conv2d(self.hidden_dim, self.dim, 1, bias=True) + + self.forward_no_conv = self.forward + self.forward = self.forward_feedfoward_conv + + self.fc_in_conv.weight.data.copy_(self.fc_in.weight[:, :, None, None]) + self.fc_out_conv.weight.data.copy_(self.fc_out.weight[:, :, None, None]) + + self.fc_in_conv.bias.data.copy_(self.fc_in.bias) + self.fc_out_conv.bias.data.copy_(self.fc_out.bias) + + del self.fc_in + del self.fc_out + + def forward_feedfoward_conv(self, x): + bsz, _, _ = x.size() + + x = torch.reshape(x, (bsz, -1, 1, self.dim)) + x = x.transpose(1, 3) # Transpose right before and after Conv + x = self.fc_in_conv(x) + x = self.act(x) + x = self.fc_out_conv(x) + x = x.transpose(1, 3) + x = torch.reshape(x, (bsz, -1, self.dim)) + return x + + def forward(self, x): + hidden_states = self.fc_in(x) + hidden_states = self.act(hidden_states) + hidden_states = self.fc_out(hidden_states) + return hidden_states diff --git a/examples/qualcomm/oss_scripts/llama/model/layernorm.py b/examples/qualcomm/oss_scripts/llama/model/layernorm.py new file mode 100644 index 00000000000..a6c12920ed8 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/model/layernorm.py @@ -0,0 +1,48 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from abc import ABC, abstractmethod +from typing import Dict, Type + +import torch + + +class Norm(torch.nn.Module, ABC): + """Abstract base class for normalization layers with unified interface.""" + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for normalization layer. + Args: + x: Input tensor + Returns: + Normalized tensor + """ + pass + + +NORM_REGISTRY: Dict[str, Type[Norm]] = {} + + +def register_norm(name: str): + """Decorator to register norm classes""" + + def decorator(cls: Type[Norm]): + NORM_REGISTRY[name] = cls + return cls + + return decorator + + +@register_norm("layernorm") +class LayerNorm(torch.nn.LayerNorm, Norm): + def __init__(self, hidden_size: int, eps=1e-5): + super().__init__(hidden_size, eps=eps) + + +@register_norm("rmsnorm") +class RMSNorm(torch.nn.RMSNorm, Norm): + def __init__(self, hidden_size: int, eps=1e-5): + super().__init__(hidden_size, eps=eps) diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index 7406b13ee8c..ba2d33d7890 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -13,6 +13,7 @@ import scipy import torch import torch.nn as nn + from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import ( hf_precompute_freqs_cis, @@ -23,42 +24,11 @@ CausalAttentionMask, SlidingWindowAttentionMask, ) - - -def apply_rotary_emb_single( - x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor -) -> torch.Tensor: - # The implementation of RoPE in HuggingFace processes query and key with two half instead of interleaved way. - # The main difference is stride in StrideSlice op. For interleaved way, stride is two which is not friendly for HTP backend. - # Ref: https://github.com/huggingface/transformers/issues/25199 - x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - # broadcast for batch_prefill mode input x - if x.dim() == 4: - freqs_cos = freqs_cos[None, :, None, :] - freqs_sin = freqs_sin[None, :, None, :] - x_out_r = x_r * freqs_cos - x_i * freqs_sin - x_out_i = x_r * freqs_sin + x_i * freqs_cos - - x_out = torch.cat([x_out_r, x_out_i], dim=-1) - return x_out - - -def apply_partial_rotary_emb_single( - x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor -) -> torch.Tensor: - - if x.dim() == 4: - freqs_cos = freqs_cos[None, :, None, :] - freqs_sin = freqs_sin[None, :, None, :] - - rotary_dim = freqs_cos.shape[-1] * 2 - - x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:] - x_r, x_i = x_rot[..., : x_rot.shape[-1] // 2], x_rot[..., x_rot.shape[-1] // 2 :] - x_out_r = x_r * freqs_cos - x_i * freqs_sin - x_out_i = x_r * freqs_sin + x_i * freqs_cos - x_rotated = torch.cat([x_out_r, x_out_i], dim=-1) - return torch.cat([x_rotated, x_pass], dim=-1) +from executorch.examples.qualcomm.oss_scripts.llama.model import ( + FeedForward_REGISTRY, + NORM_REGISTRY, + ROTARY_EMB_REGISTRY, +) class LlamaAttention(nn.Module): @@ -88,9 +58,9 @@ def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=Fals self.k_norm_fn = torch.nn.RMSNorm(k_norm_dim, eps=config.norm_eps) if config.partial_rotary_factor < 1: - self.apply_rope_emb = apply_partial_rotary_emb_single + self.apply_rope_emb = ROTARY_EMB_REGISTRY["partial"] else: - self.apply_rope_emb = apply_rotary_emb_single + self.apply_rope_emb = ROTARY_EMB_REGISTRY["default"] self.wq = nn.Linear( self.dim, @@ -227,7 +197,6 @@ def forward_sha( # noqa: C901 .reshape(bsz, seq_len, self.head_dim) for wv_sha in self.wv_sha ] - for i in range(len(q)): if self.use_qk_norm and self.qk_norm_before_rope: q[i] = self.q_norm_fn(q[i]) @@ -411,9 +380,19 @@ def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=Fals config=config, output_new_cache_only=output_new_cache_only, ) - self.feed_forward = FeedForward(config) - self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) - self.ffn_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) + + self.feed_forward = FeedForward_REGISTRY.get( + config.model_architecture, FeedForward + )(config) + self.attention_norm = NORM_REGISTRY[config.norm_type]( + config.dim, eps=config.norm_eps + ) + self.ffn_norm = ( + NORM_REGISTRY[config.norm_type](config.dim, eps=config.norm_eps) + if config.use_ffn_norm + else None + ) + self.post_attention_norm = ( torch.nn.RMSNorm(config.dim, eps=config.norm_eps) if config.post_attention_norm @@ -434,8 +413,10 @@ def forward( k_caches: List[torch.Tensor], v_caches: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + hidden_states = self.attention_norm(x) h, k_cache, v_cache = self.attention( - hidden_states=self.attention_norm(x), + hidden_states=hidden_states, freqs_cos=freqs_cos, freqs_sin=freqs_sin, atten_mask=atten_mask, @@ -445,10 +426,12 @@ def forward( if self.post_attention_norm: h = self.post_attention_norm(h) h = x + h - out = self.feed_forward(self.ffn_norm(h)) + hidden_states = hidden_states if self.ffn_norm is None else self.ffn_norm(h) + out = self.feed_forward(hidden_states) if self.post_ffn_norm: out = self.post_ffn_norm(out) output = h + out + return output, k_cache, v_cache @@ -486,8 +469,9 @@ def __init__( for i in range(config.n_layers) ] ) - self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) - self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + self.norm = NORM_REGISTRY[config.norm_type](config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=config.output_bias) + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) if config.use_hf_rope: freqs_cos, freqs_sin = hf_precompute_freqs_cis( @@ -532,7 +516,6 @@ def forward( input_pos: Optional[torch.Tensor] = None, *args, ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: - output_k_cache = [] output_v_cache = [] # following tensors should be invariant across batches diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 2bffb35852a..52796e886fd 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -103,6 +103,8 @@ std::string get_formatted_prompt( std::string formatted_prompt; switch (decoder_model_version) { case example::DecoderModelVersion::kLlama2: + case example::DecoderModelVersion::kQwen2_5: + case example::DecoderModelVersion::kCodegen: formatted_prompt.append(prompt); break; case example::DecoderModelVersion::kLlama3: @@ -138,9 +140,6 @@ std::string get_formatted_prompt( formatted_prompt.append(prompt); formatted_prompt.append("<|end|><|assistant|>"); break; - case example::DecoderModelVersion::kQwen2_5: - formatted_prompt.append(prompt); - break; case example::DecoderModelVersion::kQwen3: formatted_prompt.append("<|im_start|>user\n"); formatted_prompt.append(prompt); diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 0c4884bbccf..e239a2a5fe1 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -137,6 +137,8 @@ Runner::Runner( decoder_model_version_ = DecoderModelVersion::kSmollm2_135m; } else if (decoder_model_version == "smollm3") { decoder_model_version_ = DecoderModelVersion::kSmollm3; + } else if (decoder_model_version == "codegen") { + decoder_model_version_ = DecoderModelVersion::kCodegen; } else { ET_CHECK_MSG(false, "Unsupported Decoder Model"); } @@ -205,6 +207,8 @@ Error Runner::load() { decoder_model_version_ == DecoderModelVersion::kGemma || decoder_model_version_ == DecoderModelVersion::kGemma3) { eos_ids->insert(tokenizer_->encode("", 0, 0).get()[0]); + } else if (decoder_model_version_ == DecoderModelVersion::kCodegen) { + eos_ids->insert(tokenizer_->encode("<|endoftext|>", 0, 0).get()[0]); } // Try avoid getMetadataHelper as it is time consuming. diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 1472093ab66..9cf730c3620 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -38,7 +38,8 @@ enum DecoderModelVersion { kQwen2_5, kQwen3, kSmollm2_135m, - kSmollm3 + kSmollm3, + kCodegen, }; enum KvBitWidth {