From a011764eaa6880458c070e0278a25d2379f62f5c Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 9 Jun 2025 12:55:23 -0700 Subject: [PATCH 1/6] Add new export LLM config Pull Request resolved: https://github.com/pytorch/executorch/pull/11028 @imported-using-ghimport Differential Revision: [D75263991](https://our.internmc.facebook.com/intern/diff/D75263991/) ghstack-source-id: 289208257 --- examples/models/llama/config/TARGETS | 9 + examples/models/llama/config/llm_config.py | 495 ++++++++++++++++++ examples/models/llama/config/targets.bzl | 26 + .../models/llama/config/test_llm_config.py | 104 ++++ examples/models/llama/export_llama_lib.py | 2 +- 5 files changed, 635 insertions(+), 1 deletion(-) create mode 100644 examples/models/llama/config/TARGETS create mode 100644 examples/models/llama/config/llm_config.py create mode 100644 examples/models/llama/config/targets.bzl create mode 100644 examples/models/llama/config/test_llm_config.py diff --git a/examples/models/llama/config/TARGETS b/examples/models/llama/config/TARGETS new file mode 100644 index 00000000000..2ba1b55a3dd --- /dev/null +++ b/examples/models/llama/config/TARGETS @@ -0,0 +1,9 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/examples/models/llama/config/llm_config.py b/examples/models/llama/config/llm_config.py new file mode 100644 index 00000000000..b929e756c3e --- /dev/null +++ b/examples/models/llama/config/llm_config.py @@ -0,0 +1,495 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Configurations for exporting Llama. + +Uses dataclasses, which integrate with OmegaConf and Hydra. +""" + +import ast +import re +from dataclasses import dataclass, field +from enum import Enum +from typing import ClassVar, List, Optional + + +################################################################################ +################################## BaseConfig ################################## +################################################################################ + + +class ModelType(str, Enum): + STORIES110M = "stories110m" + LLAMA2 = "llama2" + LLAMA3 = "llama3" + LLAMA3_1 = "llama3_1" + LLAMA3_2 = "llama3_2" + LLAMA3_2_VISION = "llama3_2_vision" + STATIC_LLAMA = "static_llama" + QWEN2_5 = "qwen2_5" + QWEN3_0_6B = "qwen3-0_6b" + QWEN3_1_7B = "qwen3-1_7b" + QWEN3_4B = "qwen3-4b" + PHI_4_MINI = "phi_4_mini" + SMOLLM2 = "smollm2" + + +class PreqMode(str, Enum): + """ + If you are dealing with pre-quantized checkpoints, this used to + be the way to specify them. Now you don't need to specify these + options if you use a TorchAo-prequantized checkpoint, but they + are still around to preserve backward compatibility. + """ + + PREQ_8DA4W = "8da4w" + PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w" + + +@dataclass +class BaseConfig: + """ + Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini, + and are the minimal set of parameters needed to load the pretrained + eager model and its weights. + + Attributes: + model_class: Which model to to export. + params: Model parameters, such as n_layers, hidden_size, etc. + If left empty will use defaults specified in model_args.py. + checkpoint: Path to the checkpoint file. + If left empty, the model will be initialized with random weights. + checkpoint_dir: Path to directory containing sharded checkpoint files. + tokenizer_path: Path to the tokenizer file. + metadata: Json string containing metadata information. + e.g. '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' + use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT. + fairseq2: For legacy internal use cases, this is safe to ignore. + preq_mode: Legacy option to specify how prequantized weights are loaded. + Going forward, ExecuTorch supports loading weights prequantized through + TorchAo as-is, without any special handling. + preq_group_size: Legacy option to specify the group size of prequantized weights. + preq_embedding_quantize: Legacy option to specify how prequantized embeddings + are loaded. + """ + + model_class: ModelType = ModelType.LLAMA3 + params: Optional[str] = None + checkpoint: Optional[str] = None + checkpoint_dir: Optional[str] = None + tokenizer_path: Optional[str] = None + metadata: Optional[str] = None + use_lora: int = int + fairseq2: bool = False + preq_mode: Optional[PreqMode] = None + preq_group_size: int = 32 + preq_embedding_quantize: str = "8,0" + + +################################################################################ +################################# ModelConfig ################################## +################################################################################ + + +class DtypeOverride(str, Enum): + """ + DType of the model. Highly recommended to use "fp32", unless you want to + export without a backend, in which case you can also use "bf16". "fp16" + is not recommended. + """ + + FP32 = "fp32" + FP16 = "fp16" + BF16 = "bf16" + + +@dataclass +class ModelConfig: + """ + Configurations not necessarily specific to the model, but are needed to + finish off the rest of the model configuration in eager. You can think + of these like optimizations / actual configurations. The same ModelConfig + can be applied to multiple models. + + Attributes: + dtype_override: dtype to cast the model to. + enable_dynamic_shape: whether to enable dynamic shapes on the sequence + length so that the model can handle arbitrary prefill lengths and + token generation. + use_shared_embeddings: whether the embedding/output weights should be + shared. Only available with torchao kernels, e.g. when + qmode set to use a "torchao:8da(\\d+)w" pattern. + use_sdpa_with_kv_cache: Whether to use flash attention by substituting + for our custom SDPA op. Note that the naming is poor and this + doesn't actually have anything to do with the kv_cache at the moment. + expand_rope_table: Temporary workaround to expand sin/cos table in head + dim to take vectorized path in optimized kernels. + use_attention_sink: Whether to use attention sink to support multi-round + conversation. Structured as: + ',,', + e.g., '4,2044,1024'. + output_prune_map: Path to the output pruning token mapping file (token_map.json). + input_prune_map: Path to the output pruning token mapping file (token_map.json). + use_kv_cache: Whether to use KV cache. + quantize_kv_cache: Whether to perform int8 per token quantization on the KV cache. + local_global_attention: List of integers specifying local and global attention pattern. + e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16. + [0, 16, 32] pattern specifies 2nd and 3rd layers have sliding windows of 16 and 32. + [16] pattern specifies all layers have a sliding window of 16. + """ + + dtype_override: DtypeOverride = DtypeOverride.FP32 + enable_dynamic_shape: bool = True + use_shared_embedding: bool = False + use_sdpa_with_kv_cache: bool = False + expand_rope_table: bool = False + use_attention_sink: Optional[str] = None + output_prune_map: Optional[str] = None + input_prune_map: Optional[str] = None + use_kv_cache: bool = False + quantize_kv_cache: bool = False + local_global_attention: Optional[List[int]] = None + + def __post_init__(self): + self._validate_attention_sink() + self._validate_local_global_attention() + + if self.quantize_kv_cache and not self.use_kv_cache: + raise ValueError( + "Cannot quantize the KV cache (quantize_kv_cache) without enabling the KV cache (use_kv_cache)" + ) + + if self.local_global_attention and not self.use_kv_cache: + raise ValueError( + "Cannot use local_global_attention without enabling the KV cache (use_kv_cache)" + ) + + def _validate_attention_sink(self): + if self.use_attention_sink: + attention_sink_params = self.use_attention_sink.split(",") + if len(attention_sink_params) != 3: + raise ValueError( + "The value of use_attention_sink must be structured like ',,'" + ) + + def _validate_local_global_attention(self): + if self.local_global_attention: + local_global_err = "The value of local_global_attention must be a list of integers, e.g., [0, 16, 0, 16]" + try: + parsed = ast.literal_eval(self.local_global_attention) + if not ( + isinstance(parsed, list) and all(isinstance(i, int) for i in parsed) + ): + raise ValueError(local_global_err) + except Exception: + raise ValueError(local_global_err) + + +################################################################################ +################################ ExportConfig ################################## +################################################################################ + + +@dataclass +class ExportConfig: + """ + Configures properties relevant to the export process. + + Attributes: + max_seq_length: Maximum length of sequence to evaluate. + max_context_length: Maximum of context for the model to remember. + output_dir: Output dir to save the exported .pte file to. + output_name: File name to override the exported .pte file. + so_library: Shared library to specify custom quantized operators. + export_only: Whether to stop right after torch.export() and + just save the exported .pt2 graph file. + """ + + max_seq_length: int = 128 + max_context_length: int = 128 + output_dir: Optional[str] = None + output_name: Optional[str] = None + so_library: Optional[str] = None + export_only: bool = False + + def __post_init__(self): + if self.max_context_length > self.max_seq_length: + raise ValueError( + f"max_context_length of {self.max_context_length} cannot be greater than max_seq_length of {self.max_seq_length}" + ) + + +################################################################################ +################################# DebugConfig ################################## +################################################################################ + + +@dataclass +class DebugConfig: + """ + Configures options to debug the export process. + + Attributes: + profile_memory: Whether to generate a chrome trace of activation memory + for intermediate tensors. + profile_path: Use cProfile to profile the export. Results are saved to + profile_path as an html file. + generate_etrecord: Whether to generate an ETRecord debug artifact. + generate_full_logits: Whether to keep the full logits, potentially useful + for debugging purposes. Kept off by default to save memory. + verbose: Whether to log the export process verbosely (log level >= INFO). + """ + + profile_memory: bool = False + profile_path: Optional[str] = None + generate_etrecord: bool = False + generate_full_logits: bool = False + verbose: bool = False + + +################################################################################ +############################# QuantizationConfig ############################### +################################################################################ + + +class Pt2eQuantize(str, Enum): + """ + Type of backend-specific Pt2e quantization strategy to use. + + Pt2e uses a different quantization library that is graph-based + compared to `qmode`, which is also specified in the QuantizationConfig + and is source transform-based. + """ + + XNNPACK_DYNAMIC = "xnnpack_dynamic" + XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4" + QNN_8A8W = "qnn_8a8w" + QNN_16A16W = "qnn_16a16w" + QNN_16A4W = "qnn_16a4w" + COREML_C4W = "coreml_c4w" + COREML_8A_C8W = "coreml_8a_c8w" + COREML_8A_C4W = "coreml_8a_c4w" + COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w" + COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w" + VULKAN_8W = "vulkan_8w" + + +class SpinQuant(str, Enum): + CUDA = "cuda" + NATIVE = "native" + + +@dataclass +class QuantizationConfig: + """ + Configures how the model should be quantized (PTQ). + + Attributes: + qmode: Quantization mode using TorchAo, expressed as a string. + See the __post_init__ validation for available qmode options. + embedding_quantize: Type of embedding quantization. + Must be of the format ',', e.g., '8,1024'. + pt2e_quantize: Quantization mode using pt2e, which is an alternative + to TorchAo that uses backend-aware graph mode quantization rather + than source transformation quantization. + group_size: Group size for quantization. + use_spin_quant: Which spin quant mode to use. If unspecified, don't use + spin quant. + use_qat: Whether the checkpoint is quantization-awarely trained. + calibration_tasks: Tasks for GPTQ calibration from lm_eval. + calibration_limit: Number of samples used for calibration from lm_eval. + calibration_seq_length: Sequence length for GPTQ calibration from lm_eval. + calibration_data: Prompts use for calibration. + """ + + # Constants. + QMODE_OPTIONS: ClassVar[List[str]] = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"] + AO_QUANT_PATTERNS: ClassVar[List[str]] = [ + r"torchao:8da(\d+)w", + r"torchao:fpa(\d+)w", + ] + + qmode: Optional[str] = None + embedding_quantize: Optional[str] = None + pt2e_quantize: Optional[Pt2eQuantize] = None + group_size: Optional[int] = None + use_spin_quant: Optional[SpinQuant] = None + use_qat: bool = False + calibration_tasks: Optional[List[str]] = None + calibration_limit: Optional[int] = None + calibration_seq_length: Optional[int] = None + calibration_data: str = "Once upon a time" + + def __post_init__(self): + if self.qmode: + self._validate_qmode() + + def _validate_qmode(self) -> None: + if not self.qmode: + return + + if self.qmode in self.QMODE_OPTIONS: + return + + # If qmode is one of these below patterns, this means that we + # are using ARM-based torchao ops. + for pattern in self.AO_QUANT_PATTERNS: + matches = re.findall(pattern, self.qmode) + if len(matches) == 1: + return + + raise ValueError( + f"Got qmode {self.qmode}, but expected one of {self.QMODE_OPTIONS}, or one of the regex patterns {self.AO_QUANT_PATTERNS}." + ) + + def _validate_embedding_quantize(self): + if len(self.embedding_quantize.split(",")) != 2: + raise ValueError( + f'embedding_quantize of {self.embedding_quantize} must follow the following format: ","' + ) + + +################################################################################ +############################### BackendConfig ################################## +################################################################################ + + +@dataclass +class XNNPackConfig: + """ + Configures the XNNPack backend. + + Attributes: + enabled: :) + extended_ops: Whether to match more types of ops to delegates to XNNPack. + """ + + enabled: bool = False + extended_ops: bool = False + + +class CoreMLQuantize(str, Enum): + B4W = "b4w" + C4W = "c4w" + + +class CoreMLComputeUnit(str, Enum): + CPU_ONLY = "cpu_only" + CPU_AND_GPU = "cpu_and_gpu" + CPU_AND_NE = "cpu_and_ne" + ALL = "all" + + +@dataclass +class CoreMLConfig: + """ + Configures the CoreML backend. + """ + + enabled: bool = False + enable_state: bool = False + preserve_sdpa: bool = False + quantize: Optional[CoreMLQuantize] = None + ios: int = 15 + compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY + + def __post_init__(self): + if self.ios not in (15, 16, 17, 18): + raise ValueError(f"Invalid coreml ios version: {self.ios}") + + +@dataclass +class VulkanConfig: + """ + Configures the Vulkan backend. + """ + + enabled: bool = False + + +@dataclass +class QNNConfig: + """ + Configures the QNN backend. + """ + + enabled: bool = False + use_sha: bool = False + soc_model: str = "SM8650" + use_qnn_sha: bool = False + optimized_rotation_path: Optional[str] = None + num_sharding: int = 0 + + +@dataclass +class MPSConfig: + """ + Configures the MPS backend. + """ + + enabled: bool = False + + +@dataclass +class BackendConfig: + """ + Configures which backends should be used and how the backends + should be set up. + """ + + xnnpack: XNNPackConfig = field(default_factory=XNNPackConfig) + coreml: CoreMLConfig = field(default_factory=CoreMLConfig) + vulkan: VulkanConfig = field(default_factory=VulkanConfig) + qnn: QNNConfig = field(default_factory=QNNConfig) + mps: MPSConfig = field(default_factory=MPSConfig) + + +################################################################################ +################################## LlmConfig ################################### +################################################################################ + + +@dataclass +class LlmConfig: + """ + The overall configuration for customizing the LLM export process. + """ + + base: BaseConfig = field(default_factory=BaseConfig) + model: ModelConfig = field(default_factory=ModelConfig) + export: ExportConfig = field(default_factory=ExportConfig) + debug: DebugConfig = field(default_factory=DebugConfig) + quantization: QuantizationConfig = field(default_factory=QuantizationConfig) + backend: BackendConfig = field(default_factory=BackendConfig) + + def __post_init__(self): + self._validate_low_bit() + + def _validate_low_bit(self): + if not self.quantization.qmode: + return + + using_lowbit_ops = False + for pattern in self.quantization.AO_QUANT_PATTERNS: + matches = re.findall(pattern, self.quantization.qmode) + if len(matches) == 1: + using_lowbit_ops = True + + # If we are using Ao's low bit quantization kernels for ARM, + # we do not want to also be delegating to a CPU backend (XNNPack). + if using_lowbit_ops and self.backend.xnnpack.enabled: + raise ValueError( + "Cannot use low-bit Ao ops (from qmode=torchao:...) while also delegating to XNNPack." + ) + + # Also we can only use shared embeddings if we are using low bit kernels. + if self.model.use_shared_embedding and not using_lowbit_ops: + raise ValueError( + "Can only use shared embeddings with low-bit ops (with qmode=torchao:...)." + ) diff --git a/examples/models/llama/config/targets.bzl b/examples/models/llama/config/targets.bzl new file mode 100644 index 00000000000..8b85ce6d107 --- /dev/null +++ b/examples/models/llama/config/targets.bzl @@ -0,0 +1,26 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") + +def define_common_targets(): + runtime.python_library( + name = "llm_config", + srcs = [ + "llm_config.py", + ], + _is_external_target = True, + base_module = "executorch.examples.models.llama.config", + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + python_unittest( + name = "test_llm_config", + srcs = [ + "test_llm_config.py", + ], + deps = [ + ":llm_config", + ], + ) diff --git a/examples/models/llama/config/test_llm_config.py b/examples/models/llama/config/test_llm_config.py new file mode 100644 index 00000000000..0853e9dbbd8 --- /dev/null +++ b/examples/models/llama/config/test_llm_config.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest + +from executorch.examples.models.llama.config.llm_config import ( + BackendConfig, + BaseConfig, + CoreMLComputeUnit, + CoreMLConfig, + DebugConfig, + ExportConfig, + LlmConfig, + ModelConfig, + QuantizationConfig, + XNNPackConfig, +) + + +class TestValidation(unittest.TestCase): + def test_invalid_attention_sink(self): + with self.assertRaises(ValueError): + ModelConfig(use_attention_sink="4,2048") + + def test_invalid_local_global_attention_format(self): + with self.assertRaises(ValueError): + ModelConfig(local_global_attention="notalist") + + def test_quantize_kv_without_kv(self): + with self.assertRaises(ValueError): + ModelConfig(quantize_kv_cache=True) + + def test_local_global_attention_without_kv(self): + with self.assertRaises(ValueError): + ModelConfig(local_global_attention="[16]", use_kv_cache=False) + + def test_invalid_export_config_context_length(self): + with self.assertRaises(ValueError): + ExportConfig(max_seq_length=128, max_context_length=256) + + def test_invalid_qmode(self): + with self.assertRaises(ValueError): + QuantizationConfig(qmode="unknown") + + def test_invalid_coreml_ios(self): + with self.assertRaises(ValueError): + CoreMLConfig(ios=14) + + def test_lowbit_conflict_with_xnnpack(self): + qcfg = QuantizationConfig(qmode="torchao:8da4w") + bcfg = BackendConfig(xnnpack=XNNPackConfig(enabled=True)) + model_cfg = ModelConfig(use_shared_embedding=True) + + with self.assertRaises(ValueError): + LlmConfig(model=model_cfg, quantization=qcfg, backend=bcfg) + + def test_shared_embedding_without_lowbit(self): + model_cfg = ModelConfig(use_shared_embedding=True) + qcfg = QuantizationConfig(qmode="int8") + + with self.assertRaises(ValueError): + LlmConfig(model=model_cfg, quantization=qcfg) + + +class TestValidConstruction(unittest.TestCase): + + def test_valid_llm_config(self): + LlmConfig( + base=BaseConfig( + model_class="llama3", + checkpoint="checkpoints/model.pt", + tokenizer_path="tokenizer.json", + use_lora=8, + ), + model=ModelConfig( + dtype_override="fp32", + use_attention_sink="4,2048,1024", + use_kv_cache=True, + local_global_attention="[16, 32]", + ), + export=ExportConfig( + max_seq_length=256, + max_context_length=128, + output_dir="/tmp/export", + output_name="model.pte", + ), + debug=DebugConfig(profile_memory=True, verbose=True), + quantization=QuantizationConfig(qmode="torchao:8da4w"), + backend=BackendConfig( + xnnpack=XNNPackConfig(enabled=False), + coreml=CoreMLConfig( + enabled=True, ios=17, compute_units=CoreMLComputeUnit.ALL + ), + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 96faf64475e..11fb2fa3cbb 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -486,7 +486,7 @@ def build_args_parser() -> argparse.ArgumentParser: "--use_qat", default=False, action="store_true", - help="Whether the checkpoin is pre-quantized with QAT or not.", + help="Whether the checkpoint is pre-quantized with QAT or not.", ) parser.add_argument( From 8f1c751d3e9d5f9b1c6abb38a7b0517e86a46717 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 9 Jun 2025 13:56:25 -0700 Subject: [PATCH 2/6] Introduce hydra framework with backwards compatibility Pull Request resolved: https://github.com/pytorch/executorch/pull/11029 @imported-using-ghimport Differential Revision: [D75263989](https://our.internmc.facebook.com/intern/diff/D75263989/) ghstack-source-id: 289227700 --- examples/models/llama/TARGETS | 4 ++ examples/models/llama/config/llm_config.py | 13 +++++++ examples/models/llama/export_llama.py | 38 ++++++++++++++----- examples/models/llama/export_llama_args.py | 21 ++++++++++ examples/models/llama/export_llama_hydra.py | 27 +++++++++++++ examples/models/llama/export_llama_lib.py | 21 +++++++++- examples/models/llama/install_requirements.sh | 2 +- 7 files changed, 115 insertions(+), 11 deletions(-) create mode 100644 examples/models/llama/export_llama_args.py create mode 100644 examples/models/llama/export_llama_hydra.py diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 872eccce872..b51e164d483 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -132,6 +132,8 @@ runtime.python_library( name = "export_library", srcs = [ "export_llama.py", + "export_llama_args.py", + "export_llama_hydra.py", "export_llama_lib.py", "model.py", ], @@ -148,6 +150,8 @@ runtime.python_library( ":source_transformation", "//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform", "//caffe2:torch", + "//executorch/examples/models/llama/config:llm_config", + "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/exir/passes:init_mutable_pass", "//executorch/examples/models:model_base", "//executorch/examples/models:models", diff --git a/examples/models/llama/config/llm_config.py b/examples/models/llama/config/llm_config.py index b929e756c3e..d9a8a8e6192 100644 --- a/examples/models/llama/config/llm_config.py +++ b/examples/models/llama/config/llm_config.py @@ -12,6 +12,7 @@ Uses dataclasses, which integrate with OmegaConf and Hydra. """ +import argparse import ast import re from dataclasses import dataclass, field @@ -468,6 +469,18 @@ class LlmConfig: quantization: QuantizationConfig = field(default_factory=QuantizationConfig) backend: BackendConfig = field(default_factory=BackendConfig) + @classmethod + def from_args(cls, args: argparse.Namespace) -> "LlmConfig": + """ + To support legacy purposes, this function converts CLI args from + argparse to an LlmConfig, which is used by the LLM export process. + """ + llm_config = LlmConfig() + + # TODO: conversion code. + + return llm_config + def __post_init__(self): self._validate_low_bit() diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index e25a8a007eb..63e76e28ba9 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -4,30 +4,50 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Example script for exporting Llama2 to flatbuffer - -import logging - # force=True to ensure logging while in debugger. Set up logger before any # other imports. +import logging + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT, force=True) +import argparse +import runpy import sys import torch -from .export_llama_lib import build_args_parser, export_llama - sys.setrecursionlimit(4096) +def parse_hydra_arg(): + """First parse out the arg for whether to use Hydra or the old CLI.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument("--hydra", action="store_true") + args, remaining = parser.parse_known_args() + return args.hydra, remaining + + def main() -> None: seed = 42 torch.manual_seed(seed) - parser = build_args_parser() - args = parser.parse_args() - export_llama(args) + + use_hydra, remaining_args = parse_hydra_arg() + if use_hydra: + # The import runs the main function of export_llama_hydra with the remaining args + # under the Hydra framework. + sys.argv = [arg for arg in sys.argv if arg != "--hydra"] + print(f"running with {sys.argv}") + runpy.run_module( + "executorch.examples.models.llama.export_llama_hydra", run_name="__main__" + ) + else: + # Use the legacy version of the export_llama script which uses argsparse. + from executorch.examples.models.llama.export_llama_args import ( + main as export_llama_args_main, + ) + + export_llama_args_main(remaining_args) if __name__ == "__main__": diff --git a/examples/models/llama/export_llama_args.py b/examples/models/llama/export_llama_args.py new file mode 100644 index 00000000000..7a176d9b7d0 --- /dev/null +++ b/examples/models/llama/export_llama_args.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run export_llama with the legacy argparse setup. +""" + +from .export_llama_lib import build_args_parser, export_llama + + +def main(args) -> None: + parser = build_args_parser() + args = parser.parse_args(args) + export_llama(args) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/export_llama_hydra.py b/examples/models/llama/export_llama_hydra.py new file mode 100644 index 00000000000..73eca7e2a5a --- /dev/null +++ b/examples/models/llama/export_llama_hydra.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run export_llama using the new Hydra CLI. +""" + +import hydra + +from executorch.examples.models.llama.config.llm_config import LlmConfig +from executorch.examples.models.llama.export_llama_lib import export_llama +from hydra.core.config_store import ConfigStore + +cs = ConfigStore.instance() +cs.store(name="llm_config", node=LlmConfig) + + +@hydra.main(version_base=None, config_name="llm_config") +def main(llm_config: LlmConfig) -> None: + export_llama(llm_config) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 11fb2fa3cbb..12406cc762e 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -27,6 +27,8 @@ from executorch.devtools.backend_debug import print_delegation_info from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func + +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.hf_download import ( download_and_convert_hf_checkpoint, ) @@ -50,6 +52,7 @@ get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace +from omegaconf.dictconfig import DictConfig from ..model_factory import EagerModelFactory from .source_transformation.apply_spin_quant_r1_r2 import ( @@ -567,7 +570,23 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: return return_val -def export_llama(args) -> str: +def export_llama( + export_options: Union[argparse.Namespace, DictConfig], +) -> str: + if isinstance(export_options, argparse.Namespace): + # Legacy CLI. + args = export_options + llm_config = LlmConfig.from_args(export_options) # noqa: F841 + elif isinstance(export_options, DictConfig): + # Hydra CLI. + llm_config = export_options # noqa: F841 + else: + raise ValueError( + "Input to export_llama must be either of type argparse.Namespace or LlmConfig" + ) + + # TODO: refactor rest of export_llama to use llm_config instead of args. + # If a checkpoint isn't provided for an HF OSS model, download and convert the # weights first. if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS: diff --git a/examples/models/llama/install_requirements.sh b/examples/models/llama/install_requirements.sh index b9e0f9210c5..580a152a322 100755 --- a/examples/models/llama/install_requirements.sh +++ b/examples/models/llama/install_requirements.sh @@ -10,7 +10,7 @@ # Install tokenizers for hf .json tokenizer. # Install snakeviz for cProfile flamegraph # Install lm-eval for Model Evaluation with lm-evalution-harness. -pip install huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile +pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile # Call the install helper for further setup python examples/models/llama/install_requirement_helper.py From c638c09058233732379e2f667f0c254c3057df5a Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 9 Jun 2025 13:56:27 -0700 Subject: [PATCH 3/6] Convert args to LlmConfig Pull Request resolved: https://github.com/pytorch/executorch/pull/11081 @imported-using-ghimport Differential Revision: [D75263990](https://our.internmc.facebook.com/intern/diff/D75263990/) ghstack-source-id: 289227697 --- examples/apple/mps/scripts/mps_example.py | 2 +- examples/models/llama/config/llm_config.py | 146 ++++++++++++++++++++- 2 files changed, 145 insertions(+), 3 deletions(-) diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index c1a2e150286..2fc67bcca0e 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -145,7 +145,7 @@ def get_model_config(args): return model_config -if __name__ == "__main__": +if __name__ == "__main__": # noqa: C901 args = parse_args() if args.model_name not in MODEL_NAME_TO_MODEL: diff --git a/examples/models/llama/config/llm_config.py b/examples/models/llama/config/llm_config.py index d9a8a8e6192..a5c486a8c1e 100644 --- a/examples/models/llama/config/llm_config.py +++ b/examples/models/llama/config/llm_config.py @@ -470,14 +470,156 @@ class LlmConfig: backend: BackendConfig = field(default_factory=BackendConfig) @classmethod - def from_args(cls, args: argparse.Namespace) -> "LlmConfig": + def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 """ To support legacy purposes, this function converts CLI args from argparse to an LlmConfig, which is used by the LLM export process. """ llm_config = LlmConfig() - # TODO: conversion code. + # BaseConfig + if hasattr(args, "model"): + llm_config.base.model_class = ModelType(args.model) + if hasattr(args, "params"): + llm_config.base.params = args.params + if hasattr(args, "checkpoint"): + llm_config.base.checkpoint = args.checkpoint + if hasattr(args, "checkpoint_dir"): + llm_config.base.checkpoint_dir = args.checkpoint_dir + if hasattr(args, "tokenizer_path"): + llm_config.base.tokenizer_path = args.tokenizer_path + if hasattr(args, "metadata"): + llm_config.base.metadata = args.metadata + if hasattr(args, "use_lora"): + llm_config.base.use_lora = args.use_lora + if hasattr(args, "fairseq2"): + llm_config.base.fairseq2 = args.fairseq2 + + # PreqMode settings + if hasattr(args, "preq_mode") and args.preq_mode: + llm_config.base.preq_mode = PreqMode(args.preq_mode) + if hasattr(args, "preq_group_size"): + llm_config.base.preq_group_size = args.preq_group_size + if hasattr(args, "preq_embedding_quantize"): + llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize + + # ModelConfig + if hasattr(args, "dtype_override"): + llm_config.model.dtype_override = DtypeOverride(args.dtype_override) + if hasattr(args, "enable_dynamic_shape"): + llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape + if hasattr(args, "use_shared_embedding"): + llm_config.model.use_shared_embedding = args.use_shared_embedding + if hasattr(args, "use_sdpa_with_kv_cache"): + llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache + if hasattr(args, "expand_rope_table"): + llm_config.model.expand_rope_table = args.expand_rope_table + if hasattr(args, "use_attention_sink"): + llm_config.model.use_attention_sink = args.use_attention_sink + if hasattr(args, "output_prune_map"): + llm_config.model.output_prune_map = args.output_prune_map + if hasattr(args, "input_prune_map"): + llm_config.model.input_prune_map = args.input_prune_map + if hasattr(args, "use_kv_cache"): + llm_config.model.use_kv_cache = args.use_kv_cache + if hasattr(args, "quantize_kv_cache"): + llm_config.model.quantize_kv_cache = args.quantize_kv_cache + if hasattr(args, "local_global_attention"): + llm_config.model.local_global_attention = args.local_global_attention + + # ExportConfig + if hasattr(args, "max_seq_length"): + llm_config.export.max_seq_length = args.max_seq_length + if hasattr(args, "max_context_length"): + llm_config.export.max_context_length = args.max_context_length + if hasattr(args, "output_dir"): + llm_config.export.output_dir = args.output_dir + if hasattr(args, "output_name"): + llm_config.export.output_name = args.output_name + if hasattr(args, "so_library"): + llm_config.export.so_library = args.so_library + if hasattr(args, "export_only"): + llm_config.export.export_only = args.export_only + + # QuantizationConfig + if hasattr(args, "quantization_mode"): + llm_config.quantization.qmode = args.quantization_mode + if hasattr(args, "embedding_quantize"): + llm_config.quantization.embedding_quantize = args.embedding_quantize + if hasattr(args, "pt2e_quantize") and args.pt2e_quantize: + llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize) + if hasattr(args, "group_size"): + llm_config.quantization.group_size = args.group_size + if hasattr(args, "use_spin_quant") and args.use_spin_quant: + llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant) + if hasattr(args, "use_qat"): + llm_config.quantization.use_qat = args.use_qat + if hasattr(args, "calibration_tasks"): + llm_config.quantization.calibration_tasks = args.calibration_tasks + if hasattr(args, "calibration_limit"): + llm_config.quantization.calibration_limit = args.calibration_limit + if hasattr(args, "calibration_seq_length"): + llm_config.quantization.calibration_seq_length = args.calibration_seq_length + if hasattr(args, "calibration_data"): + llm_config.quantization.calibration_data = args.calibration_data + + # BackendConfig - XNNPack + if hasattr(args, "xnnpack"): + llm_config.backend.xnnpack.enabled = args.xnnpack + if hasattr(args, "xnnpack_extended_ops"): + llm_config.backend.xnnpack.extended_ops = args.xnnpack_extended_ops + + # CoreML + if hasattr(args, "coreml"): + llm_config.backend.coreml.enabled = args.coreml + llm_config.backend.coreml.enable_state = getattr( + args, "coreml_enable_state", False + ) + llm_config.backend.coreml.preserve_sdpa = getattr( + args, "coreml_preserve_sdpa", False + ) + if hasattr(args, "coreml_quantize") and args.coreml_quantize: + llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize) + if hasattr(args, "coreml_ios"): + llm_config.backend.coreml.ios = args.coreml_ios + if hasattr(args, "coreml_compute_units"): + llm_config.backend.coreml.compute_units = CoreMLComputeUnit( + args.coreml_compute_units + ) + + # Vulkan + if hasattr(args, "vulkan"): + llm_config.backend.vulkan.enabled = args.vulkan + + # QNN + if hasattr(args, "qnn"): + llm_config.backend.qnn.enabled = args.qnn + if hasattr(args, "use_qnn_sha"): + llm_config.backend.qnn.use_sha = args.use_qnn_sha + if hasattr(args, "soc_model"): + llm_config.backend.qnn.soc_model = args.soc_model + if hasattr(args, "optimized_rotation_path"): + llm_config.backend.qnn.optimized_rotation_path = ( + args.optimized_rotation_path + ) + if hasattr(args, "num_sharding"): + llm_config.backend.qnn.num_sharding = args.num_sharding + + # MPS + if hasattr(args, "mps"): + llm_config.backend.mps.enabled = args.mps + + # DebugConfig + if hasattr(args, "profile_memory"): + llm_config.debug.profile_memory = args.profile_memory + if hasattr(args, "profile_path"): + llm_config.debug.profile_path = args.profile_path + if hasattr(args, "generate_etrecord"): + llm_config.debug.generate_etrecord = args.generate_etrecord + if hasattr(args, "generate_full_logits"): + llm_config.debug.generate_full_logits = args.generate_full_logits + if hasattr(args, "verbose"): + llm_config.debug.verbose = args.verbose return llm_config From bc91c87e1e7fefae15eb80d5ced374f9a7b06d71 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 9 Jun 2025 16:22:30 -0700 Subject: [PATCH 4/6] Use llm_config instead of args in export_llama functions Pull Request resolved: https://github.com/pytorch/executorch/pull/11162 @imported-using-ghimport Differential Revision: [D75484927](https://our.internmc.facebook.com/intern/diff/D75484927/) ghstack-source-id: 289273387 --- backends/arm/test/models/test_llama.py | 4 +- examples/apple/mps/scripts/mps_example.py | 29 +- examples/models/llama/TARGETS | 1 + examples/models/llama/eval_llama_lib.py | 60 +-- examples/models/llama/export_llama_hydra.py | 3 +- examples/models/llama/export_llama_lib.py | 344 ++++++++---------- examples/models/llama/model.py | 93 ++--- examples/models/llama/runner/eager.py | 56 ++- .../llama/tests/test_export_llama_lib.py | 4 +- .../models/llama3_2_vision/runner/eager.py | 18 +- examples/models/llava/export_llava.py | 73 ++-- 11 files changed, 354 insertions(+), 331 deletions(-) diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index d0a18d88b9d..c11ff478e6f 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -22,6 +22,7 @@ TosaPipelineMI, ) +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import ( build_args_parser, get_llama_model, @@ -89,8 +90,9 @@ def prepare_model(self): ] parser = build_args_parser() args = parser.parse_args(args) + llm_config = LlmConfig.from_args(args) - llama_model, llama_inputs, llama_meta = get_llama_model(args) + llama_model, llama_inputs, llama_meta = get_llama_model(llm_config) return llama_model, llama_inputs, llama_meta diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index 2fc67bcca0e..5ccbc987b4d 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -20,6 +20,7 @@ serialize_from_bundled_program_to_flatbuffer, ) +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, @@ -131,28 +132,24 @@ def parse_args(): return args -def get_model_config(args): - model_config = {} - model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0] - model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1] - - if args.model_name == "llama2": - if args.checkpoint: - model_config["checkpoint"] = args.checkpoint - if args.params: - model_config["params"] = args.params - model_config["use_kv_cache"] = True - return model_config - - if __name__ == "__main__": # noqa: C901 args = parse_args() if args.model_name not in MODEL_NAME_TO_MODEL: raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.") - model_config = get_model_config(args) - model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config) + llm_config = LlmConfig() + if args.model_name == "llama2": + if args.checkpoint: + llm_config.base.checkpoint = args.checkpoint + if args.params: + llm_config.base.params = args.params + llm_config.model.use_kv_cache = True + model, example_inputs, _, _ = EagerModelFactory.create_model( + module_name=MODEL_NAME_TO_MODEL[args.model_name][0], + model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1], + llm_config=llm_config, + ) model = model.eval() diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index b51e164d483..86b7e957628 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -67,6 +67,7 @@ runtime.python_library( "//caffe2:torch", "//executorch/examples/models:model_base", "//executorch/examples/models/llama:llama_transformer", + "//executorch/examples/models/llama/config:llm_config", "//executorch/examples/models:checkpoint", ], ) diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 47b13df52e0..20ba6dbaa9f 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -164,6 +164,7 @@ def _model_call(self, inps): def gen_eval_wrapper( model_name: str, args: argparse.ArgumentParser, + llm_config=None, ): """ Generates a wrapper interface around the provided model and tokenizer for @@ -172,7 +173,13 @@ def gen_eval_wrapper( Returns: eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. """ - tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore + # If llm_config is not provided, convert args to llm_config + if llm_config is None: + from executorch.examples.models.llama.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + + tokenizer = get_tokenizer(llm_config.base.tokenizer_path) # ExecuTorch Binary Evaluation if (model := args.pte) is not None: # pyre-ignore @@ -182,7 +189,7 @@ def gen_eval_wrapper( model=model, tokenizer=tokenizer, tokenizer_bin=tokenizer_bin, - max_seq_length=args.max_seq_length, # pyre-ignore + max_seq_length=llm_config.export.max_seq_length, ) # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings @@ -191,12 +198,14 @@ def gen_eval_wrapper( tokenizer=tokenizer, # Exported model takes at most (max_seq_length - 1) tokens. # Note that the eager model takes at most max_seq_length tokens. - max_seq_length=args.max_seq_length - 1, + max_seq_length=llm_config.export.max_seq_length - 1, ) - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) # GPTFastEvalWrapper: Create a wrapper around a pre-exported model - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) if len(quantizers) != 0: manager = manager.export().pt2e_quantize(quantizers) @@ -208,9 +217,9 @@ def gen_eval_wrapper( return GraphModuleEvalWrapper( model=model, tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - use_kv_cache=args.use_kv_cache, # pyre-ignore - enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore + max_seq_length=llm_config.export.max_seq_length, + use_kv_cache=llm_config.model.use_kv_cache, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch @@ -234,8 +243,8 @@ def gen_eval_wrapper( return EagerEvalWrapper( model=model, tokenizer=tokenizer, - max_seq_length=args.max_seq_length, - use_kv_cache=args.use_kv_cache, + max_seq_length=llm_config.export.max_seq_length, + use_kv_cache=llm_config.model.use_kv_cache, ) @@ -296,12 +305,16 @@ def eval_llama( model_name: str, args: argparse.ArgumentParser, ) -> None: + # Convert args to LlmConfig + from executorch.examples.models.llama.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + # Generate the eval wrapper - eval_wrapper = gen_eval_wrapper(model_name, args) + eval_wrapper = gen_eval_wrapper(model_name, args, llm_config) # Needed for loading mmlu dataset. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files - # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks` if args.tasks and "mmlu" in args.tasks: import datasets @@ -312,8 +325,8 @@ def eval_llama( eval_results = simple_evaluate( model=eval_wrapper, tasks=args.tasks, - num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot` - limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit` + num_fewshot=args.num_fewshot, + limit=args.limit, ) for task, res in eval_results["results"].items(): @@ -326,19 +339,24 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py """ - assert args.use_attention_sink is not None # pyre-ignore [16] - assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16] - attention_sink_params = args.use_attention_sink.split(",") + # Convert args to LlmConfig + from executorch.examples.models.llama.config.llm_config import LlmConfig + + llm_config = LlmConfig.from_args(args) + + assert llm_config.model.use_attention_sink is not None + assert args.attention_sink_eval_tokens > 0 + attention_sink_params = llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) - assert args.max_seq_length == sink_size + window_size # pyre-ignore [16] + assert llm_config.export.max_seq_length == sink_size + window_size device = "cuda" if torch.cuda.is_available() else "cpu" - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) model = manager.model.eval().to(device=device) - tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16] + tokenizer = get_tokenizer(llm_config.base.tokenizer_path) eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") @@ -347,7 +365,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse progress_bar = tqdm(total=args.attention_sink_eval_tokens) input_pos = 0 while input_pos < args.attention_sink_eval_tokens: - for text in eval_data["text"]: # pyre-ignore [16] + for text in eval_data["text"]: tokens = tokenizer.encode(text, bos=False, eos=False) if len(tokens) <= 0: continue diff --git a/examples/models/llama/export_llama_hydra.py b/examples/models/llama/export_llama_hydra.py index 73eca7e2a5a..4871de00e25 100644 --- a/examples/models/llama/export_llama_hydra.py +++ b/examples/models/llama/export_llama_hydra.py @@ -13,6 +13,7 @@ from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import export_llama from hydra.core.config_store import ConfigStore +from omegaconf import OmegaConf cs = ConfigStore.instance() cs.store(name="llm_config", node=LlmConfig) @@ -20,7 +21,7 @@ @hydra.main(version_base=None, config_name="llm_config") def main(llm_config: LlmConfig) -> None: - export_llama(llm_config) + export_llama(OmegaConf.to_object(llm_config)) if __name__ == "__main__": diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 12406cc762e..1f055d65822 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -52,7 +52,6 @@ get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace -from omegaconf.dictconfig import DictConfig from ..model_factory import EagerModelFactory from .source_transformation.apply_spin_quant_r1_r2 import ( @@ -153,7 +152,8 @@ def build_model( argString = f"--model {model} --checkpoint {checkpoint} --params {params} {extra_opts} --output-dir {output_dir}" parser = build_args_parser() args = parser.parse_args(shlex.split(argString)) - return export_llama(args) + llm_config = LlmConfig.from_args(args) + return export_llama(llm_config) def parse_list_of_ints(s): @@ -571,54 +571,53 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: def export_llama( - export_options: Union[argparse.Namespace, DictConfig], + export_options: Union[argparse.Namespace, LlmConfig], ) -> str: if isinstance(export_options, argparse.Namespace): # Legacy CLI. - args = export_options - llm_config = LlmConfig.from_args(export_options) # noqa: F841 - elif isinstance(export_options, DictConfig): + llm_config = LlmConfig.from_args(export_options) + elif isinstance(export_options, LlmConfig): # Hydra CLI. - llm_config = export_options # noqa: F841 + llm_config = export_options else: raise ValueError( "Input to export_llama must be either of type argparse.Namespace or LlmConfig" ) - # TODO: refactor rest of export_llama to use llm_config instead of args. - # If a checkpoint isn't provided for an HF OSS model, download and convert the # weights first. - if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS: - repo_id = HUGGING_FACE_REPO_IDS[args.model] - if args.model == "qwen2_5": + model_name = llm_config.base.model_class + if not llm_config.base.checkpoint and model_name in HUGGING_FACE_REPO_IDS: + repo_id = HUGGING_FACE_REPO_IDS[model_name] + if model_name == "qwen2_5": from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21] convert_weights, ) - elif args.model.startswith("qwen3"): + elif model_name.startswith("qwen3"): from executorch.examples.models.qwen3 import ( # pyre-ignore[21] convert_weights, ) - elif args.model == "phi_4_mini": + elif model_name == "phi_4_mini": from executorch.examples.models.phi_4_mini import ( # pyre-ignore[21] convert_weights, ) - elif args.model == "smollm2": + elif model_name == "smollm2": from executorch.examples.models.smollm2 import ( # pyre-ignore[21] convert_weights, ) else: raise ValueError( - f"Converting weights to meta format for {args.model} is not yet supported" + f"Converting weights to meta format for {model_name} is not yet supported" ) - args.checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights) + checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights) + llm_config.base.checkpoint = checkpoint - if args.profile_path is not None: + if llm_config.debug.profile_path is not None: try: from executorch.util.python_profiler import CProfilerFlameGraph - with CProfilerFlameGraph(args.profile_path): - builder = _export_llama(args) + with CProfilerFlameGraph(llm_config.debug.profile_path): + builder = _export_llama(llm_config) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" @@ -629,14 +628,14 @@ def export_llama( ) return "" else: - builder = _export_llama(args) + builder = _export_llama(llm_config) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" return filename -def _prepare_for_llama_export(args) -> LLMEdgeManager: +def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: """ Helper function for export_llama. Loads the model from checkpoint and params, and sets up a LLMEdgeManager with initial transforms and dtype conversion. @@ -644,41 +643,30 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: Returns a LLMEdgeManager prior to calling export_to_edge with quantizers """ # load model from checkpoint and params.json - checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None + checkpoint_path = ( + canonical_path(llm_config.base.checkpoint) + if llm_config.base.checkpoint + else None + ) checkpoint_dir = ( - canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None + canonical_path(llm_config.base.checkpoint_dir) + if llm_config.base.checkpoint_dir + else None ) - params_path = canonical_path(args.params) if args.params else None - output_dir_path = canonical_path(args.output_dir, dir=True) - weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA - - # Convert dtype override string arg to actual type. - dtype_override = DType[args.dtype_override] - - edge_manager = _load_llama_model( - args.model, - checkpoint=checkpoint_path, - checkpoint_dir=checkpoint_dir, - params_path=params_path, - use_kv_cache=args.use_kv_cache, - use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, - generate_full_logits=args.generate_full_logits, - weight_type=weight_type, - enable_dynamic_shape=args.enable_dynamic_shape, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, - calibration_data=args.calibration_data, - tokenizer_path=args.tokenizer_path, - verbose=args.verbose, - max_seq_len=args.max_seq_length, - max_context_len=args.max_context_length, - input_prune_map_path=args.input_prune_map, - output_prune_map_path=args.output_prune_map, - metadata_str=args.metadata, - dtype_override=dtype_override, - args=args, + params_path = ( + canonical_path(llm_config.base.params) if llm_config.base.params else None ) + output_dir_path = canonical_path(llm_config.export.output_dir, dir=True) + + llm_config.base.checkpoint = checkpoint_path + llm_config.base.checkpoint_dir = checkpoint_dir + llm_config.base.params = params_path + llm_config.export.output_dir = output_dir_path + + # Convert dtype override string to actual type. + dtype_override = DType[llm_config.model.dtype_override] + + edge_manager = _load_llama_model(llm_config) # At this point, the model is loaded in the default fp32. @@ -707,64 +695,64 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( _get_source_transforms( dtype_override=dtype_override, - checkpoint=args.checkpoint, + checkpoint=llm_config.base.checkpoint, checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore - tokenizer_path=args.tokenizer_path, - use_spin_quant=args.use_spin_quant, - embedding_quantize=args.embedding_quantize, - use_shared_embedding=args.use_shared_embedding, - quantization_mode=args.quantization_mode, - group_size=args.group_size, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, - expand_rope_table=args.expand_rope_table, + tokenizer_path=llm_config.base.tokenizer_path, + use_spin_quant=llm_config.quantization.use_spin_quant, + embedding_quantize=llm_config.quantization.embedding_quantize, + use_shared_embedding=llm_config.model.use_shared_embedding, + quantization_mode=llm_config.quantization.qmode, + group_size=llm_config.quantization.group_size, + calibration_tasks=llm_config.quantization.calibration_tasks, + calibration_limit=llm_config.quantization.calibration_limit, + calibration_seq_length=llm_config.quantization.calibration_seq_length, + expand_rope_table=llm_config.model.expand_rope_table, use_custom_sdpa_with_attention_mask=getattr( - args, "use_custom_sdpa_with_attention_mask", False + llm_config.model, "use_custom_sdpa_with_attention_mask", False ), - use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, - quantize_kv_cache=args.quantize_kv_cache, - use_kv_cache=args.use_kv_cache, - qnn=args.qnn, - use_qnn_sha=args.use_qnn_sha, - optimized_rotation_path=args.optimized_rotation_path, - mps=args.mps, - coreml=args.coreml, - coreml_ios=args.coreml_ios, - vulkan=args.vulkan, - use_qat=args.use_qat, - use_lora=args.use_lora, - preq_mode=args.preq_mode, - preq_group_size=args.preq_group_size, - preq_embedding_quantize=args.preq_embedding_quantize, - local_global_attention=args.local_global_attention, + use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache, + quantize_kv_cache=llm_config.model.quantize_kv_cache, + use_kv_cache=llm_config.model.use_kv_cache, + qnn=llm_config.backend.qnn.enabled, + use_qnn_sha=llm_config.backend.qnn.use_sha, + optimized_rotation_path=llm_config.backend.qnn.optimized_rotation_path, + mps=llm_config.backend.mps.enabled, + coreml=llm_config.backend.coreml.enabled, + coreml_ios=llm_config.backend.coreml.ios, + vulkan=llm_config.backend.vulkan.enabled, + use_qat=llm_config.quantization.use_qat, + use_lora=llm_config.base.use_lora, + preq_mode=llm_config.base.preq_mode, + preq_group_size=llm_config.base.preq_group_size, + preq_embedding_quantize=llm_config.base.preq_embedding_quantize, + local_global_attention=llm_config.model.local_global_attention, ) ) return edge_manager -def get_quantizer_and_quant_params(args): +def get_quantizer_and_quant_params(llm_config): pt2e_quant_params = get_pt2e_quantization_params( - args.pt2e_quantize, args.quantization_mode + llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode ) - quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library) + quantizers = get_pt2e_quantizers(pt2e_quant_params, llm_config.export.so_library) quant_dtype = None - if args.qnn and args.pt2e_quantize: + if llm_config.backend.qnn.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack and qnn" qnn_quantizer, quant_dtype = get_qnn_quantizer( - args.pt2e_quantize, args.quantization_mode + llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode ) quantizers.append(qnn_quantizer) - if args.coreml and args.pt2e_quantize: + if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize: assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" - coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize) + coreml_quantizer = get_coreml_quantizer(llm_config.quantization.pt2e_quantize) quantizers.append(coreml_quantizer) - if args.vulkan and args.pt2e_quantize: + if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize: assert ( len(quantizers) == 0 ), "Should not enable both vulkan and other quantizers" - vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize) + vulkan_quantizer = get_vulkan_quantizer(llm_config.quantization.pt2e_quantize) quantizers.append(vulkan_quantizer) logging.info(f"Applying quantizers: {quantizers}") return pt2e_quant_params, quantizers, quant_dtype @@ -787,28 +775,28 @@ def _qmode_type(value): ) -def _validate_args(args): - """ - TODO: Combine all the backends under --backend args - """ - - if args.max_context_length < args.max_seq_length: +def _validate_args(llm_config): + if llm_config.export.max_context_length < llm_config.export.max_seq_length: raise ValueError( - f"max_context_length {args.max_context_length} must be >= max_seq_len {args.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length." + f"max_context_length {llm_config.export.max_context_length} must be >= max_seq_len {llm_config.export.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length." ) - if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn): + if llm_config.model.enable_dynamic_shape and ( + llm_config.backend.coreml.enabled + or llm_config.backend.mps.enabled + or llm_config.backend.qnn.enabled + ): raise ValueError( "Dynamic shape is not supported with coreml, MPS or qnn backends." " Please use --disable_dynamic_shape." ) - if args.num_sharding > 0 and not args.qnn: + if llm_config.backend.qnn.num_sharding > 0 and not llm_config.backend.qnn.enabled: raise ValueError("Model shard is only supported with qnn backend now.") - if args.use_shared_embedding: + if llm_config.model.use_shared_embedding: if not ( - args.embedding_quantize is not None - and args.embedding_quantize.startswith("torchao:") + llm_config.quantization.embedding_quantize is not None + and llm_config.quantization.embedding_quantize.startswith("torchao:") ): raise ValueError( "Shared embedding is only supported with torchao quantization." @@ -1033,28 +1021,30 @@ def _to_edge_and_lower_llama( # noqa: C901 return builder -def _export_llama(args) -> LLMEdgeManager: # noqa: C901 - _validate_args(args) +def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 + _validate_args(llm_config) - pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) additional_passes = [] - if args.model in TORCHTUNE_DEFINED_MODELS: + if llm_config.base.model_class in TORCHTUNE_DEFINED_MODELS: additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] # export_to_edge - builder_exported = _prepare_for_llama_export(args).export() + builder_exported = _prepare_for_llama_export(llm_config).export() builder_exported.run_canonical_optimizations() modelname = builder_exported.modelname - if args.export_only: + if llm_config.export.export_only: exit() if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None: - # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False - args.xnnpack = True + # Force xnnpack to be true if pt2e_quant_params is not None and xnnpack is False + llm_config.backend.xnnpack.enabled = True - if args.xnnpack: + if llm_config.backend.xnnpack.enabled: builder = _to_edge_and_lower_llama_xnnpack( builder_exported, modelname, @@ -1062,9 +1052,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - xnnpack_extended_ops=args.xnnpack_extended_ops, - generate_etrecord=args.generate_etrecord, - verbose=args.verbose, + xnnpack_extended_ops=llm_config.backend.xnnpack.extended_ops, + generate_etrecord=llm_config.debug.generate_etrecord, + verbose=llm_config.debug.verbose, ) else: builder = _to_edge_and_lower_llama( @@ -1074,33 +1064,33 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype, - vulkan=args.vulkan, - mps=args.mps, - coreml=args.coreml, - qnn=args.qnn, - dtype_override=args.dtype_override, - enable_dynamic_shape=args.enable_dynamic_shape, - use_kv_cache=args.use_kv_cache, - embedding_quantize=args.embedding_quantize, - pt2e_quantize=args.pt2e_quantize, - coreml_ios=args.coreml_ios, - coreml_quantize=args.coreml_quantize, - coreml_compute_units=args.coreml_compute_units, - use_qnn_sha=args.use_qnn_sha, - num_sharding=args.num_sharding, - soc_model=args.soc_model, - generate_etrecord=args.generate_etrecord, - verbose=args.verbose, + vulkan=llm_config.backend.vulkan.enabled, + mps=llm_config.backend.mps.enabled, + coreml=llm_config.backend.coreml.enabled, + qnn=llm_config.backend.qnn.enabled, + dtype_override=llm_config.model.dtype_override, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + use_kv_cache=llm_config.model.use_kv_cache, + embedding_quantize=llm_config.quantization.embedding_quantize, + pt2e_quantize=llm_config.quantization.pt2e_quantize, + coreml_ios=llm_config.backend.coreml.ios, + coreml_quantize=llm_config.backend.coreml.quantize, + coreml_compute_units=llm_config.backend.coreml.compute_units, + use_qnn_sha=llm_config.backend.qnn.use_sha, + num_sharding=llm_config.backend.qnn.num_sharding, + soc_model=llm_config.backend.qnn.soc_model, + generate_etrecord=llm_config.debug.generate_etrecord, + verbose=llm_config.debug.verbose, ) - if args.profile_memory: + if llm_config.debug.profile_memory: generate_memory_trace(builder.export_program, "memory_profile.json") if builder.dtype == DType.fp16: modelname = f"{modelname}_h" - if args.output_name: - modelname = args.output_name + if llm_config.export.output_name: + modelname = llm_config.export.output_name if modelname.endswith(".pte"): output_file = modelname modelname = modelname[:-4] @@ -1150,31 +1140,7 @@ def _load_llama_model_metadata( return metadata -def _load_llama_model( - modelname: str = "llama3", - *, - checkpoint: Optional[str] = None, - checkpoint_dir: Optional[str] = None, - params_path: Optional[str] = None, - use_kv_cache: bool = False, - use_sdpa_with_kv_cache: bool = False, - generate_full_logits: bool = False, - weight_type: WeightType = WeightType.LLAMA, - enable_dynamic_shape: bool = False, - calibration_tasks: Optional[List[str]] = None, - calibration_limit: Optional[int] = None, - calibration_seq_length: Optional[int] = None, - calibration_data: Optional[str] = None, - tokenizer_path: Optional[str] = None, - verbose: bool = False, - max_seq_len: int = 128, - max_context_len: int = 128, - input_prune_map_path: Optional[str] = None, - output_prune_map_path: Optional[str] = None, - metadata_str: Optional[str] = None, - dtype_override: Optional[DType] = None, - args, -) -> "LLMEdgeManager": +def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": """ A helper util that builds a Llama2 model. It returns a LLMEdgeManager that can help further lower the model to ExecuTorch. @@ -1182,6 +1148,7 @@ def _load_llama_model( An instance of LLMEdgeManager which contains the eager mode model. """ + modelname = llm_config.base.model_class if modelname in EXECUTORCH_DEFINED_MODELS: module_name = "llama" model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py. @@ -1194,53 +1161,40 @@ def _load_llama_model( else: raise ValueError(f"{modelname} is not a valid Llama model.") - torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None - model, example_inputs, example_kwarg_inputs, dynamic_shapes = ( EagerModelFactory.create_model( module_name, model_class_name, - checkpoint=checkpoint, - checkpoint_dir=checkpoint_dir, - params=params_path, - use_kv_cache=use_kv_cache, - use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, - generate_full_logits=generate_full_logits, - fairseq2=weight_type == WeightType.FAIRSEQ2, - max_seq_len=max_seq_len, - max_context_len=max_context_len, - enable_dynamic_shape=enable_dynamic_shape, - input_prune_map_path=input_prune_map_path, - output_prune_map_path=output_prune_map_path, - dtype=torch_dtype, - args=args, + llm_config=llm_config, ) ) + # Convert dtype override string to actual type. + dtype_override = DType[llm_config.model.dtype_override] return LLMEdgeManager( model=model, modelname=modelname, max_seq_len=model.max_seq_len, # type: ignore dtype=dtype_override, - use_kv_cache=use_kv_cache, - generate_full_logits=generate_full_logits, + use_kv_cache=llm_config.model.use_kv_cache, + generate_full_logits=llm_config.debug.generate_full_logits, example_inputs=example_inputs, example_kwarg_inputs=example_kwarg_inputs, dynamic_shapes=dynamic_shapes, - enable_dynamic_shape=enable_dynamic_shape, - calibration_tasks=calibration_tasks, - calibration_limit=calibration_limit, - calibration_seq_length=calibration_seq_length, - calibration_data=calibration_data, - tokenizer_path=tokenizer_path, - use_legacy_export=args.qnn, - save_exported_program=args.export_only, - verbose=verbose, + enable_dynamic_shape=llm_config.model.enable_dynamic_shape, + calibration_tasks=llm_config.quantization.calibration_tasks, + calibration_limit=llm_config.quantization.calibration_limit, + calibration_seq_length=llm_config.quantization.calibration_seq_length, + calibration_data=llm_config.quantization.calibration_data, + tokenizer_path=llm_config.base.tokenizer_path, + use_legacy_export=llm_config.backend.qnn.enabled, + save_exported_program=llm_config.export.export_only, + verbose=llm_config.debug.verbose, metadata=_load_llama_model_metadata( - weight_type, - use_kv_cache, - use_sdpa_with_kv_cache, - enable_dynamic_shape, + WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA, + llm_config.model.use_kv_cache, + llm_config.model.use_sdpa_with_kv_cache, + llm_config.model.enable_dynamic_shape, # pyre-fixme[6]: For 5th argument expected `ModelArgs` but got # `Union[Tensor, Module]`. model.max_seq_len, @@ -1253,7 +1207,7 @@ def _load_llama_model( # pyre-fixme[6]: For 8th argument expected `int` but got `Union[Tensor, # Module]`. model.vocab_size, - metadata_str, + llm_config.base.metadata, ), ) @@ -1470,9 +1424,9 @@ def _get_source_transforms( # noqa return transforms -def get_llama_model(args): - _validate_args(args) - e_mgr = _prepare_for_llama_export(args) +def get_llama_model(llm_config: LlmConfig): + _validate_args(llm_config) + e_mgr = _prepare_for_llama_export(llm_config) model = ( e_mgr.model.eval().to(device="cuda") if torch.cuda.is_available() diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index d6400c29db8..ec9646be6f4 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -8,7 +8,7 @@ import json import os -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import torch from executorch.examples.models.checkpoint import ( @@ -16,6 +16,7 @@ get_default_model_resource_dir, ) +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope @@ -36,26 +37,24 @@ def convert_to_llama_checkpoint(**kwargs): class Llama2Model(EagerModelBase): - def __init__(self, **kwargs): + def __init__(self, llm_config: Optional[LlmConfig] = None): resource_dir = get_default_model_resource_dir(__file__) - # Use single checkpoint file. - checkpoint_path = kwargs.get("checkpoint", None) - # Check if checkpoint_dir was provided for a sharded checkpoint. - checkpoint_dir = kwargs.get("checkpoint_dir", None) + self.llm_config = llm_config if llm_config else LlmConfig() - # Params file. - params_path = kwargs.get("params", None) + checkpoint_path = self.llm_config.base.checkpoint + checkpoint_dir = self.llm_config.base.checkpoint_dir + params_path = self.llm_config.base.params - self.use_kv_cache = kwargs.get("use_kv_cache", False) - self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) - self.generate_full_logits = kwargs.get("generate_full_logits", False) - self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) - self.input_prune_map_path = kwargs.get("input_prune_map_path", None) - self.output_prune_map_path = kwargs.get("output_prune_map_path", None) - self.max_seq_len = kwargs.get("max_seq_len", 128) - self.max_context_len = kwargs.get("max_context_len", 128) - self.args = kwargs.get("args", None) + self.use_kv_cache = self.llm_config.model.use_kv_cache + self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache + self.generate_full_logits = self.llm_config.debug.generate_full_logits + self.enable_dynamic_shape = self.llm_config.model.enable_dynamic_shape + self.input_prune_map_path = self.llm_config.model.input_prune_map + self.output_prune_map_path = self.llm_config.model.output_prune_map + self.max_seq_len = self.llm_config.export.max_seq_length + self.max_context_len = self.llm_config.export.max_context_length + self.verbose = self.llm_config.debug.verbose assert ( self.max_context_len >= self.max_seq_len @@ -99,7 +98,7 @@ def __init__(self, **kwargs): checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True) # If given checkpoint is fairseq, convert to llama checkpoint. - fairseq2_checkpoint = kwargs.get("fairseq2", False) + fairseq2_checkpoint = self.llm_config.base.fairseq2 if fairseq2_checkpoint: print("Using fairseq2 checkpoint") checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint) @@ -158,13 +157,14 @@ def __init__(self, **kwargs): if model_args.use_scaled_rope: # Older models don't have use_scaled_rope configuration - assert self.args.model not in ["llama2", "stories110m"] + model_name = str(self.llm_config.base.model_class) + assert model_name not in ["llama2", "stories110m"] # Llama3_2 and newer models in ExecuTorch repo should set larger scale factor - if self.args.model not in ["llama3", "llama3_1"]: + if model_name not in ["llama3", "llama3_1"]: model_args.rope_scale_factor = 32 - if kwargs.get("verbose", False): + if self.verbose: print("============= weights ================") print("{key} : {weights.numel()} : {weights.size()}") for key, weights in checkpoint.items(): @@ -196,7 +196,7 @@ def __init__(self, **kwargs): self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime( self.model_ ) - elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: + elif self.llm_config.quantization.use_spin_quant: print("Using SPIN quantization.") self._transform_for_pre_quantization(checkpoint, model_args) @@ -205,11 +205,12 @@ def __init__(self, **kwargs): ) sanitize_checkpoint_from_pre_quantization(checkpoint) - elif hasattr(self.args, "use_qat") and self.args.use_qat: + elif self.llm_config.quantization.use_qat: print("Using QAT quantization.") self._transform_for_pre_quantization(checkpoint, model_args) - if hasattr(self.args, "use_lora") and self.args.use_lora: - assert model_args.lora_args["rank"] == self.args.use_lora + if self.llm_config.base.use_lora: + lora_rank = self.llm_config.base.use_lora + assert model_args.lora_args["rank"] == lora_rank from .source_transformation.lora import ( transform_linear_for_lora_after_quantization, ) @@ -217,7 +218,7 @@ def __init__(self, **kwargs): self.model_ = transform_linear_for_lora_after_quantization( self.model_, checkpoint, - self.args.use_lora, + lora_rank, ) from .source_transformation.pre_quantization import ( @@ -226,16 +227,16 @@ def __init__(self, **kwargs): sanitize_checkpoint_from_pre_quantization(checkpoint) - if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink: + if self.llm_config.model.use_attention_sink: from .source_transformation.attention_sink import enable_attention_sink - attention_sink_params = self.args.use_attention_sink.split(",") + attention_sink_params = self.llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) eviction_batch_size = int(attention_sink_params[2]) - assert self.args.max_context_length == sink_size + window_size + assert self.llm_config.export.max_context_length == sink_size + window_size self.model_ = enable_attention_sink( module=self.model_, @@ -278,7 +279,7 @@ def __init__(self, **kwargs): f"The provided checkpoint is missing the following weights that are expected by the model: {missing_weights}. Please fix the fqn's in your checkpoint to match." ) if unexpected: - if kwargs.get("verbose", False): + if self.verbose: print(f"Unexpected keys: {unexpected}") # Prune the input layer if input_prune_map is provided @@ -326,20 +327,22 @@ def get_example_inputs_kvcache_sdpa(self): ) def _transform_for_pre_quantization(self, checkpoint, model_args): - assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" - assert self.args.preq_mode in [ + assert self.llm_config.base.preq_mode, "preq_mode must be specified" + assert self.llm_config.base.preq_mode in [ "8da4w", "8da4w_output_8da8w", - ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant." - assert hasattr( - self.args, "preq_group_size" - ), "preq_group_size must be specified" - assert hasattr(self.args, "dtype_override"), "dtype_override must be specified" + ], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant." + assert self.llm_config.base.preq_group_size, "preq_group_size must be specified" + assert self.llm_config.model.dtype_override, "dtype_override must be specified" + from .source_transformation.pre_quantization import ( transform_linear_for_pre_quantization, ) - assert self.args.preq_group_size == model_args.quantization_args["group_size"] + assert ( + self.llm_config.base.preq_group_size + == model_args.quantization_args["group_size"] + ) mapping = { "fp32": torch.float32, @@ -348,7 +351,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): } # Transform the output layer first if needed. - if self.args.preq_mode == "8da4w_output_8da8w": + if self.llm_config.base.preq_mode == "8da4w_output_8da8w": from .source_transformation.pre_quantization import ( transform_output_linear_for_pre_quantization, ) @@ -356,20 +359,20 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): self.model_ = transform_output_linear_for_pre_quantization( module=self.model_, checkpoint=checkpoint, - dtype=mapping[self.args.dtype_override], + dtype=mapping[self.llm_config.model.dtype_override], ) self.model_ = transform_linear_for_pre_quantization( self.model_, checkpoint, - self.args.preq_group_size, - mapping[self.args.dtype_override], + self.llm_config.base.preq_group_size, + mapping[self.llm_config.model.dtype_override], ) embedding_bit_width, embedding_group_size = None, None - if hasattr(self.args, "preq_embedding_quantize"): + if self.llm_config.base.preq_embedding_quantize: embedding_bit_width, embedding_group_size = ( - self.args.preq_embedding_quantize.split(",") + self.llm_config.base.preq_embedding_quantize.split(",") ) from .source_transformation.pre_quantization import ( transform_embedding_for_pre_quantization, @@ -387,7 +390,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): self.model_ = transform_embedding_for_pre_quantization( self.model_, checkpoint, - mapping[self.args.dtype_override], + mapping[self.llm_config.model.dtype_override], int(embedding_bit_width), embedding_group_size, ) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 0b842a8f976..c55ad0eea28 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -10,6 +10,7 @@ import torch +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import ( _prepare_for_llama_export, build_args_parser as _build_args_parser, @@ -23,19 +24,24 @@ class EagerLlamaRunner(LlamaRunner): Runs llama in eager mode with provided checkpoint file. """ - def __init__(self, args): - with open(args.params, "r") as f: + def __init__( + self, + llm_config: LlmConfig, + tokenizer_config_path: Optional[str] = None, + use_attention_sink: bool = False, + ): + with open(llm_config.base.params, "r") as f: params = json.loads(f.read()) super().__init__( - tokenizer_path=args.tokenizer_path, - tokenizer_config_path=args.tokenizer_config_path, - max_seq_len=args.max_seq_length, + tokenizer_path=llm_config.base.tokenizer_path, + tokenizer_config_path=tokenizer_config_path, + max_seq_len=llm_config.export.max_seq_length, max_batch_size=1, - use_kv_cache=args.use_kv_cache, + use_kv_cache=llm_config.model.use_kv_cache, vocab_size=params["vocab_size"], device="cuda" if torch.cuda.is_available() else "cpu", ) - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) self.model = manager.model.eval().to(device=self.device) def forward( @@ -49,6 +55,7 @@ def forward( def build_args_parser() -> argparse.ArgumentParser: parser = _build_args_parser() + # Runner-specific arguments that aren't part of LlmConfig parser.add_argument( "--prompt", type=str, @@ -89,22 +96,41 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None: parser = build_args_parser() args = parser.parse_args() + # Convert args to LlmConfig for model configuration. + llm_config = LlmConfig.from_args(args) + + # Extract runner-specific parameters. + prompt = args.prompt + temperature = args.temperature + show_tokens = args.show_tokens + chat_mode = args.chat + tokenizer_config_path = args.tokenizer_config_path + use_attention_sink = args.use_attention_sink + with torch.no_grad(): - runner = runner_class(args) # pyre-ignore: Missing argument [20] + # Create runner with LlmConfig and separate runner parameters. + runner = runner_class( + llm_config=llm_config, + tokenizer_config_path=tokenizer_config_path, + use_attention_sink=use_attention_sink, + ) + generated_tokens = ( runner.chat_completion( - max_seq_len=1000000 if args.use_attention_sink else args.max_seq_length, - temperature=args.temperature, - show_progress=args.show_tokens, + max_seq_len=( + 1000000 if use_attention_sink else llm_config.export.max_seq_length + ), + temperature=temperature, + show_progress=show_tokens, ) - if args.chat + if chat_mode else runner.text_completion( - prompt=args.prompt, - temperature=args.temperature, + prompt=prompt, + temperature=temperature, echo=True, ) ) - if args.show_tokens: + if show_tokens: print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index b94adb5fa0c..f2ac9497604 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -7,6 +7,7 @@ import unittest from executorch.devtools.backend_debug import get_delegation_info +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import ( _export_llama, build_args_parser, @@ -40,7 +41,8 @@ def test_has_expected_ops_and_op_counts(self): args.use_kv_cache = True args.verbose = True - builder = _export_llama(args) + llm_config = LlmConfig.from_args(args) + builder = _export_llama(llm_config) graph_module = builder.edge_manager.exported_program().graph_module delegation_info = get_delegation_info(graph_module) diff --git a/examples/models/llama3_2_vision/runner/eager.py b/examples/models/llama3_2_vision/runner/eager.py index c5d91013077..5e68a43bf8e 100644 --- a/examples/models/llama3_2_vision/runner/eager.py +++ b/examples/models/llama3_2_vision/runner/eager.py @@ -8,6 +8,7 @@ from typing import Optional import torch +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import _prepare_for_llama_export from executorch.examples.models.llama.runner.eager import execute_runner @@ -22,18 +23,23 @@ class EagerLlamaRunner(TorchTuneLlamaRunner): Runs llama in eager mode with provided checkpoint file. """ - def __init__(self, args): - with open(args.params, "r") as f: + def __init__( + self, + llm_config: LlmConfig, + tokenizer_config_path: Optional[str] = None, + use_attention_sink: bool = False, + ): + with open(llm_config.base.params, "r") as f: params = json.loads(f.read()) super().__init__( - tokenizer_path=args.tokenizer_path, - max_seq_len=args.max_seq_length, + tokenizer_path=llm_config.base.tokenizer_path, + max_seq_len=llm_config.export.max_seq_length, max_batch_size=1, - use_kv_cache=args.use_kv_cache, + use_kv_cache=llm_config.model.use_kv_cache, vocab_size=params["vocab_size"], device="cuda" if torch.cuda.is_available() else "cpu", ) - manager: LLMEdgeManager = _prepare_for_llama_export(args) + manager: LLMEdgeManager = _prepare_for_llama_export(llm_config) self.model = manager.model.eval().to(device=self.device) def forward( diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 18ef83ee1e4..32b3ff448ac 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -16,8 +16,8 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import ( - build_args_parser, get_quantizer_and_quant_params, ) from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( @@ -92,32 +92,26 @@ def forward(self, input_pos, embeddings): dynamic_shapes=dynamic_shapes, ) + # Manually set some LlmConfig options. + llm_config = LlmConfig() + llm_config.base.params = "params.json" + llm_config.backend.xnnpack.enabled = True + llm_config.quantization.qmode = "8da4w" + llm_config.quantization.group_size = 128 + llm_config.quantization.embedding_quantize = "4,32" + dtype_override = DType.fp32 - parser = build_args_parser() - args = parser.parse_args( - [ - "-p", - "params.json", - "-X", - "-qmode", - "8da4w", - "--group_size", - "128", - "--embedding-quantize", - "4,32", - ] - ) quant_transform = get_quant_weight_transform( - quantization_mode=args.quantization_mode, - group_size=args.group_size, + quantization_mode=llm_config.quantization.qmode, + group_size=llm_config.quantization.group_size, computation_dtype=dtype_override, - checkpoint_path=args.checkpoint, - tokenizer_path=args.tokenizer_path, - calibration_tasks=args.calibration_tasks, - calibration_limit=args.calibration_limit, - calibration_seq_length=args.calibration_seq_length, + checkpoint_path=llm_config.base.checkpoint, + tokenizer_path=llm_config.base.tokenizer_path, + calibration_tasks=llm_config.quantization.calibration_tasks, + calibration_limit=llm_config.quantization.calibration_limit, + calibration_seq_length=llm_config.quantization.calibration_seq_length, ) - _, quantizers, _ = get_quantizer_and_quant_params(args) + _, quantizers, _ = get_quantizer_and_quant_params(llm_config) source_transforms = [] if llava.use_sdpa_with_kv_cache_op: source_transforms.append(replace_kv_cache_with_custom_kv_cache) @@ -279,6 +273,20 @@ def get_tokenizer_for_llava_runner(llava_model): t.export("tokenizer.bin") +def create_llava_config_from_args(args): + """ + Create an LlmConfig from command line arguments for LLaVA export + """ + llm_config = LlmConfig() + + llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache + llm_config.export.max_seq_length = args.max_seq_len + llm_config.export.output_name = args.pte_name + llm_config.debug.profile_memory = args.profile_memory + + return llm_config + + def main(): parser = ArgumentParser() parser.add_argument( @@ -311,28 +319,33 @@ def main(): help="Generate chrome trace of activation memory for intermediate tensors.", ) args = parser.parse_args() + + # Create LlmConfig from args + llm_config = create_llava_config_from_args(args) + logging.info( - f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {args.use_sdpa_with_kv_cache}, max_seq_len: {args.max_seq_len}" + f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {llm_config.model.use_sdpa_with_kv_cache}, max_seq_len: {llm_config.export.max_seq_length}" ) + llava_model = LlavaModel( - use_sdpa_with_kv_cache_op=args.use_sdpa_with_kv_cache, - max_seq_len=args.max_seq_len, + use_sdpa_with_kv_cache_op=llm_config.model.use_sdpa_with_kv_cache, + max_seq_len=llm_config.export.max_seq_length, ) executorch_program = export_all(llava_model) # memory profiling - if args.profile_memory: + if llm_config.debug.profile_memory: for method_name in executorch_program.methods: generate_memory_trace( executorch_program, - f"{args.pte_name}_{method_name}.json", + f"{llm_config.export.output_name}_{method_name}.json", method_name=method_name, ) - with open(args.pte_name, "wb") as f: + with open(llm_config.export.output_name, "wb") as f: executorch_program.write_to_file(f) - logging.info(f"Exported ExecuTorch program to {args.pte_name}") + logging.info(f"Exported ExecuTorch program to {llm_config.export.output_name}") # artifacts if args.with_artifacts: From c22236e0e3e66974a63bcc53365922576cb2ef65 Mon Sep 17 00:00:00 2001 From: Jack <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 10 Jun 2025 07:14:05 -0700 Subject: [PATCH 5/6] Fix merge conflict in export_llama_lib.py --- examples/models/llama/export_llama_lib.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 6f3a82b5126..1f055d65822 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -52,7 +52,6 @@ get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace -from omegaconf.dictconfig import DictConfig from ..model_factory import EagerModelFactory from .source_transformation.apply_spin_quant_r1_r2 import ( From d0e4e4695714edba067348e3ee9b0d9df4d5ebf9 Mon Sep 17 00:00:00 2001 From: Jack <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 10 Jun 2025 07:14:50 -0700 Subject: [PATCH 6/6] Fix merge conflict in mps_example.py --- examples/apple/mps/scripts/mps_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index 18e58617993..5ccbc987b4d 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -131,7 +131,7 @@ def parse_args(): args = parser.parse_args() return args - + if __name__ == "__main__": # noqa: C901 args = parse_args()