diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 9035a7a458c..7e492fa3e30 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -4638,6 +4638,77 @@ def test_qnn_backend_generate_optrace(self): class TestExampleLLMScript(TestQNN): + def test_static_gemma3_1b(self): + if not self.required_envs(): + self.skipTest("missing required envs") + + prompt = "My favourite condiment is " + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + "--prompt", + f"{prompt}", + "--ptq", + "16a4w_block", + "--temperature", + "0", + "--decoder_model", + "gemma3-1b", + "--model_mode", + "kv", + "--max_seq_len", + "1024", + "--eval_perplexity", + "--tasks", + "wikitext", + "--limit", + "1", + "--enable_masked_softmax", + ] + if self.compile_only: + cmds.extend(["--compile_only"]) + elif self.device: + cmds.extend(["--device", self.device]) + if self.host: + cmds.extend(["--host", self.host]) + elif self.enable_x86_64: + cmds.extend(["--enable_x86_64"]) + if self.pre_gen_pte: + cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + if not self.compile_only: + self.assertLessEqual(msg["wiki_ppl"], 23) + if not self.enable_x86_64: + pte_size = msg["pte_size"] + self.assertLessEqual(pte_size, 1_200_000_000) # 1.2GB + inference_speed_ref = {"SM8650": 70, "SM8750": 100} + if ( + not self.compile_only + and not self.enable_x86_64 + and self.model in inference_speed_ref + ): + self.assertGreaterEqual( + msg["inference_speed"], inference_speed_ref[self.model] + ) + def test_llama3_2_1b(self): if not self.required_envs(): self.skipTest("missing required envs") @@ -4708,7 +4779,7 @@ def test_llama3_2_1b(self): # Inference speed on x86 is slow, so we only check when running on Android if not self.enable_x86_64: pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 1300000000) + self.assertLessEqual(pte_size, 1_300_000_000) # 1.3GB if not self.compile_only and not self.enable_x86_64: self.assertGreaterEqual(msg["inference_speed"], 66) # Lanai @@ -4784,7 +4855,7 @@ def test_llama_stories_260k(self): # x86 does not allow weight sharing, so we don't check pte size if not self.enable_x86_64: pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 2020000) + self.assertLessEqual(pte_size, 2_020_000) # 2MB if not self.compile_only and not self.enable_x86_64: self.assertGreaterEqual(msg["inference_speed"], 1600) # Lanai @@ -4859,7 +4930,7 @@ def test_llama_stories_110m(self): # x86 does not allow weight sharing, so we don't check pte size if not self.enable_x86_64: pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 130000000) + self.assertLessEqual(pte_size, 130_000_000) # 130MB if not self.compile_only and not self.enable_x86_64: self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai @@ -4922,7 +4993,7 @@ def test_static_phi4(self): else: inference_speed_ref = {"SM8650": 14, "SM8750": 19} self.assertLessEqual(msg["wiki_ppl"], 12) - self.assertLessEqual(msg["pte_size"], 4000000000) # 4gb + self.assertLessEqual(msg["pte_size"], 4_000_000_000) # 4GB if self.model in inference_speed_ref: self.assertGreaterEqual( msg["inference_speed"], inference_speed_ref[self.model] @@ -4981,7 +5052,7 @@ def test_static_qwen2_5(self): else: inference_speed_ref = {"SM8650": 115, "SM8750": 155} self.assertLessEqual(msg["wiki_ppl"], 15) - self.assertLessEqual(msg["pte_size"], 600000000) # 600mb + self.assertLessEqual(msg["pte_size"], 600_000_000) # 600MB if self.model in inference_speed_ref: self.assertGreaterEqual( msg["inference_speed"], inference_speed_ref[self.model] @@ -5040,7 +5111,7 @@ def test_static_qwen3(self): else: inference_speed_ref = {"SM8650": 38, "SM8750": 56} self.assertLessEqual(msg["wiki_ppl"], 18) - self.assertLessEqual(msg["pte_size"], 950_000_000) # 950mb + self.assertLessEqual(msg["pte_size"], 950_000_000) # 950MB if self.model in inference_speed_ref: self.assertGreaterEqual( msg["inference_speed"], inference_speed_ref[self.model] diff --git a/examples/models/gemma3/__init__.py b/examples/models/gemma3/__init__.py new file mode 100644 index 00000000000..ae34db47954 --- /dev/null +++ b/examples/models/gemma3/__init__.py @@ -0,0 +1,16 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.gemma3.convert_weights import convert_weights +from executorch.examples.models.llama.model import Llama2Model + + +class Gemma3Model(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "Gemma3Model", + "convert_weights", +] diff --git a/examples/models/gemma3/config/1b_config.json b/examples/models/gemma3/config/1b_config.json new file mode 100644 index 00000000000..3a9e673716b --- /dev/null +++ b/examples/models/gemma3/config/1b_config.json @@ -0,0 +1,23 @@ +{ + "dim": 1152, + "ffn_dim_multiplier": 1, + "hidden_dim": 6912, + "n_heads": 4, + "head_dim": 256, + "n_kv_heads": 1, + "n_layers": 26, + "act_fn": "gelu_approx", + "norm_type": "gemma3", + "norm_eps": 1e-06, + "post_attention_norm": true, + "post_ffn_norm": true, + "rope_theta": 1000000.0, + "use_scaled_rope": false, + "apply_embedding": true, + "embedding_scale_factor": 33.941125497, + "vocab_size": 262144, + "use_hf_rope": true, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true +} diff --git a/examples/models/gemma3/convert_weights.py b/examples/models/gemma3/convert_weights.py new file mode 100644 index 00000000000..ed44b9eb1cc --- /dev/null +++ b/examples/models/gemma3/convert_weights.py @@ -0,0 +1,110 @@ +import argparse + +import json +import os +from typing import Dict + +import torch +from safetensors.torch import load_file + +from torchtune.models.convert_weights import get_mapped_key + + +# Weight mappings from Gemma 3's checkpoint to ExecuTorch's transformer parameters. +_GEMMA3_TO_EXECUTORCH = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm_fn.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm_fn.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.post_attention_norm.weight", + "model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.post_ffn_norm.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", +} + + +def gemma3_to_executorch( + state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Convert the state dict so that it matches what ExecuTorch's transformer definition expects. + """ + converted_state_dict = {} + for key, value in state_dict.items(): + new_key = get_mapped_key(key, _GEMMA3_TO_EXECUTORCH) + converted_state_dict[new_key] = value + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + return converted_state_dict + + +def load_checkpoint_from_safetensors(input_dir: str) -> Dict: + index_path = os.path.join(input_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + # Sharded checkpoint. + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + checkpoint_shards = sorted(set(weight_map.values())) + + # Load all the shards into memory + shard_to_weights = {} + for shard in checkpoint_shards: + shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) + + # Merge tensors into consolidated state dict. + merged_state_dict = {} + for weight_name, shard in weight_map.items(): + tensor = shard_to_weights[shard][weight_name] + merged_state_dict[weight_name] = tensor + return merged_state_dict + else: + # Single checkpoint. + state_dict = load_file(os.path.join(input_dir, "model.safetensors")) + return state_dict + + +def load_checkpoint(input_dir: str) -> Dict: + pytorch_path = os.path.join(input_dir, "pytorch_model.bin") + if os.path.exists(pytorch_path): + print("Loading checkpoint from PyTorch .bin file") + return torch.load(pytorch_path, map_location="cpu", weights_only=True) + print("Loading checkpoint from safetensors directory") + return load_checkpoint_from_safetensors(input_dir) + + +def convert_weights(input_dir: str, output_file: str) -> None: + print("Loading checkpoint...") + sd = load_checkpoint(input_dir) + print("Converting checkpoint...") + sd = gemma3_to_executorch(sd) + print("Saving checkpoint...") + torch.save(sd, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Gemma3 weights to ExecuTorch transformer format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 3ed9f23443b..651047ecd96 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -1,7 +1,39 @@ import dataclasses from dataclasses import dataclass +from enum import Enum +from functools import partial from typing import Any, Dict, Optional +import torch.nn.functional as F + + +class ActFn(Enum): + SILU = "silu" + GELU = "gelu" + GELU_APPROX = "gelu_approx" + + @classmethod + def from_string(cls, value: str) -> "ActFn": + """Convert string to ActFn enum.""" + try: + return cls(value) + except ValueError: + valid_values = [e.value for e in cls] + raise ValueError( + f"Invalid activation function: {value}. Valid options: {valid_values}" + ) + + def get_function(self): + """Return the corresponding activation function.""" + if self == ActFn.SILU: + return F.silu + elif self == ActFn.GELU: + return F.gelu + elif self == ActFn.GELU_APPROX: + return partial(F.gelu, approximate="tanh") + else: + raise ValueError(f"Unsupported activation function: {self}") + @dataclass class ModelArgs: @@ -15,6 +47,8 @@ class ModelArgs: multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 + post_attention_norm: bool = False + post_ffn_norm: bool = False max_batch_size: int = 1 max_seq_len: int = 2048 max_context_len: int = 2048 @@ -22,6 +56,8 @@ class ModelArgs: num_experts: int = 8 # Number of experts num_activated_experts: int = 2 # Number of experts to activate attention_type: str = "mha" # Attention type, registered in attention.py + norm_type: str = "rmsnorm" # Normalization type, registered in norm.py + act_fn: ActFn = dataclasses.field(default=ActFn.SILU) # Activation function type attention_qkv_bias: bool = False use_kv_cache: bool = False # Use key/value cache use_sdpa_with_kv_cache_op: bool = ( @@ -37,6 +73,7 @@ class ModelArgs: # A dictionary mapping from pruned token-id to original token-id output_prune_map: Optional[Dict[int, int]] = None apply_embedding: bool = True # Use embedding inside the transformer + embedding_scale_factor: float = 1.0 # Multiple by which to scale embeddings. apply_output: bool = True # Use output layer (unembedding) inside the transformer use_qk_norm: bool = False # apply normalization to q and k in the attention qk_norm_before_rope: bool = False # when to apply qk norm @@ -103,3 +140,7 @@ def find_multiple(n: int, k: int) -> int: if self.head_dim is None: self.head_dim = self.dim // self.n_heads + + # Convert string act_fn to enum if needed + if isinstance(self.act_fn, str): + self.act_fn = ActFn.from_string(self.act_fn) diff --git a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt index 78a7e2905e6..612f898028c 100644 --- a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt @@ -28,6 +28,7 @@ list( ${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp ${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp ${CMAKE_CURRENT_LIST_DIR}/runner/runner.h + ${CMAKE_CURRENT_LIST_DIR}/runner/cache_utils.h ${CMAKE_CURRENT_LIST_DIR}/runner/decoder_runner.cpp ${CMAKE_CURRENT_LIST_DIR}/runner/decoder_runner.h ${CMAKE_CURRENT_LIST_DIR}/runner/prompt_processor.cpp diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index 40d30c77ffb..5ce15cabaa9 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -5,10 +5,12 @@ This file provides you the instructions to run LLM Decoder model with different 1. LLAMA2 Stories 110M 2. LLAMA3.2 1B 3. LLAMA3.2 3B - 4. QWEN2.5 0.5B / 1.5B - 5. QWEN3 0.6B / 1.7B - 6. Phi4-mini-instruct - 7. SMOLLM2 135M + 4. Gemma3 1B + 5. Phi4-mini-instruct + 6. QWEN2.5 0.5B / 1.5B + 7. QWEN3 0.6B / 1.7B + 8. SMOLLM2 135M + We offer the following modes to execute the model: @@ -69,10 +71,16 @@ Default example using hybrid mode. python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" ``` +#### Gemma3 1B +Default example using hybrid mode +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma3-1b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +``` + #### Phi4-mini-instruct Default example using kv mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --tokenizer_model tokenizer.model --decoder_model phi_4_mini --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model phi_4_mini --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` #### QWEN2.5 0.5B diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py index f85f48f3de4..ad74754708c 100644 --- a/examples/qualcomm/oss_scripts/llama/__init__.py +++ b/examples/qualcomm/oss_scripts/llama/__init__.py @@ -7,6 +7,7 @@ import os from abc import ABC from dataclasses import dataclass +from enum import Enum from functools import partial from typing import Callable, Dict, Tuple, Type @@ -20,6 +21,8 @@ get_ptq_per_channel_quant_config, ) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + +from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights from executorch.examples.models.phi_4_mini import ( convert_weights as convert_phi_4_mini_weights, ) @@ -34,11 +37,21 @@ from executorch.examples.qualcomm.oss_scripts.llama.decoder_constants import ( DECODER_MODEL_VERSION, ) +from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( + MultiScopeAwareLlamaModel, +) +from tabulate import tabulate from torchao.quantization.pt2e import MinMaxObserver + BASE_DIR = os.path.dirname(__file__) +LLM_VARIANT_ARCHS = { + "gemma3-1b": MultiScopeAwareLlamaModel, +} + + @dataclass(init=False, frozen=True) class LLMModelConfig(ABC): """ @@ -79,6 +92,50 @@ class LLMModelConfig(ABC): r3: bool custom_annotation: Tuple + def __str__(self): # noqa: C901 + """ + Visualize the current LLMModelConfig settings in a readable table format. + + This method helps users quickly inspect key configuration, + skipping internal or irrelevant attributes and formatting complex types + like functions, enums, and partials for clarity. + + Returns: + str: A table showing the current config for LLM models. + """ + + def format_value(v): + if isinstance(v, partial): + func_name = ( + v.func.__name__ if hasattr(v.func, "__name__") else str(v.func) + ) + return f"partial({func_name})" + elif isinstance(v, Callable): + return v.__name__ if hasattr(v, "__name__") else str(v) + elif isinstance(v, Enum): + return f"{v.__class__.__name__}.{v.name}" + elif isinstance(v, (tuple, list)): + return "(" + ", ".join(format_value(i) for i in v) + ")" + elif isinstance(v, (str, int, float, bool)): + return v + else: + return f"<{v.__class__.__name__}>" + + attrs = {} + for k in dir(self): + if k.startswith("_") or k in {"convert_weights", "params_path"}: + continue + try: + v = getattr(self, k) + if k in {"get_kv_io_bit_width", "get_logits_output_bit_width"}: + v = v() + except Exception: + v = f"Warning: failed to retrieve config for '{k}'" + if isinstance(v, (str, int, float, bool, tuple, list, Callable)): + attrs[k] = format_value(v) + table = [(k, v) for k, v in attrs.items()] + return tabulate(table, headers=["Config", "Value"], tablefmt="grid") + def get_kv_io_bit_width(self) -> int: if self.ptq is None: return 32 @@ -198,6 +255,66 @@ class Llama3_2(LLMModelConfig): ) +@register_llm_model("gemma3-1b") +@dataclass(init=False, frozen=True) +class Gemma3(LLMModelConfig): + repo_id: str = "google/gemma-3-1b-it" + params_path: str = os.path.join( + BASE_DIR, "../../../models/gemma3/config/1b_config.json" + ) + convert_weights = convert_gemma3_weights + transform_weight = False + instruct_model = True + + num_sharding = 1 + # quant config + ptq = QuantDtype.use_16a4w_block + group_size = 64 + masked_softmax = True + r1 = False + r2 = False + r3 = False + quantization_config_wv_sha_16a8w = get_ptq_per_channel_quant_config( + torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver + ) + custom_annotation = ( + annotate_kv_8bit, + partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_16a8w), + ) + + +@register_llm_model("phi_4_mini") +@dataclass(init=False, frozen=True) +class Phi4Mini(LLMModelConfig): + repo_id: str = "microsoft/Phi-4-mini-instruct" + params_path: str = os.path.join( + BASE_DIR, "../../../models/phi_4_mini/config/config.json" + ) + convert_weights = convert_phi_4_mini_weights + transform_weight = False + instruct_model = True + + num_sharding = 8 + # quant config + ptq = QuantDtype.use_16a4w_block + group_size = 16 + masked_softmax = False + r1 = False + r2 = False + r3 = False + quantization_config_wv_sha_8a4w = get_ptq_per_channel_quant_config( + act_dtype=torch.uint8, + weight_dtype=torch.int4, + act_observer=MinMaxObserver, + act_symmetric=True, + ) + custom_annotation = ( + annotate_kv_8bit, + annotate_output_16a8w, + partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_8a4w), + ) + + @register_llm_model("qwen2_5-0_5b") @dataclass(init=False, frozen=True) class Qwen2_5_0_5B(LLMModelConfig): @@ -289,38 +406,6 @@ class Qwen3_1_7B(LLMModelConfig): ) -@register_llm_model("phi_4_mini") -@dataclass(init=False, frozen=True) -class Phi4Mini(LLMModelConfig): - repo_id: str = "microsoft/Phi-4-mini-instruct" - params_path: str = os.path.join( - BASE_DIR, "../../../models/phi_4_mini/config/config.json" - ) - convert_weights = convert_phi_4_mini_weights - transform_weight = False - instruct_model = True - - num_sharding = 8 - # quant config - ptq = QuantDtype.use_16a4w_block - group_size = 16 - masked_softmax = False - r1 = False - r2 = False - r3 = False - quantization_config_wv_sha_8a4w = get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int4, - act_observer=MinMaxObserver, - act_symmetric=True, - ) - custom_annotation = ( - annotate_kv_8bit, - annotate_output_16a8w, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_8a4w), - ) - - @register_llm_model("smollm2_135m") @dataclass(init=False, frozen=True) class Smollm2_135M(LLMModelConfig): diff --git a/examples/qualcomm/oss_scripts/llama/decoder_constants.py b/examples/qualcomm/oss_scripts/llama/decoder_constants.py index 9b00e38f73e..03bf5043d60 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_constants.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_constants.py @@ -15,10 +15,11 @@ "stories260k": "llama2", "stories110m": "llama2", "llama3_2": "llama3", + "gemma3-1b": "gemma3", + "phi_4_mini": "phi_4_mini", "qwen2_5-0_5b": "qwen2_5", "qwen2_5-1_5b": "qwen2_5", "qwen3-0_6b": "qwen3", "qwen3-1_7b": "qwen3", - "phi_4_mini": "phi_4_mini", "smollm2_135m": "smollm2_135m", } diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index 85749232f94..eaa25698e90 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -13,11 +13,12 @@ import torch from executorch.examples.models.llama.evaluate.eager_eval import EagerEvalWrapper - from executorch.examples.qualcomm.oss_scripts.llama.decoder_constants import ( DECODER_MODEL_VERSION, EVAL_MODE, ) +from executorch.examples.qualcomm.oss_scripts.llama.masking_utils import AttentionMask + from executorch.examples.qualcomm.utils import make_output_dir, SimpleADB from executorch.exir._serialize._program import deserialize_pte_binary from pytorch_tokenizers.hf_tokenizer import HuggingFaceTokenizer @@ -32,6 +33,16 @@ ) +INFERENCE_REGISTRY = {} + + +def register_inference(use_kv_cache: bool): + def decorator(func): + INFERENCE_REGISTRY[use_kv_cache] = func + + return decorator + + class GraphModuleCalibrationWrapper(EagerEvalWrapper): """ A wrapper class for calibration @@ -65,28 +76,20 @@ def __init__( def _model_call(self, inps): all_logits = None + kwargs = {} if self._use_kv_cache: - all_logits = kv_inference( - self.get_example_inputs, - inps, - self._model, - self._tokenizer, - self.ar_len, - self.max_seq_length, - kv_updater=self.kv_updater, - use_i64_token=self.use_i64_token, - collect_logits=True, - ) - else: - all_logits = prefill_inference( - self.get_example_inputs, - inps, - self._model, - self._tokenizer, - self.max_seq_length, - use_i64_token=self.use_i64_token, - collect_logits=True, - ) + kwargs["ar_len"] = self.ar_len + kwargs["kv_updater"] = self.kv_updater + all_logits = INFERENCE_REGISTRY[self._use_kv_cache]( + self.get_example_inputs, + inps, + self._model, + self._tokenizer, + max_seq_len=self.max_seq_length, + use_i64_token=self.use_i64_token, + collect_logits=True, + **kwargs, + ) return all_logits @@ -231,7 +234,13 @@ def post_process(): def smart_mask_updater( - _, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches + n_updates: int, + atten_mask: AttentionMask, + pos, + k_caches, + v_caches, + new_k_caches, + new_v_caches, ): # ar_len is unused in smart mask max_cache_len = k_caches[0].size(-1) @@ -241,14 +250,20 @@ def smart_mask_updater( for i, v_cache in enumerate(v_caches): v_cache[:, pos : pos + n_updates, :] = new_v_caches[i][:, :n_updates, :] - atten_mask[:, :, pos : pos + n_updates] = 0 + atten_mask.smart_mask_update(pos, n_updates) pos += n_updates - return (atten_mask, pos, k_caches, v_caches) + return pos, k_caches, v_caches def shift_pointer_updater( - ar_len, n_updates, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches + n_updates: int, + atten_mask: AttentionMask, + pos, + k_caches, + v_caches, + new_k_caches, + new_v_caches, ): max_cache_len = k_caches[0].size(-1) if pos + n_updates <= max_cache_len: @@ -264,12 +279,13 @@ def shift_pointer_updater( ) for i, v_cache in enumerate(v_caches) ] - atten_mask[:, :, -pos - n_updates - ar_len : -pos - ar_len] = 0 + atten_mask.shift_pointer_update(pos, n_updates) pos += n_updates - return (atten_mask, pos, k_caches, v_caches) + return pos, k_caches, v_caches +@register_inference(use_kv_cache=True) def kv_inference( get_example_inputs, prompt: Union[str, list], @@ -331,7 +347,7 @@ def kv_inference( # Run inference. logits, new_k_caches, new_v_caches = module( tmp_token_list, - atten_mask, + *atten_mask, tmp_pos, *k_caches, *v_caches, @@ -340,8 +356,7 @@ def kv_inference( result_logits.append(logits[:, :num_tokens_in_chunk]) # Update the pos, KV cache and attention mask. - atten_mask, pos, k_caches, v_caches = kv_updater( - ar_len, + pos, k_caches, v_caches = kv_updater( num_tokens_in_chunk, atten_mask, pos, @@ -378,7 +393,7 @@ def kv_inference( logits, new_k_caches, new_v_caches = module( tmp_token_list, - atten_mask, + *atten_mask, tmp_pos, *k_caches, *v_caches, @@ -386,8 +401,7 @@ def kv_inference( if collect_logits: result_logits.append(logits[:, :num_tokens_in_chunk]) - atten_mask, pos, k_caches, v_caches = kv_updater( - ar_len, + pos, k_caches, v_caches = kv_updater( 1, atten_mask, pos, @@ -406,6 +420,7 @@ def kv_inference( return result_logits +@register_inference(use_kv_cache=False) def prefill_inference( get_example_inputs, prompt: Union[str, list], @@ -449,12 +464,9 @@ def prefill_inference( ], dim=1, ) - results = module( - tmp_token_list, - atten_mask, - ) + results = module(tmp_token_list, *atten_mask) if len(results) == 3: - logits, new_k_caches, new_v_caches = results + logits, _, _ = results elif len(results) == 1: logits = results token = torch.argmax(logits[:, pos - 1], dim=-1).item() @@ -492,28 +504,20 @@ def graph_module_inference( prompt is None ), "Please provide either tasks or prompt - not both or neither" if tasks is None: + kwargs = {} if use_kv_cache: - kv_inference( - get_example_inputs, - prompt, - module, - tokenizer, - ar_len, - max_seq_len, - kv_updater=kv_updater, - use_i64_token=use_i64_token, - collect_logits=False, - ) - else: - prefill_inference( - get_example_inputs, - prompt, - module, - tokenizer, - max_seq_len, - use_i64_token, - collect_logits=False, - ) + kwargs["ar_len"] = ar_len + kwargs["kv_updater"] = kv_updater + INFERENCE_REGISTRY[use_kv_cache]( + get_example_inputs, + prompt, + module, + tokenizer, + max_seq_len=max_seq_len, + use_i64_token=use_i64_token, + collect_logits=False, + **kwargs, + ) else: calibration_wrapper = GraphModuleCalibrationWrapper( model=module, diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index f9196d38750..73a804219d5 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -62,6 +62,7 @@ get_quant_embedding_transform, ) from executorch.examples.qualcomm.oss_scripts.llama import ( + LLM_VARIANT_ARCHS, LLMModelConfig, SUPPORTED_LLM_MODELS, ) @@ -137,14 +138,14 @@ def __init__( self.llama_meta = self.decoder_model.get_metadata() self.has_quant_io = False self.pte_filename = pte_filename - if self.llama_meta["get_use_kv_cache"]: - tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs( - use_kv_cache=True - ) - self.inputs = (tokens, atten_mask, pos_ids, *k_caches, *v_caches) - else: - tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) - self.inputs = (tokens, atten_mask) + inputs = self.get_example_inputs(self.llama_meta["get_use_kv_cache"]) + self.inputs = ( + inputs[0], # tokens + *inputs[1], # attn_mask + *(inputs[2] if self.llama_meta["get_use_kv_cache"] else []), # pos_ids + *(inputs[3] if self.llama_meta["get_use_kv_cache"] else []), # k_caches + *(inputs[4] if self.llama_meta["get_use_kv_cache"] else []), # v_caches + ) self.llama_graph_module = decoder_model self.io_shape = { # logit output @@ -227,7 +228,6 @@ def quantize( quantizer = make_custom_quantizer( quant_dtype, args.range_setting, custom_annotations ) - self.has_quant_io = True fx_graph_module = None with torch.no_grad(): @@ -451,39 +451,60 @@ def compile( llama_instance_list = [] use_i64_token = args.embedding_quantize is not None + extra_kwargs = {} + if args.decoder_model == "gemma3-1b": + from transformers import Gemma3Config + + hf_config = Gemma3Config.from_pretrained(decoder_model_config.repo_id) + extra_kwargs["layer_types"] = hf_config.text_config.layer_types + extra_kwargs["rope_local_base_freq"] = ( + hf_config.text_config.rope_local_base_freq + ) + extra_kwargs["sliding_window"] = hf_config.sliding_window + with torch.device("meta"): if args.model_mode == "kv": llama_instance_list.append( - LlamaModel( + LLM_VARIANT_ARCHS.get(args.decoder_model, LlamaModel)( kv_config, ar_len=1, output_new_cache_only=True, output_cache=True, use_i64_token=use_i64_token, + **extra_kwargs, ) ) elif args.model_mode == "hybrid": llama_instance_list.append( - LlamaModel( + LLM_VARIANT_ARCHS.get(args.decoder_model, LlamaModel)( kv_config, ar_len=1, output_new_cache_only=True, output_cache=True, use_i64_token=use_i64_token, + **extra_kwargs, ) ) llama_instance_list.append( - LlamaModel( + LLM_VARIANT_ARCHS.get(args.decoder_model, LlamaModel)( prefill_config, ar_len=args.prefill_ar_len, output_new_cache_only=True, output_cache=True, use_i64_token=use_i64_token, + **extra_kwargs, ) ) elif args.model_mode == "lookahead": + # TODO: Lookahead decoding is not yet supported for gemma3-1b. + # This will be implemented once the model architecture and KV update logic are adapted. + if args.decoder_model == "gemma3-1b": + raise NotImplementedError( + "gemma3-1b does not currently support lookahead decoding." + ) + llama_instance_list.append( - LlamaModel( + LLM_VARIANT_ARCHS.get(args.decoder_model, LlamaModel)( kv_config, # To get better performance, we round up to the nearest power of 2. ar_len=next_power_of_two( @@ -492,15 +513,17 @@ def compile( output_new_cache_only=True, output_cache=True, use_i64_token=use_i64_token, + **extra_kwargs, ) ) llama_instance_list.append( - LlamaModel( + LLM_VARIANT_ARCHS.get(args.decoder_model, LlamaModel)( prefill_config, ar_len=args.prefill_ar_len, output_new_cache_only=True, output_cache=True, use_i64_token=use_i64_token, + **extra_kwargs, ) ) else: @@ -514,6 +537,14 @@ def compile( state_dict = torch.load( checkpoint, weights_only=True, map_location="cpu", mmap=True ) + if args.decoder_model == "gemma3-1b": + for k, v in state_dict.items(): + if "norm" not in k: + continue + # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + state_dict[k] = v.float() + torch.ones(v.shape, dtype=torch.float32) + else: state_dict = torch.load( args.checkpoint, weights_only=True, map_location="cpu", mmap=True @@ -591,9 +622,13 @@ def permute(w, heads): model.to(torch.float) ar_len, model.ar_len = model.ar_len, model.max_seq_len tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) - atten_mask.to(torch.float) + atten_mask.mask.to(torch.float) wrapped_model = WrappedLlamaModel( - model, atten_mask, model.use_kv_cache, args.max_seq_len, args.device + model, + atten_mask.mask, + model.use_kv_cache, + args.max_seq_len, + args.device, ) act_bits, weight_bits = { QuantDtype.use_8a8w: (8, 8), @@ -659,7 +694,6 @@ def permute(w, heads): if decoder_model_config.ptq: start_quantize_ts = time.time() custom_annotations = decoder_model_config.custom_annotation - logging.info(f"Custom annotations applied: {custom_annotations}") kv_quant_attrs = {} for i, llama_instance in enumerate(llama_instance_list): llama_instance.quantize( @@ -1155,6 +1189,7 @@ def export_llama(args) -> None: args.decoder_model in SUPPORTED_LLM_MODELS ), f"Unknown decoder_model: {args.decoder_model}." decoder_model_config = SUPPORTED_LLM_MODELS[args.decoder_model] + logging.info(f"*** {args.decoder_model} ***\n%s", str(decoder_model_config)) if args.model_mode == "kv": pte_filename = "kv_llama_qnn" diff --git a/examples/qualcomm/oss_scripts/llama/masking_utils.py b/examples/qualcomm/oss_scripts/llama/masking_utils.py new file mode 100644 index 00000000000..bed81c894f0 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/masking_utils.py @@ -0,0 +1,323 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import List, Union + +import torch + + +def create_causal_attn_mask(max_batch_size: int, ar_len: int, max_seq_len: int): + """ + Creating a causal attention mask (ar_len: 5, max_seq_len: 15) + 0 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ○ + 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ● + + ● = activate (can attend), ○ = inactivate (masked) + """ + mask = torch.full((ar_len, ar_len), -255.0) + mask_cond = torch.arange(ar_len) + mask.masked_fill_(mask_cond.view(1, ar_len) <= mask_cond.view(ar_len, 1), 0) + + if max_seq_len != ar_len: + mask = torch.cat( + [ + torch.ones(ar_len, max_seq_len - ar_len) * -255.0, + mask, + ], + dim=-1, + ) + mask = mask[None, :, :].expand(max_batch_size, ar_len, max_seq_len) + return mask + + +def create_sliding_window_attn_mask( + max_batch_size: int, ar_len: int, max_seq_len: int, sliding_window: int +): + """ + Creating a sliding_window attention mask (ar_len: 5, max_seq_len: 15, sliding_window: 3) + 0 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ + 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● + + ● = activate (can attend), ○ = inactivate (masked) + """ + mask = torch.full((ar_len, ar_len), -255.0) + mask_cond = torch.arange(ar_len) + mask.masked_fill_( + (mask_cond.view(1, ar_len) <= mask_cond.view(ar_len, 1)) + & (mask_cond.view(ar_len, 1) - mask_cond.view(1, ar_len) < sliding_window), + 0, + ) + + if max_seq_len != ar_len: + mask = torch.cat( + [ + torch.ones(ar_len, max_seq_len - ar_len) * -255.0, + mask, + ], + dim=-1, + ) + mask = mask[None, :, :].expand(max_batch_size, ar_len, max_seq_len) + return mask + + +class BaseAttentionMask(ABC): + def __init__(self, max_batch_size: int, ar_len: int, max_seq_len: int): + """ + Base class for attention masks used in autoregressive or hybrid attention mechanisms. + + Args: + max_batch_size (int): Maximum batch size supported. + ar_len (int): Length of the autoregressive sequence. + max_seq_len (int): Maximum sequence length. + """ + self.max_batch_size = max_batch_size + self.ar_len = ar_len + self.max_seq_len = max_seq_len + + @property + @abstractmethod + def mask(self) -> torch.Tensor: + """ + Attention mask tensor that must be initialized by child classes. + """ + pass + + @abstractmethod + def smart_mask_update(self, pos, n_updates): + """ + Update the attention mask by smart mask update method after model forward. + + Args: + pos (int): Current position in the sequence. + n_updates (int): Number of new tokens to update. + """ + pass + + @abstractmethod + def shift_pointer_update(self, pos, n_updates): + """ + Update the attention mask by shift pointer update method after model forward. + + Args: + pos (int): Current position in the sequence. + n_updates (int): Number of tokens to shift. + """ + pass + + +class CausalAttentionMask(BaseAttentionMask): + def __init__(self, max_batch_size: int, ar_len: int, max_seq_len: int): + super().__init__(max_batch_size, ar_len, max_seq_len) + self._mask = create_causal_attn_mask(max_batch_size, ar_len, max_seq_len) + + @property + def mask(self): + return self._mask + + def smart_mask_update(self, pos, n_updates): + """ + Smart Mask mechanism for attention mask updating + Initial mask(5x15) layout (before any updates): + Each row represents a query token in the autoregressive context. + ● = activate (can attend), ○ = inactivate (masked) + + 0 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ○ + 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ● + + After 1st update (e.g., pos=0, n_updates=5, sliding_window=3): + Newly added tokens are unmasked (set to 0). + + 0 ● ● ● ● ● ○ ○ ○ ○ ○ ● ○ ○ ○ ○ + 1 ● ● ● ● ● ○ ○ ○ ○ ○ ● ● ○ ○ ○ + 2 ● ● ● ● ● ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ● ● ● ● ● ○ ○ ○ ○ ○ ● ● ● ● ○ + 4 ● ● ● ● ● ○ ○ ○ ○ ○ ● ● ● ● ● + + After 2nd update (e.g., pos=5, n_updates=5): + + 0 ● ● ● ● ● ● ● ● ● ● ● ○ ○ ○ ○ + 1 ● ● ● ● ● ● ● ● ● ● ● ● ○ ○ ○ + 2 ● ● ● ● ● ● ● ● ● ● ● ● ● ○ ○ + 3 ● ● ● ● ● ● ● ● ● ● ● ● ● ● ○ + 4 ● ● ● ● ● ● ● ● ● ● ● ● ● ● ● + """ + start_pos = pos + end_pos = pos + n_updates + self.mask[:, :, start_pos:end_pos] = 0 + + def shift_pointer_update(self, pos, n_updates): + """ + Shift Pointer mechanism for attention mask updating + + Initial mask(5x15) layout (before any updates): + Each row represents a query token in the autoregressive context. + ● = activate (can attend), ○ = inactivate (masked) + + Init mask: + 0 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ○ + 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ● + + After 1st update (e.g., pos=0, n_updates=5, sliding_window=3): + Newly added tokens are unmasked (set to 0). + 0 ○ ○ ○ ○ ○ ● ● ● ● ● ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ○ ● ● ● ● ● ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ● ● ● ● ● ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ● ● ● ● ● ● ● ● ● ○ + 4 ○ ○ ○ ○ ○ ● ● ● ● ● ● ● ● ● ● + + After 2nd update (e.g., pos=5, n_updates=5): + 0 ● ● ● ● ● ● ● ● ● ● ● ○ ○ ○ ○ + 1 ● ● ● ● ● ● ● ● ● ● ● ● ○ ○ ○ + 2 ● ● ● ● ● ● ● ● ● ● ● ● ● ○ ○ + 3 ● ● ● ● ● ● ● ● ● ● ● ● ● ● ○ + 4 ● ● ● ● ● ● ● ● ● ● ● ● ● ● ● + """ + start_pos = -pos - n_updates - self.ar_len + end_pos = -pos - self.ar_len + self.mask[:, :, start_pos:end_pos] = 0 + + +class SlidingWindowAttentionMask(BaseAttentionMask): + def __init__( + self, + max_batch_size: int, + ar_len: int, + max_seq_len: int, + sliding_window: int, + ): + super().__init__(max_batch_size, ar_len, max_seq_len) + self._mask = create_sliding_window_attn_mask( + max_batch_size, ar_len, max_seq_len, sliding_window + ) + self.sliding_window = sliding_window + + @property + def mask(self): + return self._mask + + def smart_mask_update(self, pos, n_updates): + """ + Smart Mask mechanism for attention mask updating + + Initial mask(5x15) layout (before any updates): + Each row represents a query token in the autoregressive context. + ● = activate (can attend), ○ = inactivate (masked) + + 0 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ + 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● + + + After 1st update (e.g., pos=0, n_updates=5, sliding_window=3): + Newly added tokens are unmasked (set to 0). + Earlier tokens lose access to older cache due to sliding window limits. + + 0 ○ ○ ○ ● ● ○ ○ ○ ○ ○ ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ● ○ ○ ○ ○ ○ ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ + 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● + + + After 2nd update (e.g., pos=5, n_updates=5): + Sliding window shifts again, masking older positions and activate new postion. + + 0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ + 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● + """ + start_pos = pos + end_pos = pos + n_updates + # Unmask the same range in the sliding window mask + self.mask[:, :, start_pos:end_pos] = 0 + + for i in range(self.ar_len): + # Calculate how many cached tokens are still avalible for this row + avalible_cache_len = self.sliding_window - (i + 1) + + # If the current position exceeds available cache, mask the overflow + if end_pos > avalible_cache_len: + # Mask tokens that are no longer within the sliding window + # TODO: [Optional]: it can be optimized by computing the exact start index + self.mask[:, i, : end_pos - avalible_cache_len] = -255.0 + + def shift_pointer_update(self, pos, n_updates): + """ + Shift Pointer mechanism for attention mask updating + + Initial mask(5x15) layout (before any updates): + Each row represents a query token in the autoregressive context. + ● = activate (can attend), ○ = inactivate (masked) + + Init mask: + 0 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ + 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● + + After 1st update (e.g., pos=0, n_updates=5, sliding_window=3): + 0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ + 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● + + After 2nd update (e.g., pos=5, n_updates=5): + 0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○ + 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ + 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ + 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ + 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● + """ + + start_pos = -pos - n_updates - self.ar_len + end_pos = -pos - self.ar_len + self.mask[:, :, start_pos:end_pos] = 0 + + for i in range(self.ar_len): + avalible_cache_len = self.sliding_window - (i + 1) + if abs(start_pos + self.ar_len) > avalible_cache_len: + self.mask[ + :, + i, + start_pos : start_pos + + abs(start_pos + self.ar_len) + - avalible_cache_len, + ] = -255.0 + + +class AttentionMask: + def __init__(self, masks: Union[BaseAttentionMask, List[BaseAttentionMask]]): + self.masks = masks if isinstance(masks, list) else [masks] + + def smart_mask_update(self, pos, n_updates): + for mask in self.masks: + mask.smart_mask_update(pos, n_updates) + + def shift_pointer_update(self, pos, n_updates): + for mask in self.masks: + mask.shift_pointer_update(pos, n_updates) + + def __iter__(self): + return iter([mask.mask for mask in self.masks]) diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index 08c67e9d1d6..32764eba985 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -13,12 +13,16 @@ import scipy import torch import torch.nn as nn -import torch.nn.functional as F from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import ( hf_precompute_freqs_cis, precompute_freqs_cis, ) +from executorch.examples.qualcomm.oss_scripts.llama.masking_utils import ( + AttentionMask, + CausalAttentionMask, + SlidingWindowAttentionMask, +) def apply_rotary_emb_single( @@ -358,6 +362,7 @@ def __init__(self, args: ModelArgs): self.w1 = nn.Linear(self.dim, self.hidden_dim, bias=False) self.w2 = nn.Linear(self.hidden_dim, self.dim, bias=False) self.w3 = nn.Linear(self.dim, self.hidden_dim, bias=False) + self.act_fn = args.act_fn.get_function() def prepare_feedfoward_conv(self): self.w1_conv = nn.Conv2d(self.dim, self.hidden_dim, 1, bias=False) @@ -379,13 +384,13 @@ def forward_feedfoward_conv(self, x): bsz, _, _ = x.size() x = torch.reshape(x, (bsz, -1, 1, self.dim)) x = x.transpose(1, 3) # Transpose right before and after Conv - x = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x)) + x = self.w2_conv(self.act_fn(self.w1_conv(x)) * self.w3_conv(x)) x = x.transpose(1, 3) x = torch.reshape(x, (bsz, -1, self.dim)) return x def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + return self.w2(self.act_fn(self.w1(x)) * self.w3(x)) class LlamaDecoderLayer(nn.Module): @@ -399,6 +404,16 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False): self.feed_forward = FeedForward(config) self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) self.ffn_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps) + self.post_attention_norm = ( + torch.nn.RMSNorm(config.dim, eps=config.norm_eps) + if config.post_attention_norm + else None + ) + self.post_ffn_norm = ( + torch.nn.RMSNorm(config.dim, eps=config.norm_eps) + if config.post_ffn_norm + else None + ) def forward( self, @@ -417,8 +432,13 @@ def forward( k_caches=k_caches, v_caches=v_caches, ) + if self.post_attention_norm: + h = self.post_attention_norm(h) h = x + h - output = h + self.feed_forward(self.ffn_norm(h)) + out = self.feed_forward(self.ffn_norm(h)) + if self.post_ffn_norm: + out = self.post_ffn_norm(out) + output = h + out return output, k_cache, v_cache @@ -430,6 +450,7 @@ def __init__( output_new_cache_only=True, output_cache=True, use_i64_token=False, + **kwargs, ): super().__init__() self.dim = config.dim @@ -442,6 +463,7 @@ def __init__( self.vocab_size = config.vocab_size self.rope_freq_base = config.rope_freq_base self.use_kv_cache = config.use_kv_cache + self.embedding_scale_factor = config.embedding_scale_factor self.ar_len = ar_len self.output_new_cache_only = output_new_cache_only self.use_i64_token = use_i64_token @@ -511,7 +533,7 @@ def forward( self.freqs_sin[input_pos][0] if self.use_kv_cache else self.freqs_sin ) - hidden_states = self.tok_embeddings(tokens) + hidden_states = self.embedding_scale_factor * self.tok_embeddings(tokens) for ind, decoder_layer in enumerate(self.layers): k_caches = None v_caches = None @@ -543,22 +565,8 @@ def get_example_inputs(self, use_kv_cache=True): tokens = torch.randint( self.vocab_size, (self.max_batch_size, self.ar_len), dtype=dtype ) - - atten_mask = torch.full((self.ar_len, self.ar_len), torch.tensor(-255.0)) - mask_cond = torch.arange(atten_mask.size(-1)) - atten_mask.masked_fill_( - mask_cond < (mask_cond + 1).view(atten_mask.size(-1), 1), 0 - ) - if self.max_seq_len != self.ar_len: - atten_mask = torch.cat( - [ - torch.ones(self.ar_len, self.max_seq_len - self.ar_len) * -255.0, - atten_mask, - ], - dim=-1, - ) - atten_mask = atten_mask[None, :, :].expand( - self.max_batch_size, self.ar_len, self.max_seq_len + atten_mask = AttentionMask( + CausalAttentionMask(self.max_batch_size, self.ar_len, self.max_seq_len) ) if use_kv_cache: pos_ids = torch.zeros((self.max_batch_size, self.ar_len), dtype=torch.int32) @@ -612,3 +620,130 @@ def get_metadata(self): "get_use_kv_cache": self.use_kv_cache, "get_kv_io_bit_width": self.kv_io_bit_width, } + + +class MultiScopeAwareLlamaModel(LlamaModel): + def __init__( + self, + config: ModelArgs, + ar_len=1, + output_new_cache_only=True, + output_cache=True, + use_i64_token=False, + **kwargs, + ): + super().__init__( + config=config, + ar_len=ar_len, + output_new_cache_only=output_new_cache_only, + output_cache=output_cache, + use_i64_token=use_i64_token, + ) + + for key in ["layer_types", "sliding_window", "rope_local_base_freq"]: + assert key in kwargs, f"Missing required argument: '{key}' in kwargs" + + # Get attention type for each layer + self.layer_types = kwargs["layer_types"] + # Get sliding window size (used in local/global attention) + self.sliding_window = kwargs["sliding_window"] + # Get local freq base for sliding attention + rope_freq_base = kwargs["rope_local_base_freq"] + + local_freqs_cos, local_freqs_sin = hf_precompute_freqs_cis( + config.head_dim, + config.max_seq_len, + rope_freq_base, + config.partial_rotary_factor, + ) + local_freqs_cos = local_freqs_cos[:, : local_freqs_cos.shape[-1] // 2] + local_freqs_sin = local_freqs_sin[:, : local_freqs_sin.shape[-1] // 2] + self.register_buffer("local_freqs_cos", local_freqs_cos, persistent=False) + self.register_buffer("local_freqs_sin", local_freqs_sin, persistent=False) + + def forward( + self, + tokens: torch.Tensor, + atten_mask: torch.Tensor, + window_atten_mask: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + *args, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + + output_k_cache = [] + output_v_cache = [] + # following tensors should be invariant across batches + freqs_cos = ( + self.freqs_cos[input_pos][0] if self.use_kv_cache else self.freqs_cos + ) + freqs_sin = ( + self.freqs_sin[input_pos][0] if self.use_kv_cache else self.freqs_sin + ) + local_freqs_cos = ( + self.local_freqs_cos[input_pos][0] + if self.use_kv_cache + else self.local_freqs_cos + ) + local_freqs_sin = ( + self.local_freqs_sin[input_pos][0] + if self.use_kv_cache + else self.local_freqs_sin + ) + + hidden_states = self.embedding_scale_factor * self.tok_embeddings(tokens) + for ind, decoder_layer in enumerate(self.layers): + k_caches = None + v_caches = None + if self.use_kv_cache: + offset_k = ind * self.n_kv_heads + offset_v = self.n_layers * self.n_kv_heads + offset_k + k_caches = args[offset_k : offset_k + self.n_kv_heads] + v_caches = args[offset_v : offset_v + self.n_kv_heads] + + if self.layer_types[ind] == "sliding_attention": + hidden_states, k, v = decoder_layer( + hidden_states, + freqs_cos=local_freqs_cos, + freqs_sin=local_freqs_sin, + atten_mask=window_atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + else: + hidden_states, k, v = decoder_layer( + hidden_states, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + + output_k_cache.extend(k) + output_v_cache.extend(v) + + hidden_states = self.norm(hidden_states) + logits = self.output(hidden_states) + if self.output_cache: + return logits, output_k_cache, output_v_cache + return logits + + def get_example_inputs(self, use_kv_cache=True): + inputs = list(super().get_example_inputs(use_kv_cache=use_kv_cache)) + causal_mask = CausalAttentionMask( + self.max_batch_size, self.ar_len, self.max_seq_len + ) + sliding_window_mask = SlidingWindowAttentionMask( + self.max_batch_size, + self.ar_len, + self.max_seq_len, + sliding_window=self.sliding_window, + ) + # Don't reverse the order of attention mask + inputs[1] = AttentionMask([causal_mask, sliding_window_mask]) + return tuple(inputs) + + def get_metadata(self): + meta_data = super().get_metadata() + meta_data["get_sliding_window"] = self.sliding_window + return meta_data diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index dab74dc966b..80d92beb099 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -9,8 +9,9 @@ /** * @file * - * This tool can run Llama2 110M, Llama3.2 1B / 3B, Qwen2.5 0.5B / 1.5B, Qwen3 - * 0.6B / 1.7B, phi4-mini-instruct, Smollm2 135M with Qualcomm AI Engine Direct. + * This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma3 1B, + * phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B / 1.7B, Smollm2 135M with + * Qualcomm AI Engine Direct. * */ @@ -102,19 +103,29 @@ std::string get_formatted_prompt( std::string formatted_prompt; switch (decoder_model_version) { case example::DecoderModelVersion::kLlama2: - case example::DecoderModelVersion::kQwen2_5: formatted_prompt.append(prompt); break; - case example::DecoderModelVersion::kQwen3: - formatted_prompt.append("<|im_start|>user\n"); + case example::DecoderModelVersion::kLlama3: + if (!system_prompt.empty()) { + formatted_prompt.append( + "<|start_header_id|>system<|end_header_id|>\n\n"); + formatted_prompt.append(system_prompt); + formatted_prompt.append("<|eot_id|>"); + } + formatted_prompt.append("<|start_header_id|>user<|end_header_id|>\n\n"); formatted_prompt.append(prompt); - formatted_prompt.append("<|im_end|>\n"); + formatted_prompt.append( + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); + break; + case example::DecoderModelVersion::kGemma3: + formatted_prompt.append("user\n"); + formatted_prompt.append(prompt); + formatted_prompt.append("\n"); + formatted_prompt.append("model\n"); if (!system_prompt.empty()) { - formatted_prompt.append("<|im_start|>system\n"); formatted_prompt.append(system_prompt); - formatted_prompt.append("<|im_end|>\n"); + formatted_prompt.append("\n"); } - formatted_prompt.append("<|im_start|>assistant"); break; case example::DecoderModelVersion::kPhi4: if (!system_prompt.empty()) { @@ -125,27 +136,30 @@ std::string get_formatted_prompt( formatted_prompt.append("<|user|>"); formatted_prompt.append(prompt); formatted_prompt.append("<|end|><|assistant|>"); - case example::DecoderModelVersion::kSmollm2_135m: + break; + case example::DecoderModelVersion::kQwen2_5: + formatted_prompt.append(prompt); + break; + case example::DecoderModelVersion::kQwen3: + formatted_prompt.append("<|im_start|>user\n"); + formatted_prompt.append(prompt); + formatted_prompt.append("<|im_end|>\n"); if (!system_prompt.empty()) { formatted_prompt.append("<|im_start|>system\n"); formatted_prompt.append(system_prompt); - formatted_prompt.append("<|im_end|>\n\n"); + formatted_prompt.append("<|im_end|>\n"); } - formatted_prompt.append("<|im_start|>user\n"); - formatted_prompt.append(prompt); - formatted_prompt.append("<|im_end|>\n\n"); + formatted_prompt.append("<|im_start|>assistant"); break; - case example::DecoderModelVersion::kLlama3: + case example::DecoderModelVersion::kSmollm2_135m: if (!system_prompt.empty()) { - formatted_prompt.append( - "<|start_header_id|>system<|end_header_id|>\n\n"); + formatted_prompt.append("<|im_start|>system\n"); formatted_prompt.append(system_prompt); - formatted_prompt.append("<|eot_id|>"); + formatted_prompt.append("<|im_end|>\n\n"); } - formatted_prompt.append("<|start_header_id|>user<|end_header_id|>\n\n"); + formatted_prompt.append("<|im_start|>user\n"); formatted_prompt.append(prompt); - formatted_prompt.append( - "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); + formatted_prompt.append("<|im_end|>\n\n"); break; default: ET_CHECK_MSG(false, "unsupported llama version"); diff --git a/examples/qualcomm/oss_scripts/llama/runner/cache_utils.h b/examples/qualcomm/oss_scripts/llama/runner/cache_utils.h new file mode 100644 index 00000000000..11eeb52d9b4 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/runner/cache_utils.h @@ -0,0 +1,15 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +enum CacheMode { + StaticCahce = 0, + // For models with global/local attention architecture (e.g., Gemma3), + HybridCache, +}; diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp index 9ce1abafa04..a049b54abb6 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp @@ -116,6 +116,84 @@ void KVManager::init_attention_mask( } } +template +void KVManager::init_attention_mask( + uint16_t* attention_mask, + const std::vector& attention_map, + int32_t ar_len, + int32_t n_past, + int32_t sliding_window) { + ET_CHECK_MSG( + attention_map.size() <= ar_len, + "The size of attention_map (%zu) doesn't match with ar_len (%d)", + attention_map.size(), + ar_len); + uint16_t neg_val = 0; + uint16_t pos_val = 65535; + // Clear the attention mask + std::fill_n(attention_mask, ar_len * metadata_.context_len, neg_val); + + // SMART_MASK requires special handling of attention mask + switch (kv_updater_) { + case KVManagerMode::SMART_MASK: { + uint16_t* past_ptr = attention_mask; + uint16_t* new_ptr = attention_mask + (metadata_.context_len - ar_len); + // All inputs will necessarily attend to n_past and itself + for (int i = 0; i < ar_len; i++) { + // Iterate across ar_len + if (attention_map[i] < 0) { + // If negative, attend to only past tokens + std::fill_n(past_ptr, n_past, pos_val); + } else { + // If positive, copy attention map from (relative to 0th input) parent + // Parent token index + const int32_t pidx = attention_map[i]; + uint16_t* parent_ptr = attention_mask + pidx * metadata_.context_len; + std::memcpy( + past_ptr, parent_ptr, metadata_.context_len * sizeof(uint16_t)); + } + // Attend to itself + new_ptr[i] = pos_val; + + // mask by limitation of sliding_window + int32_t avalible_context_len = sliding_window - (i + 1) - n_past; + if (n_past > avalible_context_len) { + std::fill_n(past_ptr, n_past - avalible_context_len, neg_val); + } + + past_ptr += metadata_.context_len; + new_ptr += metadata_.context_len; + } + break; + } + case KVManagerMode::SHIFT_POINTER: { + // Only fill in ar_len. Rest will be padding + const size_t attn_row_start = metadata_.context_len - n_past - ar_len; + for (int i = 0; i < ar_len; i++) { + uint16_t* cur_ptr = + attention_mask + i * metadata_.context_len + attn_row_start; + // Attend to itself + cur_ptr[n_past + i] = pos_val; + if (attention_map[i] < 0) { + // If negative, attend to only past tokens + std::fill_n(cur_ptr, n_past, pos_val); + } else { + // If positive, copy attention map from (relative to 0th input) parent + // Parent token index + const int32_t pidx = attention_map[i]; + uint16_t* parent_ptr = + attention_mask + pidx * metadata_.context_len + attn_row_start; + std::memcpy( + cur_ptr, parent_ptr, (n_past + pidx + 1) * sizeof(uint16_t)); + } + } + break; + } + default: + break; + } +} + template void KVManager::update_attention_mask( uint16_t* attention_mask, @@ -135,6 +213,41 @@ void KVManager::update_attention_mask( } } +template +void KVManager::update_attention_mask( + uint16_t* attention_mask, + int32_t ar_len, + int32_t n_past, + int32_t n_update, + int32_t sliding_window) { + uint16_t pos_val = 65535; + uint16_t neg_val = 0; + uint16_t* cur_ptr = attention_mask; + if (kv_updater_ == KVManagerMode::SMART_MASK) + cur_ptr += n_past; + if (kv_updater_ == KVManagerMode::SHIFT_POINTER) + cur_ptr += metadata_.context_len - n_past - ar_len - n_update; + + for (int i = 0; i < ar_len; i++) { + std::fill_n(cur_ptr, n_update, pos_val); + int32_t avalible_cache_len = sliding_window - (i + 1); + if (kv_updater_ == KVManagerMode::SMART_MASK) { + if (n_past + n_update > avalible_cache_len) { + std::fill_n( + cur_ptr - n_past, n_past + n_update - avalible_cache_len, neg_val); + } + } else if (kv_updater_ == KVManagerMode::SHIFT_POINTER) { + if (std::abs(n_past + ar_len) > avalible_cache_len) { + int32_t n_invalid = n_past - avalible_cache_len; + std::fill_n( + cur_ptr, std::abs(n_past + ar_len) - avalible_cache_len, neg_val); + } + + cur_ptr += metadata_.context_len; + } + } +} + template void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { cur_ar_len_ = ar_len; diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h index c20a5a1ab60..af9cf49a34f 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h @@ -78,6 +78,31 @@ class KVManager { int32_t ar_len, int32_t n_past); + /** + * @brief Initialize attention mask based on kv manager mode, and attention + * map. + * For example, + * ar_len = 4, CL = 6, n_past = 0, + * attention map: {-1, 0, 1, 2} and SMART_MASK. + * Attention_mask will be: + * [ 0 0 65535 0 0 0 ] + * [ 0 0 65535 65535 0 0 ] + * [ 0 0 65535 65535 65535 0 ] + * [ 0 0 65535 65535 65535 65535 ] + * @param attention_mask Pointer to the attention mask array to be + * initialized. + * @param attention_map Vector containing the attention map values. The shape + * of attention map should be [ar_len]. + * @param ar_len Length of input tokens. + * @param n_past Number of past elements in the cache. + */ + void init_attention_mask( + uint16_t* attention_mask, + const std::vector& attention_map, + int32_t ar_len, + int32_t n_past, + int32_t sliding_window); + /** * @brief Update attention mask based on kv manager mode, and n_update. * @param attention_mask Pointer to the attention mask array to be @@ -92,6 +117,23 @@ class KVManager { int32_t n_past, int32_t n_update); + /** + * @brief Update attention mask based on kv manager mode, and n_update. + * @param attention_mask Pointer to the attention mask array to be + * initialized. + * @param ar_len Length of input tokens. + * @param n_past Number of past elements in the cache. + * @param n_update Number of elements to be updated. + * @param sliding_window Length of sliding window for sliding window attention + * mask + */ + void update_attention_mask( + uint16_t* attention_mask, + int32_t ar_len, + int32_t n_past, + int32_t n_update, + int32_t sliding_window); + /** * @brief Reset the data pointer of the I/O cache tensor based on number of * past cache, kv manager mode, current ar length and KV cache data pointer diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h index 174c7f7504f..fe5e4b49230 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h @@ -28,6 +28,7 @@ class LhdTokenGenerator : public TokenGenerator { int32_t ngram; int32_t window; int32_t gcap; + int sliding_window; }; LhdTokenGenerator( tokenizers::Tokenizer* tokenizer, @@ -49,7 +50,8 @@ class LhdTokenGenerator : public TokenGenerator { metadata.num_layers, metadata.ar_len, metadata.vocab_size, - metadata.use_int64_token}, + metadata.use_int64_token, + metadata.sliding_window}, stats), metadata_(metadata), lhd_branch_(metadata.ngram - 1, std::vector(metadata.window)), diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp index 787185c2249..73da764b584 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp @@ -35,19 +35,35 @@ PromptProcessor::PromptProcessor( input_pos_.size = 0; else input_pos_.size = metadata_.ar_len * sizeof(int32_t); - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + + switch (metadata_.cache_mode) { + case CacheMode::StaticCahce: + attention_mask_.size = + metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + window_attention_mask_.size = 0; + break; + case CacheMode::HybridCache: + attention_mask_.size = + metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + window_attention_mask_.size = + metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + break; + default: + ET_CHECK_MSG(false, "Unsupported llama cache mode"); + break; + } + logits_.size = metadata_.ar_len * metadata_.vocab_size * sizeof(uint16_t); }; - template void PromptProcessor::init_io( IMemAlloc* buffer_manager, Result method_meta) { + size_t idx = 0; input_tensors_.reserve(method_meta->num_inputs()); output_tensors_.reserve(method_meta->num_outputs()); // [I]: input_tokens - Result input_toks = method_meta->input_tensor_meta(0); + Result input_toks = method_meta->input_tensor_meta(idx++); input_toks_.data = reinterpret_cast(buffer_manager->allocate(input_toks_.size)); input_toks_.tensor = std::make_unique( @@ -61,7 +77,7 @@ void PromptProcessor::init_io( input_toks_.data, input_toks_.size, input_toks.get()); // [I]: attention_mask - Result attention_mask = method_meta->input_tensor_meta(1); + Result attention_mask = method_meta->input_tensor_meta(idx++); attention_mask_.data = reinterpret_cast( buffer_manager->allocate(attention_mask_.size)); attention_mask_.tensor = std::make_unique( @@ -75,9 +91,30 @@ void PromptProcessor::init_io( buffer_manager->add_memory_info( attention_mask_.data, attention_mask_.size, attention_mask.get()); + // [I]: sliding window attention_mask + if (metadata_.cache_mode == CacheMode::HybridCache) { + Result window_attention_mask = + method_meta->input_tensor_meta(idx++); + window_attention_mask_.data = reinterpret_cast( + buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.tensor = std::make_unique( + window_attention_mask->scalar_type(), + window_attention_mask->sizes().size(), + const_cast( + window_attention_mask->sizes().data()), + window_attention_mask_.data, + const_cast( + window_attention_mask->dim_order().data())); + input_tensors_.emplace_back(window_attention_mask_.tensor.get()); + buffer_manager->add_memory_info( + window_attention_mask_.data, + window_attention_mask_.size, + window_attention_mask.get()); + } + if (!is_bert()) { // [I]: input_pos - Result input_pos = method_meta->input_tensor_meta(2); + Result input_pos = method_meta->input_tensor_meta(idx++); input_pos_.data = reinterpret_cast(buffer_manager->allocate(input_pos_.size)); input_pos_.tensor = std::make_unique( @@ -91,7 +128,7 @@ void PromptProcessor::init_io( input_pos_.data, input_pos_.size, input_pos.get()); // [I] kv_cache - int index = 3; // bypass input_tokens, atten_mask, input_pos + size_t index = idx; // bypass input_tokens, atten_mask, input_pos for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); @@ -133,7 +170,7 @@ void PromptProcessor::init_io( buffer_manager->add_memory_info(logits_.data, logits_.size, logits.get()); // [O] kv_cache - int index = 1; + size_t index = 1; for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); @@ -234,6 +271,16 @@ Result PromptProcessor::prefill( // Initialize attention mask with current position kv_manager_->init_attention_mask( attention_mask_.data, attention_map, metadata_.ar_len, pos); + // Initialize window attention mask with current position + if (metadata_.cache_mode == CacheMode::HybridCache) { + kv_manager_->init_attention_mask( + window_attention_mask_.data, + attention_map, + metadata_.ar_len, + pos, + metadata_.sliding_window); + } + // Initialize the output of the module ET_CHECK_MSG( decoder_runner_->set_outputs(method_name_, output_tensors_) == @@ -275,9 +322,18 @@ Result PromptProcessor::prefill( } // Update KV Cache with the output results kv_manager_->update_cache(metadata_.ar_len, pos, n_update, {}); + // Update attention mask with current position kv_manager_->update_attention_mask( attention_mask_.data, metadata_.ar_len, pos, n_update); + if (metadata_.cache_mode == CacheMode::HybridCache) { + kv_manager_->update_attention_mask( + window_attention_mask_.data, + metadata_.ar_len, + pos, + n_update, + metadata_.sliding_window); + } prompt_pos += metadata_.ar_len; pos += metadata_.ar_len; } diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h index 04945558ae5..a3dd2079461 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h @@ -7,6 +7,7 @@ */ #pragma once +#include #include #include #include @@ -29,6 +30,8 @@ class PromptProcessor { int32_t ar_len; int32_t vocab_size; bool use_int64_token; + int sliding_window; + CacheMode cache_mode; }; PromptProcessor( DecoderRunner* decoder_runner, @@ -72,8 +75,13 @@ class PromptProcessor { * @return Total I/O size in bytes. */ inline const size_t total_prompt_processor_io_size_in_bytes() const { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size; + if (metadata_.cache_mode == CacheMode::HybridCache) { + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size; + } else { + return input_toks_.size + input_pos_.size + attention_mask_.size + + logits_.size; + } } private: @@ -103,6 +111,7 @@ class PromptProcessor { TensorStruct input_toks_; TensorStruct input_pos_; TensorStruct attention_mask_; + TensorStruct window_attention_mask_; TensorStruct logits_; // layer -> head -> TensorImpl diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 26b8cbfa991..83ea0b88ad0 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -122,12 +122,13 @@ Runner::Runner( decoder_model_version_ = DecoderModelVersion::kLlama2; } else if (decoder_model_version == "llama3") { decoder_model_version_ = DecoderModelVersion::kLlama3; - } else if (decoder_model_version == "qwen2_5") { - decoder_model_version_ = DecoderModelVersion::kQwen2_5; - } else if (decoder_model_version == "qwen3") { - decoder_model_version_ = DecoderModelVersion::kQwen3; + } else if (decoder_model_version == "gemma3") { + decoder_model_version_ = DecoderModelVersion::kGemma3; + cache_mode_ = CacheMode::HybridCache; } else if (decoder_model_version == "phi_4_mini") { decoder_model_version_ = DecoderModelVersion::kPhi4; + } else if (decoder_model_version == "qwen2_5") { + decoder_model_version_ = DecoderModelVersion::kQwen2_5; } else if (decoder_model_version == "smollm2_135m") { decoder_model_version_ = DecoderModelVersion::kSmollm2_135m; } else { @@ -192,6 +193,8 @@ Error Runner::load() { eos_ids->insert(tokenizer_->encode("<|end|>", 0, 0).get()[0]); } else if (decoder_model_version_ == DecoderModelVersion::kQwen3) { eos_ids->insert(tokenizer_->encode("<|im_end|>", 0, 0).get()[0]); + } else if (decoder_model_version_ == DecoderModelVersion::kGemma3) { + eos_ids->insert(tokenizer_->encode("", 0, 0).get()[0]); } // Try avoid getMetadataHelper as it is time consuming. @@ -207,7 +210,6 @@ Error Runner::load() { ET_CHECK_OK_OR_RETURN_ERROR(decoder_runner_->load(method_names)); ET_LOG(Info, "Reading metadata from model"); - // retrieve any method meta, can be either prefill or kv int64_t num_layers = ET_UNWRAP(module_->get("get_n_layers")).toScalar().to(); @@ -246,6 +248,13 @@ Error Runner::load() { std::min(token_generator_ar_len, prompt_processor_ar_len); max_ar_len = std::max(token_generator_ar_len, prompt_processor_ar_len); + // Load the sliding window size if the model supports it. + // This is used to configure the attention mask for models with window + // attention + int32_t sliding_window = context_len_; + if (module_->method_names()->count("get_sliding_window") > 0) { + sliding_window = ET_UNWRAP(module_->get("get_sliding_window")).toInt(); + } kv_manager_ = std::make_unique>( kv_updater_, typename KVManager::Metadata{ @@ -266,8 +275,16 @@ Error Runner::load() { num_layers, prompt_processor_ar_len, vocab_size, - use_int64_token}); + use_int64_token, + sliding_window, + cache_mode_}); if (eval_mode_ == EvalMode::kLookaheadDecoding) { + // TODO: sliding window attention will be supported in future. + if (sliding_window < context_len_) { + ET_CHECK_MSG( + false, + "Lookahead decoding (eval_mode == 2) is not yet supported for sliding window attention."); + } token_generator_ = std::make_unique>( tokenizer_.get(), decoder_runner_.get(), @@ -283,7 +300,8 @@ Error Runner::load() { use_int64_token, ngram_, window_, - gcap_}, + gcap_, + sliding_window}, &stats_); } else { token_generator_ = std::make_unique>( @@ -298,7 +316,9 @@ Error Runner::load() { num_layers, token_generator_ar_len, vocab_size, - use_int64_token}, + use_int64_token, + sliding_window, + cache_mode_}, &stats_); } diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index a771a3c0108..cb6c08d9c87 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -31,10 +32,11 @@ namespace example { enum DecoderModelVersion { kLlama2 = 0, kLlama3, + kGemma3, + kPhi4, kQwen2_5, kQwen3, - kPhi4, - kSmollm2_135m + kSmollm2_135m, }; enum KvBitWidth { @@ -99,6 +101,10 @@ class Runner : public executorch::extension::llm::IRunner { int ngram_{0}; int window_{0}; int gcap_{0}; + + // Defaults to StaticCahce, indicating that the model does not use a + // global/local architecture. + CacheMode cache_mode_{CacheMode::StaticCahce}; int64_t cur_pos_{0}; std::string tokenizer_path_; @@ -106,6 +112,7 @@ class Runner : public executorch::extension::llm::IRunner { std::string dump_logits_path_; float temperature_; EvalMode eval_mode_; + DecoderModelVersion decoder_model_version_; KVManagerMode kv_updater_; std::unique_ptr buffer_manager_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp index b04d3e4486d..6775c08bd87 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp @@ -40,17 +40,35 @@ TokenGenerator::TokenGenerator( input_pos_.size = metadata_.ar_len * sizeof(int32_t); attention_mask_.size = metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + + switch (metadata_.cache_mode) { + case CacheMode::StaticCahce: + attention_mask_.size = + metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + window_attention_mask_.size = 0; + break; + case CacheMode::HybridCache: + attention_mask_.size = + metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + window_attention_mask_.size = + metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + break; + default: + ET_CHECK_MSG(false, "Unsupported llama cache mode"); + break; + } + logits_.size = metadata_.ar_len * metadata_.vocab_size * sizeof(uint16_t); } - template void TokenGenerator::init_io( IMemAlloc* buffer_manager, Result method_meta) { + size_t idx = 0; input_tensors_.reserve(method_meta->num_inputs()); output_tensors_.reserve(method_meta->num_outputs()); // [I]: input_tokens - Result input_toks = method_meta->input_tensor_meta(0); + Result input_toks = method_meta->input_tensor_meta(idx++); input_toks_.data = reinterpret_cast(buffer_manager->allocate(input_toks_.size)); input_toks_.tensor = std::make_unique( @@ -64,7 +82,7 @@ void TokenGenerator::init_io( input_toks_.data, input_toks_.size, input_toks.get()); // [I]: attention_mask - Result attention_mask = method_meta->input_tensor_meta(1); + Result attention_mask = method_meta->input_tensor_meta(idx++); attention_mask_.data = reinterpret_cast( buffer_manager->allocate(attention_mask_.size)); attention_mask_.tensor = std::make_unique( @@ -78,8 +96,29 @@ void TokenGenerator::init_io( buffer_manager->add_memory_info( attention_mask_.data, attention_mask_.size, attention_mask.get()); + // [I]: sliding window attention_mask + if (metadata_.cache_mode == CacheMode::HybridCache) { + Result window_attention_mask = + method_meta->input_tensor_meta(idx++); + window_attention_mask_.data = reinterpret_cast( + buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.tensor = std::make_unique( + window_attention_mask->scalar_type(), + window_attention_mask->sizes().size(), + const_cast( + window_attention_mask->sizes().data()), + window_attention_mask_.data, + const_cast( + window_attention_mask->dim_order().data())); + input_tensors_.emplace_back(window_attention_mask_.tensor.get()); + buffer_manager->add_memory_info( + window_attention_mask_.data, + window_attention_mask_.size, + window_attention_mask.get()); + } + // [I]: input_pos - Result input_pos = method_meta->input_tensor_meta(2); + Result input_pos = method_meta->input_tensor_meta(idx++); input_pos_.data = reinterpret_cast(buffer_manager->allocate(input_pos_.size)); input_pos_.tensor = std::make_unique( @@ -93,7 +132,7 @@ void TokenGenerator::init_io( input_pos_.data, input_pos_.size, input_pos.get()); // [I] kv_cache - int index = 3; // bypass input_tokens, atten_mask, input_pos + size_t index = idx; // bypass input_tokens, atten_mask, input_pos for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); @@ -198,9 +237,20 @@ Result TokenGenerator::generate( kv_manager_->rearrange_cache(metadata_.ar_len); std::vector attention_map(metadata_.ar_len); std::iota(attention_map.begin(), attention_map.end(), -1); + // Initialize attention mask with current position kv_manager_->init_attention_mask( attention_mask_.data, attention_map, metadata_.ar_len, pos); + // Initialize window attention mask with current position + if (metadata_.cache_mode == CacheMode::HybridCache) { + kv_manager_->init_attention_mask( + window_attention_mask_.data, + attention_map, + metadata_.ar_len, + pos, + metadata_.sliding_window); + } + // Initialize the output of the module ET_CHECK_MSG( decoder_runner_->set_outputs(method_name_, output_tensors_) == @@ -252,6 +302,14 @@ Result TokenGenerator::generate( // Update attention mask with current position kv_manager_->update_attention_mask( attention_mask_.data, metadata_.ar_len, pos, metadata_.ar_len); + if (metadata_.cache_mode == CacheMode::HybridCache) { + kv_manager_->update_attention_mask( + window_attention_mask_.data, + metadata_.ar_len, + pos, + metadata_.ar_len, + metadata_.sliding_window); + } pos++; // print the token as string, decode it with the Tokenizer object @@ -267,7 +325,6 @@ Result TokenGenerator::generate( } return pos - start_pos; } - // Explicit instantiations template class TokenGenerator; template class TokenGenerator; diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h index 682c1531b88..9f0198f3040 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h @@ -7,6 +7,7 @@ */ #pragma once +#include #include #include #include @@ -30,6 +31,8 @@ class TokenGenerator { int32_t ar_len; int32_t vocab_size; bool use_int64_token; + int sliding_window; + CacheMode cache_mode; }; TokenGenerator( tokenizers::Tokenizer* tokenizer, @@ -73,8 +76,13 @@ class TokenGenerator { std::function token_callback, bool dump_logits); inline const size_t total_token_generator_io_size_in_bytes() const { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size; + if (metadata_.cache_mode == CacheMode::HybridCache) { + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size; + } else { + return input_toks_.size + input_pos_.size + attention_mask_.size + + logits_.size; + } } protected: @@ -88,6 +96,7 @@ class TokenGenerator { TensorStruct input_toks_; TensorStruct input_pos_; TensorStruct attention_mask_; + TensorStruct window_attention_mask_; TensorStruct logits_; // layer -> head -> TensorImpl