From f4aadcc8146a41a18c7207986bff137bff2681ab Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 3 Sep 2025 12:42:10 -0700 Subject: [PATCH 1/2] model : support LiquidAI LFM2 hybrid family (#13805) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Add support for [LiquidAI LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) model family. For more information about models, please read [the blog post](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models). - Support hybrid LFM2-350M, LFM2-700M, and LFM2-1.2B models. - Add `ShortConvBlock`. - Modify `construct_transformer` to construct hybrid architectures. - Move FeedForward to avoid cyclid dependency Instructions are in `examples/models/lfm2/README.md`. Pull Request resolved: https://github.com/pytorch/executorch/pull/13805 Test Plan: All commands in `README.md` are tests. ``` ❯ python -m examples.models.llama.runner.native \ --model lfm2_700m \ --pte lfm2_700m_8da4w.pte \ --tokenizer ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer.json \ --tokenizer_config ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer_config.json \ --prompt "<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n" \ --params examples/models/lfm2/config/lfm2_700m_config.json \ --max_len 128 \ -kv \ --temperature 0.3 ... I'm an AI designed to assist with generating text based on the prompts you provide. I'm a type of language model, but I don't have a physical form or consciousness. I operate based on complex algorithms and vast amounts of training data. How can I help you today? If you have a specific question or need assistance with something, feel free to ask! ... ``` Differential Revision: D81593776 Pulled By: jackzhxng --- examples/models/lfm2/README.md | 62 ++++++++++ examples/models/lfm2/__init__.py | 5 + .../models/lfm2/config/lfm2_1_2b_config.json | 34 ++++++ .../models/lfm2/config/lfm2_350m_config.json | 34 ++++++ .../models/lfm2/config/lfm2_700m_config.json | 34 ++++++ .../models/lfm2/config/lfm2_xnnpack_fp32.yaml | 12 ++ .../lfm2/config/lfm2_xnnpack_q8da4w.yaml | 15 +++ examples/models/lfm2/convert_weights.py | 74 ++++++++++++ examples/models/lfm2/short_conv.py | 110 ++++++++++++++++++ examples/models/llama/export_llama_lib.py | 8 ++ examples/models/llama/feed_forward.py | 14 +++ examples/models/llama/llama_transformer.py | 34 +++--- examples/models/llama/model_args.py | 2 + .../llama/source_transformation/spin_quant.py | 2 +- extension/llm/export/config/llm_config.py | 3 + 15 files changed, 424 insertions(+), 19 deletions(-) create mode 100644 examples/models/lfm2/README.md create mode 100644 examples/models/lfm2/__init__.py create mode 100644 examples/models/lfm2/config/lfm2_1_2b_config.json create mode 100644 examples/models/lfm2/config/lfm2_350m_config.json create mode 100644 examples/models/lfm2/config/lfm2_700m_config.json create mode 100644 examples/models/lfm2/config/lfm2_xnnpack_fp32.yaml create mode 100644 examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml create mode 100644 examples/models/lfm2/convert_weights.py create mode 100644 examples/models/lfm2/short_conv.py create mode 100644 examples/models/llama/feed_forward.py diff --git a/examples/models/lfm2/README.md b/examples/models/lfm2/README.md new file mode 100644 index 00000000000..8f2a84be376 --- /dev/null +++ b/examples/models/lfm2/README.md @@ -0,0 +1,62 @@ +## Summary +[LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) is a new generation of hybrid models developed by [Liquid AI](https://www.liquid.ai/) and available in 3 variants - 350M, 700M, 1.2B. + +## Instructions + +LFM2 uses the same example code as optimized Llama model, while the checkpoint, model params, and tokenizer are different. Please see the [Llama README page](../llama/README.md) for details. +LFM2 is a hybrid model, where some attention layers are replaced with short convolutions. + +### Example export +Here is a basic example for exporting LFM2, although please refer to the Llama README's [Step 2: Prepare model](../llama/README.md#step-2-prepare-model) for more advanced usage. + +Export 350m to XNNPack, quantized with 8da4w: +``` +python -m extension.llm.export.export_llm \ + --config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \ + +base.model_class="lfm2_350m" \ + +base.params="examples/models/lfm2/config/lfm2_350m_config.json" \ + +export.output_name="lfm2_350m_8da4w.pte" +``` + +Export 700m to XNNPack, quantized with 8da4w: +``` +python -m extension.llm.export.export_llm \ + --config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \ + +base.model_class="lfm2_700m" \ + +base.params="examples/models/lfm2/config/lfm2_700m_config.json" \ + +export.output_name="lfm2_700m_8da4w.pte" +``` + +Export 1_2b to XNNPack, quantized with 8da4w: +``` +python -m extension.llm.export.export_llm \ + --config examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml \ + +base.model_class="lfm2_1_2b" \ + +base.params="examples/models/lfm2/config/lfm2_1_2b_config.json" \ + +export.output_name="lfm2_1_2b_8da4w.pte" +``` +### Example run +With ExecuTorch pybindings: +``` +python -m examples.models.llama.runner.native \ + --model lfm2_700m \ + --pte lfm2_700m_8da4w.pte \ + --tokenizer ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer.json \ + --tokenizer_config ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer_config.json \ + --prompt "<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n" \ + --params examples/models/lfm2/config/lfm2_700m_config.json \ + --max_len 128 \ + -kv \ + --temperature 0.3 +``` + +With ExecuTorch's sample c++ runner (see the Llama README's [Step 3: Run on your computer to validate](../llama/README.md#step-3-run-on-your-computer-to-validate) to build the runner): +``` +cmake-out/examples/models/llama/llama_main \ + --model_path lfm2_700m_8da4w.pte \ + --tokenizer_path ~/.cache/huggingface/hub/models--LiquidAI--LFM2-700M/snapshots/ab260293733f05dd4ce22399bea1cae2cf9b272d/tokenizer.json \ + --prompt="<|startoftext|><|im_start|>user\nWho are you?<|im_end|>\n<|im_start|>assistant\n" \ + --temperature 0.3 +``` + +To run the model on an example iOS or Android app, see the Llama README's [Step 5: Build Mobile apps](../llama/README.md#step-5-build-mobile-apps) section. diff --git a/examples/models/lfm2/__init__.py b/examples/models/lfm2/__init__.py new file mode 100644 index 00000000000..224282df905 --- /dev/null +++ b/examples/models/lfm2/__init__.py @@ -0,0 +1,5 @@ +from executorch.examples.models.lfm2.convert_weights import convert_weights + +__all__ = [ + "convert_weights", +] diff --git a/examples/models/lfm2/config/lfm2_1_2b_config.json b/examples/models/lfm2/config/lfm2_1_2b_config.json new file mode 100644 index 00000000000..015b6940ed9 --- /dev/null +++ b/examples/models/lfm2/config/lfm2_1_2b_config.json @@ -0,0 +1,34 @@ +{ + "dim": 2048, + "ffn_dim_multiplier": 1, + "hidden_dim": 8192, + "n_heads": 32, + "n_kv_heads": 8, + "n_layers": 16, + "norm_eps": 1e-5, + "rope_theta": 1000000.0, + "use_scaled_rope": false, + "vocab_size": 65536, + "use_hf_rope": true, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "layer_types": [ + "conv", + "conv", + "full_attention", + "conv", + "conv", + "full_attention", + "conv", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "conv" + ] +} diff --git a/examples/models/lfm2/config/lfm2_350m_config.json b/examples/models/lfm2/config/lfm2_350m_config.json new file mode 100644 index 00000000000..4cd2b2a1830 --- /dev/null +++ b/examples/models/lfm2/config/lfm2_350m_config.json @@ -0,0 +1,34 @@ +{ + "dim": 1024, + "ffn_dim_multiplier": 1, + "hidden_dim": 4608, + "n_heads": 16, + "n_kv_heads": 8, + "n_layers": 16, + "norm_eps": 1e-5, + "rope_theta": 1000000.0, + "use_scaled_rope": false, + "vocab_size": 65536, + "use_hf_rope": true, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "layer_types": [ + "conv", + "conv", + "full_attention", + "conv", + "conv", + "full_attention", + "conv", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "conv" + ] +} diff --git a/examples/models/lfm2/config/lfm2_700m_config.json b/examples/models/lfm2/config/lfm2_700m_config.json new file mode 100644 index 00000000000..9f3afc0f121 --- /dev/null +++ b/examples/models/lfm2/config/lfm2_700m_config.json @@ -0,0 +1,34 @@ +{ + "dim": 1536, + "ffn_dim_multiplier": 1, + "hidden_dim": 6912, + "n_heads": 24, + "n_kv_heads": 8, + "n_layers": 16, + "norm_eps": 1e-5, + "rope_theta": 1000000.0, + "use_scaled_rope": false, + "vocab_size": 65536, + "use_hf_rope": true, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "layer_types": [ + "conv", + "conv", + "full_attention", + "conv", + "conv", + "full_attention", + "conv", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "full_attention", + "conv", + "conv" + ] +} diff --git a/examples/models/lfm2/config/lfm2_xnnpack_fp32.yaml b/examples/models/lfm2/config/lfm2_xnnpack_fp32.yaml new file mode 100644 index 00000000000..9dd93821326 --- /dev/null +++ b/examples/models/lfm2/config/lfm2_xnnpack_fp32.yaml @@ -0,0 +1,12 @@ +base: + metadata: '{"get_bos_id": 1, "get_eos_ids":[7]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: True + dtype_override: fp32 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml b/examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml new file mode 100644 index 00000000000..0925a5989be --- /dev/null +++ b/examples/models/lfm2/config/lfm2_xnnpack_q8da4w.yaml @@ -0,0 +1,15 @@ +base: + metadata: '{"get_bos_id": 1, "get_eos_ids":[7]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: True + dtype_override: fp32 + +quantization: + qmode: 8da4w + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/lfm2/convert_weights.py b/examples/models/lfm2/convert_weights.py new file mode 100644 index 00000000000..d23f2aa7c49 --- /dev/null +++ b/examples/models/lfm2/convert_weights.py @@ -0,0 +1,74 @@ +import os +from typing import Dict + +import torch +from safetensors.torch import load_file + +from torchtune.models.convert_weights import get_mapped_key + +_LFM_2_TO_META = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.embedding_norm.weight": "norm.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.out_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.k_layernorm.weight": "layers.{}.attention.k_norm_fn.weight", + "model.layers.{}.self_attn.q_layernorm.weight": "layers.{}.attention.q_norm_fn.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.operator_norm.weight": "layers.{}.attention_norm.weight", +} + + +def lfm_2_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from LFM2 HF format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in LFM2 HF format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + converted_state_dict = {} + + for key, value in state_dict.items(): + try: + new_key = get_mapped_key(key, _LFM_2_TO_META) + except: + new_key = key.removeprefix("model.") + + # split in_proj + if new_key.endswith(".conv.in_proj.weight"): + for name, split_value in zip( + ["B_proj", "C_proj", "x_proj"], torch.chunk(value, 3, dim=0) + ): + converted_state_dict[new_key.replace("in_proj", name)] = split_value + else: + converted_state_dict[new_key] = value + + # If lm_head.weight is not present in state dict, assume tied embeddings + if "lm_head.weight" not in state_dict: + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + + return converted_state_dict + + +def load_checkpoint(input_dir: str) -> Dict: + print("Loading checkpoint from safetensors directory") + state_dict = load_file(os.path.join(input_dir, "model.safetensors")) + return state_dict + + +def convert_weights(input_dir: str, output_file: str) -> None: + print("Loading checkpoint...") + sd = load_checkpoint(input_dir) + print("Converting checkpoint...") + sd = lfm_2_to_meta(sd) + print("Saving checkpoint...") + torch.save(sd, output_file) + print("Done.") diff --git a/examples/models/lfm2/short_conv.py b/examples/models/lfm2/short_conv.py new file mode 100644 index 00000000000..5a141d4ce61 --- /dev/null +++ b/examples/models/lfm2/short_conv.py @@ -0,0 +1,110 @@ +import torch +from executorch.examples.models.llama.attention import ForwardOptions +from executorch.examples.models.llama.feed_forward import FeedForward + +from executorch.examples.models.llama.norm import RMSNorm +from torch import nn + + +class ShortConv(nn.Module): + def __init__( + self, + dim: int, + L_cache: int = 3, + bias: bool = False, + device: torch.device = None, + dtype: torch.dtype = None, + ): + super().__init__() + self.dim = dim + self.L_cache = L_cache + self.device = device + self.dtype = dtype + self.bias = bias + + self.conv = nn.Conv1d( + dim, + dim, + kernel_size=L_cache, + padding=0, ## we don't need padding since we handle it manually + groups=dim, + bias=bias, + ) + + conv_state = torch.zeros( + 1, ## batch size is assumed to be 1 for now + dim, + L_cache - 1, + device="cpu", + ) + self.register_buffer("conv_state", conv_state) + + ## better performance in Executorch with separate projections + self.B_proj = nn.Linear(dim, dim, bias=bias) + self.C_proj = nn.Linear(dim, dim, bias=bias) + self.x_proj = nn.Linear(dim, dim, bias=bias) + + self.out_proj = nn.Linear(dim, dim, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seqlen, dim = x.size() + assert batch_size == 1, "batch_size must be 1" + + B = self.B_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) + C = self.C_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) + x = self.x_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len) + + Bx = B * x # (batch_size, dim, seq_len) + + ## This is where we handle padding + ## By default, the conv_state is initialized to 0. + # So, assuming prefill is done on an empty cache, concatenating conv_state to the beginning of the sequence acts similary to + ## using nn.Conv1d(padding=L_cache-1) (for prefill) without no manual padding. + ## However, the manual padding has the added benefit of being correct during decode, when the cache is not initialized to 0. + Bx = torch.cat( + [self.conv_state, Bx], dim=-1 + ) # (batch_size, dim, seq_len + L_cache - 1) + + ## Update the conv_state + new_conv_state = Bx[ + ..., -(self.L_cache - 1) : + ] # (batch_size, dim, L_cache - 1) + with torch.no_grad(): + self.conv_state.copy_(new_conv_state) + + conv_out = self.conv(Bx)[..., : x.size(-1)] # (batch_size, dim, seq_len) + y = C * conv_out # (batch_size, dim, seq_len) + + y = y.transpose(-1, -2) # (batch_size, seq_len, dim) + y = y.contiguous() # (batch_size, seq_len, dim) + y = self.out_proj(y) # (batch_size, seq_len, dim) + return y + + def reset_cache(self): + self.conv_state.zero_() + + +class ShortConvBlock(nn.Module): + def __init__(self, dim: int, hidden_dim: int, norm_eps: float): + super().__init__() + self.L_cache = 3 # hardcode 3 for now + self.conv = ShortConv(dim, self.L_cache, bias=False) + self.feed_forward = FeedForward(dim, hidden_dim) + self.ffn_norm = RMSNorm(dim, norm_eps) + # use attention_norm norm instead of operator_norm to unify with TransformerBlock + self.attention_norm = RMSNorm(dim, norm_eps) + + def forward( + self, + x, + freqs_cos=None, + freqs_sin=None, + _unused_attn_options: ForwardOptions = None, + ): # x: 1xN + h = self.conv.forward(self.attention_norm(x)) + h = x + h + out = h + self.feed_forward(self.ffn_norm(h)) + return out, None + + def reset_cache(self): + self.conv.reset_cache() diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 07de9f3fa75..06dd57282ab 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -99,6 +99,9 @@ "qwen3_4b", "phi_4_mini", "smollm2", + "lfm2_350m", # hybrid + "lfm2_700m", # hybrid + "lfm2_1_2b", # hybrid ] TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] HUGGING_FACE_REPO_IDS = { @@ -108,6 +111,9 @@ "qwen3_0_6b": "Qwen/Qwen3-0.6B", "qwen3_1_7b": "Qwen/Qwen3-1.7B", "qwen3_4b": "Qwen/Qwen3-4B", + "lfm2_350m": "LiquidAI/LFM2-350M", + "lfm2_700m": "LiquidAI/LFM2-700M", + "lfm2_1_2b": "LiquidAI/LFM2-1.2B", } @@ -603,6 +609,8 @@ def export_llama( from executorch.examples.models.phi_4_mini import convert_weights elif model_name == "smollm2": from executorch.examples.models.smollm2 import convert_weights + elif model_name.startswith("lfm2"): + from executorch.examples.models.lfm2 import convert_weights else: raise ValueError( f"Converting weights to meta format for {model_name} is not yet supported" diff --git a/examples/models/llama/feed_forward.py b/examples/models/llama/feed_forward.py new file mode 100644 index 00000000000..3e7af7e0dc8 --- /dev/null +++ b/examples/models/llama/feed_forward.py @@ -0,0 +1,14 @@ +import torch.nn.functional as F +from torch import nn + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + assert hidden_dim is not None + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 048f377ee7c..cdbb0c7557c 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -12,31 +12,19 @@ import torch import torch.nn.functional as F +from executorch.examples.models.lfm2.short_conv import ShortConvBlock from executorch.examples.models.llama.attention import ( Attention, ATTENTION_REGISTRY, ForwardOptions, ) - +from executorch.examples.models.llama.feed_forward import FeedForward from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.norm import RMSNorm from executorch.examples.models.llama.rope import Rope from torch import nn -class FeedForward(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - assert args.hidden_dim is not None - hidden_dim: int = args.hidden_dim - self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) - self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) - - def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - class ConditionalFeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -102,7 +90,7 @@ def __init__(self, args: ModelArgs, attention: Attention): if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: - self.feed_forward = FeedForward(args) + self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim) self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) @@ -255,8 +243,18 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: layers = torch.nn.ModuleList() cls = ATTENTION_REGISTRY[model_args.attention_type] for layer_id in range(model_args.n_layers): - attention = cls(model_args, layer_id, rope, **model_args.attention_kwargs) - transformer_block = TransformerBlock(model_args, attention) - layers.append(transformer_block) + # hybrid models define layer_types + if model_args.layer_types and model_args.layer_types[layer_id] == "conv": + layers.append( + ShortConvBlock( + dim=model_args.dim, + hidden_dim=model_args.hidden_dim, + norm_eps=model_args.norm_eps, + ) + ) + else: + attention = cls(model_args, layer_id, rope, **model_args.attention_kwargs) + transformer_block = TransformerBlock(model_args, attention) + layers.append(transformer_block) return Transformer(model_args, layers, rope) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index fdd4dabee17..3ed9f23443b 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -71,6 +71,8 @@ class ModelArgs: None # KV cache bit width. This is for QNN backend only for now. ) attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + # Hybrid models can have layer types different from attention + layer_types: Optional[list] = None def __post_init__(self): if self.n_kv_heads is None: diff --git a/examples/models/llama/source_transformation/spin_quant.py b/examples/models/llama/source_transformation/spin_quant.py index e07b78dc6ee..c8dcee94617 100644 --- a/examples/models/llama/source_transformation/spin_quant.py +++ b/examples/models/llama/source_transformation/spin_quant.py @@ -14,7 +14,7 @@ import torch.nn.functional as F -from executorch.examples.models.llama.llama_transformer import FeedForward +from executorch.examples.models.llama.feed_forward import FeedForward from torch import nn diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 8f8646e88cc..5755b3410cd 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -43,6 +43,9 @@ class ModelType(str, Enum): qwen3_4b = "qwen3_4b" phi_4_mini = "phi_4_mini" smollm2 = "smollm2" + lfm2_350m = "lfm2_350m" + lfm2_700m = "lfm2_700m" + lfm2_1_2b = "lfm2_1_2b" class PreqMode(str, Enum): From ad9ea346b597c63b8e2d3e8d13cce4b1bcab005b Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 3 Sep 2025 12:42:10 -0700 Subject: [PATCH 2/2] Fix internal tests for LiquidAI LFM2 (#13916) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/13916 Fix internal tests for diff external diff generated by https://github.com/pytorch/executorch/pull/13805 for adding LiquidAI's LFM2 model to ExecuTorch. Reviewed By: mergennachin Differential Revision: D81491136 --- examples/models/lfm2/__init__.py | 5 +++++ examples/models/lfm2/short_conv.py | 8 +++++--- examples/models/llama/TARGETS | 21 ++++++++++++++++++--- examples/models/llama/feed_forward.py | 1 - examples/models/llama/llama_transformer.py | 15 +++++++++++++-- 5 files changed, 41 insertions(+), 9 deletions(-) diff --git a/examples/models/lfm2/__init__.py b/examples/models/lfm2/__init__.py index 224282df905..1efdc55af81 100644 --- a/examples/models/lfm2/__init__.py +++ b/examples/models/lfm2/__init__.py @@ -1,5 +1,10 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from executorch.examples.models.lfm2.convert_weights import convert_weights +from executorch.examples.models.lfm2.short_conv import ShortConvBlock __all__ = [ "convert_weights", + "ShortConvBlock", ] diff --git a/examples/models/lfm2/short_conv.py b/examples/models/lfm2/short_conv.py index 5a141d4ce61..ae04580d6c6 100644 --- a/examples/models/lfm2/short_conv.py +++ b/examples/models/lfm2/short_conv.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from executorch.examples.models.llama.attention import ForwardOptions from executorch.examples.models.llama.feed_forward import FeedForward @@ -12,8 +14,8 @@ def __init__( dim: int, L_cache: int = 3, bias: bool = False, - device: torch.device = None, - dtype: torch.dtype = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): super().__init__() self.dim = dim @@ -99,7 +101,7 @@ def forward( x, freqs_cos=None, freqs_sin=None, - _unused_attn_options: ForwardOptions = None, + _unused_attn_options: Optional[ForwardOptions] = None, ): # x: 1xN h = self.conv.forward(self.attention_norm(x)) h = x + h diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index c4870ece193..fe26b2f08e0 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -13,6 +13,24 @@ runtime.python_library( name = "llama_transformer", srcs = [ "llama_transformer.py", + ], + _is_external_target = True, + base_module = "executorch.examples.models.llama", + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + ":transformer_modules", + "//caffe2:torch", + "//executorch/examples/models/lfm2:lfm2", + ], +) + +runtime.python_library( + name = "transformer_modules", + srcs = [ + "feed_forward.py", "lora.py", "rope.py", "attention.py", @@ -25,9 +43,6 @@ runtime.python_library( "//executorch/...", "@EXECUTORCH_CLIENTS", ], - deps = [ - "//caffe2:torch", - ], ) runtime.python_library( diff --git a/examples/models/llama/feed_forward.py b/examples/models/llama/feed_forward.py index 3e7af7e0dc8..21a4e27df04 100644 --- a/examples/models/llama/feed_forward.py +++ b/examples/models/llama/feed_forward.py @@ -5,7 +5,6 @@ class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int): super().__init__() - assert hidden_dim is not None self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index cdbb0c7557c..3a325d0f4f8 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -12,7 +12,6 @@ import torch import torch.nn.functional as F -from executorch.examples.models.lfm2.short_conv import ShortConvBlock from executorch.examples.models.llama.attention import ( Attention, ATTENTION_REGISTRY, @@ -87,10 +86,15 @@ def __init__(self, args: ModelArgs, attention: Attention): self.dim = args.dim self.head_dim = args.head_dim self.attention = attention + + assert ( + args.hidden_dim is not None + ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock." if args.moe: self.block_sparse_moe = MOEFeedForward(args) else: self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim) + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) @@ -245,6 +249,11 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: for layer_id in range(model_args.n_layers): # hybrid models define layer_types if model_args.layer_types and model_args.layer_types[layer_id] == "conv": + from executorch.examples.models.lfm2.short_conv import ShortConvBlock + + assert ( + model_args.hidden_dim is not None + ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock." layers.append( ShortConvBlock( dim=model_args.dim, @@ -253,7 +262,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: ) ) else: - attention = cls(model_args, layer_id, rope, **model_args.attention_kwargs) + attention = cls( + model_args, layer_id, rope, **model_args.attention_kwargs + ) # pyre-ignore[45] transformer_block = TransformerBlock(model_args, attention) layers.append(transformer_block)