diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/Constants.swift b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/Constants.swift index 1c2a9d12b97..4bc28f996ed 100644 --- a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/Constants.swift +++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/Constants.swift @@ -26,5 +26,14 @@ You are a helpful assistant. public static let llama3PromptTemplate = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>%@<|eot_id|><|start_header_id|>assistant<|end_header_id|>" -public static let phi4PromptTemplate = "<|user|>%@<|end|><|assistant|>" + public static let phi4PromptTemplate = "<|user|>%@<|end|><|assistant|>" + + public static let gemma3PromptTemplate = """ +user + + +%@ +model + +""" } diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift index c6b8b71dfc1..6f8b8bb2c79 100644 --- a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift +++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift @@ -87,6 +87,7 @@ struct ContentView: View { case llava case qwen3 case phi4 + case gemma3 static func fromPath(_ path: String) -> ModelType { let filename = (path as NSString).lastPathComponent.lowercased() @@ -98,7 +99,9 @@ struct ContentView: View { return .qwen3 } else if filename.hasPrefix("phi4") { return .phi4 - } + } else if filename.hasPrefix("gemma3") { + return .gemma3 + } print("Unknown model type in path: \(path). Model filename should start with one of: llama, llava, qwen3, or phi4") exit(1) } @@ -346,7 +349,7 @@ struct ContentView: View { } switch modelType { - case .llama, .qwen3, .phi4: + case .llama, .qwen3, .phi4, .gemma3: runnerHolder.llamaRunner = runnerHolder.llamaRunner ?? LLaMARunner(modelPath: modelPath, tokenizerPath: tokenizerPath) case .llava: runnerHolder.llavaRunner = runnerHolder.llavaRunner ?? LLaVARunner(modelPath: modelPath, tokenizerPath: tokenizerPath) @@ -354,7 +357,7 @@ struct ContentView: View { guard !shouldStopGenerating else { return } switch modelType { - case .llama, .qwen3, .phi4: + case .llama, .qwen3, .phi4, .gemma3: if let runner = runnerHolder.llamaRunner, !runner.isLoaded() { var error: Error? let startLoadTime = Date() @@ -479,6 +482,8 @@ struct ContentView: View { prompt = String(format: Constants.llama3PromptTemplate, text) case .phi4: prompt = String(format: Constants.phi4PromptTemplate, text) + case .gemma3: + prompt = String(format: Constants.gemma3PromptTemplate, text) } try runnerHolder.llamaRunner?.generate(prompt, sequenceLength: seq_len) { token in diff --git a/examples/models/gemma3/__init__.py b/examples/models/gemma3/__init__.py new file mode 100644 index 00000000000..ae34db47954 --- /dev/null +++ b/examples/models/gemma3/__init__.py @@ -0,0 +1,16 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.gemma3.convert_weights import convert_weights +from executorch.examples.models.llama.model import Llama2Model + + +class Gemma3Model(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "Gemma3Model", + "convert_weights", +] diff --git a/examples/models/gemma3/config/1b_config.json b/examples/models/gemma3/config/1b_config.json new file mode 100644 index 00000000000..3a9e673716b --- /dev/null +++ b/examples/models/gemma3/config/1b_config.json @@ -0,0 +1,23 @@ +{ + "dim": 1152, + "ffn_dim_multiplier": 1, + "hidden_dim": 6912, + "n_heads": 4, + "head_dim": 256, + "n_kv_heads": 1, + "n_layers": 26, + "act_fn": "gelu_approx", + "norm_type": "gemma3", + "norm_eps": 1e-06, + "post_attention_norm": true, + "post_ffn_norm": true, + "rope_theta": 1000000.0, + "use_scaled_rope": false, + "apply_embedding": true, + "embedding_scale_factor": 33.941125497, + "vocab_size": 262144, + "use_hf_rope": true, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true +} diff --git a/examples/models/gemma3/config/gemma3_xnnpack_q8da4w.yaml b/examples/models/gemma3/config/gemma3_xnnpack_q8da4w.yaml new file mode 100644 index 00000000000..cfd212f6239 --- /dev/null +++ b/examples/models/gemma3/config/gemma3_xnnpack_q8da4w.yaml @@ -0,0 +1,17 @@ +base: + model_class: gemma3_1b + metadata: '{"get_bos_id":[2, 105], "get_eos_ids":[1, 106]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: True + dtype_override: fp32 + local_global_attention: [512,512,512,512,512,0,512,512,512,512,512,0,512,512,512,512,512,0,512,512,512,512,512,0,512,512] + +quantization: + qmode: 8da4w + +backend: + xnnpack: + enabled: True + extended_ops: True \ No newline at end of file diff --git a/examples/models/gemma3/convert_weights.py b/examples/models/gemma3/convert_weights.py new file mode 100644 index 00000000000..ed44b9eb1cc --- /dev/null +++ b/examples/models/gemma3/convert_weights.py @@ -0,0 +1,110 @@ +import argparse + +import json +import os +from typing import Dict + +import torch +from safetensors.torch import load_file + +from torchtune.models.convert_weights import get_mapped_key + + +# Weight mappings from Gemma 3's checkpoint to ExecuTorch's transformer parameters. +_GEMMA3_TO_EXECUTORCH = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm_fn.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm_fn.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.post_attention_norm.weight", + "model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.post_ffn_norm.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", +} + + +def gemma3_to_executorch( + state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Convert the state dict so that it matches what ExecuTorch's transformer definition expects. + """ + converted_state_dict = {} + for key, value in state_dict.items(): + new_key = get_mapped_key(key, _GEMMA3_TO_EXECUTORCH) + converted_state_dict[new_key] = value + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + return converted_state_dict + + +def load_checkpoint_from_safetensors(input_dir: str) -> Dict: + index_path = os.path.join(input_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + # Sharded checkpoint. + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + checkpoint_shards = sorted(set(weight_map.values())) + + # Load all the shards into memory + shard_to_weights = {} + for shard in checkpoint_shards: + shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) + + # Merge tensors into consolidated state dict. + merged_state_dict = {} + for weight_name, shard in weight_map.items(): + tensor = shard_to_weights[shard][weight_name] + merged_state_dict[weight_name] = tensor + return merged_state_dict + else: + # Single checkpoint. + state_dict = load_file(os.path.join(input_dir, "model.safetensors")) + return state_dict + + +def load_checkpoint(input_dir: str) -> Dict: + pytorch_path = os.path.join(input_dir, "pytorch_model.bin") + if os.path.exists(pytorch_path): + print("Loading checkpoint from PyTorch .bin file") + return torch.load(pytorch_path, map_location="cpu", weights_only=True) + print("Loading checkpoint from safetensors directory") + return load_checkpoint_from_safetensors(input_dir) + + +def convert_weights(input_dir: str, output_file: str) -> None: + print("Loading checkpoint...") + sd = load_checkpoint(input_dir) + print("Converting checkpoint...") + sd = gemma3_to_executorch(sd) + print("Saving checkpoint...") + torch.save(sd, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Gemma3 weights to ExecuTorch transformer format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 63d783c3332..fb974de31d2 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.norm import Norm from executorch.examples.models.llama.rope import Rope @@ -324,7 +324,14 @@ def update( @register_attention("mha") class AttentionMHA(Attention): - def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + q_norm_fn: Optional[Norm] = None, + k_norm_fn: Optional[Norm] = None, + ): super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads @@ -343,11 +350,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.qk_norm_before_rope = args.qk_norm_before_rope self.enable_dynamic_shape = args.enable_dynamic_shape - if self.use_qk_norm: - q_norm_dim = self.head_dim - k_norm_dim = self.head_dim - self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps) - self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps) + self.q_norm_fn = q_norm_fn + self.k_norm_fn = k_norm_fn self.wq = nn.Linear( self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 43ae595f797..79134079483 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -107,6 +107,7 @@ "qwen3_0_6b", "qwen3_1_7b", "qwen3_4b", + "gemma3_1b", "phi_4_mini", "smollm2", ] @@ -118,6 +119,7 @@ "qwen3_0_6b": "Qwen/Qwen3-0.6B", "qwen3_1_7b": "Qwen/Qwen3-1.7B", "qwen3_4b": "Qwen/Qwen3-4B", + "gemma3_1b": "google/gemma-3-1b-it", } @@ -609,6 +611,10 @@ def export_llama( from executorch.examples.models.smollm2 import ( # pyre-ignore[21] convert_weights, ) + elif model_name.startswith("gemma3"): + from executorch.examples.models.gemma3 import ( # pyre-ignore[21] + convert_weights, + ) else: raise ValueError( f"Converting weights to meta format for {model_name} is not yet supported" diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 1fdcdcd91fc..fa2ed7204ca 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -7,7 +7,7 @@ # Please refer to README.md in the same folder for more information. -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Type, Union import torch import torch.nn.functional as F @@ -19,7 +19,7 @@ ) from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.norm import Norm, NORM_REGISTRY from executorch.examples.models.llama.rope import Rope from torch import nn @@ -29,12 +29,13 @@ def __init__(self, args: ModelArgs): super().__init__() assert args.hidden_dim is not None hidden_dim: int = args.hidden_dim + self.act_fn = args.act_fn.get_function() # Store the actual function self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + return self.w2(self.act_fn(self.w1(x)) * self.w3(x)) class ConditionalFeedForward(nn.Module): @@ -84,7 +85,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs, attention: Attention): + def __init__(self, args: ModelArgs, attention: Attention, norm_cls: Type[Norm]): """ Transformer block with support for pre-norm and post-norm. Args: @@ -103,8 +104,12 @@ def __init__(self, args: ModelArgs, attention: Attention): self.block_sparse_moe = MOEFeedForward(args) else: self.feed_forward = FeedForward(args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.attention_norm = norm_cls(args.dim, eps=args.norm_eps) + if args.post_attention_norm: + self.post_attention_norm = norm_cls(args.dim, eps=args.norm_eps) + self.ffn_norm = norm_cls(args.dim, eps=args.norm_eps) + if args.post_ffn_norm: + self.post_ffn_norm = norm_cls(args.dim, eps=args.norm_eps) @classmethod def from_type(cls, layer_id, args, rope) -> "TransformerBlock": @@ -120,20 +125,50 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock": f"Unknown attention type: {args.attention_type}. " f"Available: {list(ATTENTION_REGISTRY.keys())}" ) + if args.norm_type not in NORM_REGISTRY: + raise ValueError( + f"Unknown norm type: {args.norm_type}. " + f"Available: {list(NORM_REGISTRY.keys())}" + ) + norm_cls = NORM_REGISTRY[args.norm_type] + + # Create qk_norm instances if needed + q_norm_fn = None + k_norm_fn = None + if args.attention_type == "static": + q_norm_fn = torch.nn.Identity() + k_norm_fn = torch.nn.Identity() + if args.use_qk_norm: + q_norm_fn = norm_cls(args.head_dim, eps=args.norm_eps) + k_norm_fn = norm_cls(args.head_dim, eps=args.norm_eps) + cls = ATTENTION_REGISTRY[args.attention_type] - attention = cls(args, layer_id, rope) - return TransformerBlock(args, attention) + attention = cls(args, layer_id, rope, q_norm_fn, k_norm_fn) + + return TransformerBlock(args, attention, norm_cls) def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN - h, attn_options_update = self.attention.forward( - self.attention_norm(x), freqs_cos, freqs_sin, **attn_options + # Attention. + residual = x + x_norm = self.attention_norm(x) + + hidden, attn_options_update = self.attention.forward( + x_norm, freqs_cos, freqs_sin, **attn_options ) + if self.post_attention_norm: + hidden = self.post_attention_norm(hidden) + hidden = residual + hidden - h = x + h + # MLP. + residual = hidden if hasattr(self, "block_sparse_moe"): - out = h + self.block_sparse_moe(self.ffn_norm(h)) + hidden = self.block_sparse_moe(self.ffn_norm(hidden)) else: - out = h + self.feed_forward(self.ffn_norm(h)) + hidden = self.feed_forward(self.ffn_norm(hidden)) + if self.post_ffn_norm: + hidden = self.post_ffn_norm(hidden) + out = residual + hidden + return out, attn_options_update @@ -152,6 +187,7 @@ def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): self.vocab_size = params.vocab_size self.n_layers = params.n_layers self.apply_embedding = params.apply_embedding + self.embedding_scale_factor = params.embedding_scale_factor self.apply_output = params.apply_output self.tok_embeddings = ( @@ -161,7 +197,13 @@ def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): ) self.layers = layers self.rope = rope - self.norm = RMSNorm(params.dim, eps=params.norm_eps) + if params.norm_type not in NORM_REGISTRY: + raise ValueError( + f"Unknown norm type: {params.norm_type}. " + f"Available: {list(NORM_REGISTRY.keys())}" + ) + norm_cls = NORM_REGISTRY[params.norm_type] + self.norm = norm_cls(params.dim, eps=params.norm_eps) self.output = ( nn.Linear(params.dim, params.vocab_size, bias=False) if self.apply_output @@ -180,12 +222,13 @@ def forward( attn_options: Optional[ForwardOptions] = None, h: Optional[torch.FloatTensor] = None, # embeddings ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[Any]]]: + if (tokens is None) ^ (h is not None): raise ValueError( "You cannot specify both tokens and h at the same time, and must specify either one" ) if self.apply_embedding and tokens is not None and h is None: - h = self.tok_embeddings(tokens) + h = self.embedding_scale_factor * self.tok_embeddings(tokens) if attn_options is None: attn_options = {} @@ -251,11 +294,29 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: f"Unknown attention type: {model_args.attention_type}. " f"Available: {list(ATTENTION_REGISTRY.keys())}" ) + if model_args.norm_type not in NORM_REGISTRY: + raise ValueError( + f"Unknown norm type: {model_args.norm_type}. " + f"Available: {list(NORM_REGISTRY.keys())}" + ) + norm_cls = NORM_REGISTRY[model_args.norm_type] + layers = torch.nn.ModuleList() cls = ATTENTION_REGISTRY[model_args.attention_type] for layer_id in range(model_args.n_layers): - attention = cls(model_args, layer_id, rope) - transformer_block = TransformerBlock(model_args, attention) + # Create qk_norm instances if needed + q_norm_fn = None + k_norm_fn = None + if model_args.use_qk_norm: + q_norm_fn = norm_cls(model_args.head_dim, eps=model_args.norm_eps) + k_norm_fn = norm_cls(model_args.head_dim, eps=model_args.norm_eps) + elif model_args.attention_type == "static": + # StaticAttention expects Identity functions when qk_norm is disabled + q_norm_fn = torch.nn.Identity() + k_norm_fn = torch.nn.Identity() + + attention = cls(model_args, layer_id, rope, q_norm_fn, k_norm_fn) + transformer_block = TransformerBlock(model_args, attention, norm_cls) layers.append(transformer_block) return Transformer(model_args, layers, rope) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 5734cd66ef7..20ce454287b 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -1,6 +1,38 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field +from enum import Enum +from functools import partial from typing import Dict, Optional +import torch.nn.functional as F + + +class ActFn(Enum): + SILU = "silu" + GELU = "gelu" + GELU_APPROX = "gelu_approx" + + @classmethod + def from_string(cls, value: str) -> "ActFn": + """Convert string to ActFn enum.""" + try: + return cls(value) + except ValueError: + valid_values = [e.value for e in cls] + raise ValueError( + f"Invalid activation function: {value}. Valid options: {valid_values}" + ) + + def get_function(self): + """Return the corresponding activation function.""" + if self == ActFn.SILU: + return F.silu + elif self == ActFn.GELU: + return F.gelu + elif self == ActFn.GELU_APPROX: + return partial(F.gelu, approximate="tanh") + else: + raise ValueError(f"Unsupported activation function: {self}") + @dataclass class ModelArgs: @@ -14,6 +46,8 @@ class ModelArgs: multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 + post_attention_norm: bool = False + post_ffn_norm: bool = False max_batch_size: int = 1 max_seq_len: int = 2048 max_context_len: int = 2048 @@ -21,6 +55,8 @@ class ModelArgs: num_experts: int = 8 # Number of experts num_activated_experts: int = 2 # Number of experts to activate attention_type: str = "mha" # Attention type, registered in attention.py + norm_type: str = "rmsnorm" # Normalization type, registered in norm.py + act_fn: ActFn = field(default=ActFn.SILU) # Activation function type attention_qkv_bias: bool = False use_kv_cache: bool = False # Use key/value cache use_sdpa_with_kv_cache_op: bool = ( @@ -36,6 +72,7 @@ class ModelArgs: # A dictionary mapping from pruned token-id to original token-id output_prune_map: Optional[Dict[int, int]] = None apply_embedding: bool = True # Use embedding inside the transformer + embedding_scale_factor: float = 1.0 # Multiple by which to scale embeddings. apply_output: bool = True # Use output layer (unembedding) inside the transformer use_qk_norm: bool = False # apply normalization to q and k in the attention qk_norm_before_rope: bool = False # when to apply qk norm @@ -86,3 +123,7 @@ def find_multiple(n: int, k: int) -> int: if self.head_dim is None: self.head_dim = self.dim // self.n_heads + + # Convert string act_fn to enum if needed + if isinstance(self.act_fn, str): + self.act_fn = ActFn.from_string(self.act_fn) diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 3786e61cd05..fade8381529 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -4,11 +4,44 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from abc import ABC, abstractmethod +from typing import Dict, Type + import torch from torch import nn -class RMSNorm(torch.nn.Module): +class Norm(nn.Module, ABC): + """Abstract base class for normalization layers with unified interface.""" + + @abstractmethod + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for normalization layer. + + Args: + x: Input tensor + + Returns: + Normalized tensor + """ + pass + + +NORM_REGISTRY: Dict[str, Type[Norm]] = {} + + +def register_norm(name: str): + """Decorator to register norm classes""" + + def decorator(cls: Type[Norm]): + NORM_REGISTRY[name.lower()] = cls + return cls + + return decorator + + +@register_norm("rmsnorm") +class RMSNorm(Norm): def __init__(self, dim: int, eps: float = 1e-6): """ Initialize the RMSNorm normalization layer. @@ -53,3 +86,24 @@ def forward(self, x): """ output = self._norm(x.float()).type_as(x) return output * self.weight + + +@register_norm("gemma3") +class Gemma3RMSNorm(Norm): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 7e662317509..82440a34813 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -28,6 +28,7 @@ def __init__( self, llm_config: LlmConfig, tokenizer_config_path: Optional[str] = None, + tokenizer_type: Optional[str] = None, use_attention_sink: bool = False, ): with open(llm_config.base.params, "r") as f: @@ -35,6 +36,7 @@ def __init__( super().__init__( tokenizer_path=llm_config.base.tokenizer_path, tokenizer_config_path=tokenizer_config_path, + tokenizer_type=tokenizer_type, max_seq_len=llm_config.export.max_seq_length, max_batch_size=1, use_kv_cache=llm_config.model.use_kv_cache, @@ -89,6 +91,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="Path to an accompanying tokenizer_config.json, which provides metadata for the main tokenizer.json", ) + parser.add_argument( + "--tokenizer_type", + type=str, + choices=["sentencepiece", "huggingface", "llama2c", "tiktoken"], + help="Type of tokenizer", + ) + return parser @@ -105,6 +114,7 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None: show_tokens = args.show_tokens chat_mode = args.chat tokenizer_config_path = args.tokenizer_config_path + tokenizer_type = args.tokenizer_type use_attention_sink = args.use_attention_sink with torch.no_grad(): @@ -112,6 +122,7 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None: runner = runner_class( llm_config=llm_config, tokenizer_config_path=tokenizer_config_path, + tokenizer_type=tokenizer_type, use_attention_sink=use_attention_sink, ) diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 2baa8f5cd14..628addc9a8d 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -10,7 +10,7 @@ import torch -from pytorch_tokenizers import get_tokenizer +from pytorch_tokenizers import get_tokenizer, TokenizerType def sample_top_p(probs, p): @@ -52,6 +52,7 @@ def __init__( *, tokenizer_path: str, tokenizer_config_path: Optional[str] = None, + tokenizer_type: Optional[str] = None, max_seq_len: int, max_batch_size: int, use_kv_cache: bool, @@ -72,7 +73,11 @@ def __init__( self.max_seq_len = max_seq_len self.max_batch_size = max_batch_size self.use_kv_cache = use_kv_cache - self.tokenizer = get_tokenizer(tokenizer_path, tokenizer_config_path) + self.tokenizer = get_tokenizer( + tokenizer_path, + tokenizer_config_path, + TokenizerType.from_str(tokenizer_type), + ) self.device = device # For some models like qwen, mismatch is acceptable: https://github.com/QwenLM/Qwen2.5/issues/466#issuecomment-2146759706 if vocab_size != self.tokenizer.n_words: @@ -172,8 +177,10 @@ def text_completion( Note: This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. """ + prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False) + print(prompt_tokens) return self.generate( - prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), + prompt_tokens=prompt_tokens, max_seq_len=self.max_seq_len, temperature=temperature, top_p=top_p, diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 6d5d4730844..0eb89795493 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -38,6 +38,7 @@ def __init__(self, args): super().__init__( tokenizer_path=args.tokenizer, tokenizer_config_path=args.tokenizer_config, + tokenizer_type=args.tokenizer_type, max_seq_len=args.max_len, max_batch_size=1, use_kv_cache=args.kv_cache, @@ -101,6 +102,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="Path to an accompanying tokenizer_config.json, which provides metadata for the main tokenizer.json", ) + parser.add_argument( + "--tokenizer_type", + type=str, + choices=["sentencepiece", "huggingface", "llama2c", "tiktoken"], + help="Type of tokenizer", + ) + parser.add_argument( "--prompt", type=str, diff --git a/examples/models/llama/source_transformation/rms_norm.py b/examples/models/llama/source_transformation/rms_norm.py index 3d94f73b631..b9e27562df7 100644 --- a/examples/models/llama/source_transformation/rms_norm.py +++ b/examples/models/llama/source_transformation/rms_norm.py @@ -5,12 +5,19 @@ # LICENSE file in the root directory of this source tree. import torch -from executorch.examples.models.llama.llama_transformer import RMSNorm +from executorch.examples.models.llama.norm import Gemma3RMSNorm, RMSNorm +from torch import nn def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module): + """Replace custom norm implementations with torch.nn.RMSNorm. + + Handles both standard RMSNorm and Gemma3RMSNorm with appropriate + weight scaling conversions. + """ for name, child in module.named_children(): if isinstance(child, RMSNorm): + # Standard RMSNorm: direct replacement rms_norm = torch.nn.RMSNorm(child.dim, eps=child.eps) rms_norm.weight = child.weight setattr( @@ -18,6 +25,15 @@ def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module): name, rms_norm, ) + elif isinstance(child, Gemma3RMSNorm): + # Gemma3RMSNorm: convert weight scaling from (1.0 + w) to w + rms_norm = torch.nn.RMSNorm(child.dim, eps=child.eps) + rms_norm.weight = nn.Parameter(1.0 + child.weight) + setattr( + module, + name, + rms_norm, + ) else: replace_rms_norm_with_native_rms_norm(child) return module diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 57b5796cbb3..e18b6d3bda1 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -11,6 +11,7 @@ register_attention, ) from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.norm import Norm from executorch.examples.models.llama.rope import Rope @@ -365,7 +366,14 @@ class StaticAttention(Attention): model only needs to perform a concat to combine past and new data. """ - def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): + def __init__( + self, + config: ModelArgs, + layer_id: int, + rope: Rope, + q_norm_fn: Optional[Norm] = None, + k_norm_fn: Optional[Norm] = None, + ): super().__init__() self.n_heads = config.n_heads self.n_kv_heads = ( @@ -410,8 +418,16 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): self.rope = _Rope(rope.params.use_hf_rope) if self.use_qk_norm: - self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) - self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) + self.q_norm = ( + q_norm_fn + if q_norm_fn is not None + else torch.nn.RMSNorm(self.head_dim, config.norm_eps) + ) + self.k_norm = ( + k_norm_fn + if k_norm_fn is not None + else torch.nn.RMSNorm(self.head_dim, config.norm_eps) + ) else: self.q_norm = torch.nn.Identity() self.k_norm = torch.nn.Identity() @@ -512,10 +528,16 @@ def load_weights_from_attention_mha(self, other: AttentionMHA): if other.use_qk_norm: self.use_qk_norm = True self.qk_norm_before_rope = other.qk_norm_before_rope - self.q_norm = torch.nn.RMSNorm(other.q_norm_fn.dim, other.q_norm_fn.eps) - self.q_norm.load_state_dict(other.q_norm_fn.state_dict()) - self.k_norm = torch.nn.RMSNorm(other.k_norm_fn.dim, other.k_norm_fn.eps) - self.k_norm.load_state_dict(other.k_norm_fn.state_dict()) + if other.q_norm_fn is not None: + self.q_norm = torch.nn.RMSNorm( + other.q_norm_fn.weight.shape[0], other.q_norm_fn.eps + ) + self.q_norm.load_state_dict(other.q_norm_fn.state_dict()) + if other.k_norm_fn is not None: + self.k_norm = torch.nn.RMSNorm( + other.k_norm_fn.weight.shape[0], other.k_norm_fn.eps + ) + self.k_norm.load_state_dict(other.k_norm_fn.state_dict()) def linear_to_conv2d(self): def transfer_weight(linear, conv2d): diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 94bbb2d8b2e..c025b4e191e 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -41,6 +41,7 @@ class ModelType(str, Enum): qwen3_0_6b = "qwen3_0_6b" qwen3_1_7b = "qwen3_1_7b" qwen3_4b = "qwen3_4b" + gemma3_1b = "gemma3_1b" phi_4_mini = "phi_4_mini" smollm2 = "smollm2"