|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# Copyright 2025 Arm Limited and/or its affiliates. | 
|  | 4 | +# | 
|  | 5 | +# This source code is licensed under the BSD-style license found in the | 
|  | 6 | +# LICENSE file in the root directory of this source tree. | 
|  | 7 | + | 
|  | 8 | +# pyre-unsafe | 
|  | 9 | + | 
|  | 10 | +""" | 
|  | 11 | +Configurations for exporting Llama. | 
|  | 12 | +
 | 
|  | 13 | +Uses dataclases, which integrate with OmegaConf and Hydra. | 
|  | 14 | +""" | 
|  | 15 | + | 
|  | 16 | +import re | 
|  | 17 | +from dataclasses import dataclass, field | 
|  | 18 | +from enum import Enum | 
|  | 19 | +from typing import List, Literal, Optional | 
|  | 20 | + | 
|  | 21 | + | 
|  | 22 | +################################################################################ | 
|  | 23 | +################################## BaseConfig ################################## | 
|  | 24 | +################################################################################ | 
|  | 25 | + | 
|  | 26 | + | 
|  | 27 | +class ModelType(str, Enum): | 
|  | 28 | +    STORIES110M = "stories110m" | 
|  | 29 | +    LLAMA2 = "llama2" | 
|  | 30 | +    LLAMA3 = "llama3" | 
|  | 31 | +    LLAMA3_1 = "llama3_1" | 
|  | 32 | +    LLAMA3_2 = "llama3_2" | 
|  | 33 | +    LLAMA3_2_VISION = "llama3_2_vision" | 
|  | 34 | +    STATIC_LLAMA = "static_llama" | 
|  | 35 | +    QWEN2_5 = "qwen2_5" | 
|  | 36 | +    QWEN3_0_6B = "qwen3-0_6b" | 
|  | 37 | +    QWEN3_1_7B = "qwen3-1_7b" | 
|  | 38 | +    QWEN3_4B = "qwen3-4b" | 
|  | 39 | +    PHI_4_MINI = "phi_4_mini" | 
|  | 40 | +    SMOLLM2 = "smollm2" | 
|  | 41 | + | 
|  | 42 | + | 
|  | 43 | +class PreqMode(str, Enum): | 
|  | 44 | +    PREQ_8DA4W = "8da4w" | 
|  | 45 | +    PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w" | 
|  | 46 | + | 
|  | 47 | + | 
|  | 48 | +@dataclass | 
|  | 49 | +class BaseConfig: | 
|  | 50 | +    """ | 
|  | 51 | +    These are specific to the specific model, e.g. whether it’s Qwen3 0.6B or Phi-4-mini. | 
|  | 52 | +    For each of these different models, you can expect each of these fields to change. | 
|  | 53 | +    """ | 
|  | 54 | + | 
|  | 55 | +    model_class: ModelType = ModelType.LLAMA3 | 
|  | 56 | +    params: Optional[str] = None | 
|  | 57 | +    checkpoint: Optional[str] = None | 
|  | 58 | +    checkpoint_dir: Optional[str] = None  # For sharded checkpoint. | 
|  | 59 | +    tokenizer_path: Optional[str] = None | 
|  | 60 | +    metadata: Optional[str] = None | 
|  | 61 | +    use_lora: bool = False | 
|  | 62 | +    fairseq2: bool = False  # For legacy internal use cases. | 
|  | 63 | + | 
|  | 64 | +    # Legacy pre-quantization options that happen during model weight loading. | 
|  | 65 | +    preq_mode: Optional[PreqMode] = None | 
|  | 66 | +    preq_group_size: int = 32 | 
|  | 67 | +    preq_embedding_quantize: str = "8,0" | 
|  | 68 | + | 
|  | 69 | + | 
|  | 70 | +################################################################################ | 
|  | 71 | +################################# ModelConfig ################################## | 
|  | 72 | +################################################################################ | 
|  | 73 | + | 
|  | 74 | + | 
|  | 75 | +class DtypeOverride(str, Enum): | 
|  | 76 | +    FP32 = "fp32" | 
|  | 77 | +    FP16 = "fp16" | 
|  | 78 | +    BF16 = "bf16" | 
|  | 79 | + | 
|  | 80 | + | 
|  | 81 | +@dataclass | 
|  | 82 | +class ModelConfig: | 
|  | 83 | +    """ | 
|  | 84 | +    These are not necessarily specific to the model, but are needed to finish off | 
|  | 85 | +    the rest of the model configuration in eager. You can think of these like | 
|  | 86 | +    optimizations / actual configurations. The same ModelConfig can be applied | 
|  | 87 | +    to different models. | 
|  | 88 | +    """ | 
|  | 89 | + | 
|  | 90 | +    dtype_override: DtypeOverride = DtypeOverride.FP32 | 
|  | 91 | +    enable_dynamic_shape: bool = True | 
|  | 92 | +    use_shared_embedding: bool = False | 
|  | 93 | +    use_sdpa_with_kv_cache: bool = False | 
|  | 94 | +    expand_rope_table: bool = False | 
|  | 95 | +    use_attention_sink: Optional[str] = None | 
|  | 96 | +    output_prune_map: Optional[str] = None | 
|  | 97 | +    input_prune_map: Optional[str] = None | 
|  | 98 | + | 
|  | 99 | +    # Below are config options relating to kv cache. | 
|  | 100 | +    use_kv_cache: bool = False | 
|  | 101 | +    quantize_kv_cache: bool = False | 
|  | 102 | +    local_global_attention: Optional[List[int]] = None | 
|  | 103 | + | 
|  | 104 | + | 
|  | 105 | +################################################################################ | 
|  | 106 | +################################ ExportConfig ################################## | 
|  | 107 | +################################################################################ | 
|  | 108 | + | 
|  | 109 | + | 
|  | 110 | +@dataclass | 
|  | 111 | +class ExportConfig: | 
|  | 112 | +    max_seq_length: int = 128 | 
|  | 113 | +    max_context_length: int = 128 | 
|  | 114 | +    output_dir: Optional[str] = None | 
|  | 115 | +    output_name: Optional[str] = None | 
|  | 116 | +    so_library: Optional[str] = None | 
|  | 117 | +    export_only: bool = False | 
|  | 118 | + | 
|  | 119 | + | 
|  | 120 | +################################################################################ | 
|  | 121 | +################################# DebugConfig ################################## | 
|  | 122 | +################################################################################ | 
|  | 123 | + | 
|  | 124 | + | 
|  | 125 | +@dataclass | 
|  | 126 | +class DebugConfig: | 
|  | 127 | +    profile_memory: bool = False | 
|  | 128 | +    profile_path: Optional[str] = None | 
|  | 129 | +    generate_etrecord: bool = False | 
|  | 130 | +    generate_full_logits: bool = False | 
|  | 131 | +    verbose: bool = False | 
|  | 132 | + | 
|  | 133 | + | 
|  | 134 | +################################################################################ | 
|  | 135 | +############################# QuantizationConfig ############################### | 
|  | 136 | +################################################################################ | 
|  | 137 | + | 
|  | 138 | + | 
|  | 139 | +class Pt2eQuantize(str, Enum): | 
|  | 140 | +    XNNPACK_DYNAMIC = "xnnpack_dynamic" | 
|  | 141 | +    XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4" | 
|  | 142 | +    QNN_8A8W = "qnn_8a8w" | 
|  | 143 | +    QNN_16A16W = "qnn_16a16w" | 
|  | 144 | +    QNN_16A4W = "qnn_16a4w" | 
|  | 145 | +    COREML_C4W = "coreml_c4w" | 
|  | 146 | +    COREML_8A_C8W = "coreml_8a_c8w" | 
|  | 147 | +    COREML_8A_C4W = "coreml_8a_c4w" | 
|  | 148 | +    COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w" | 
|  | 149 | +    COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w" | 
|  | 150 | +    VULKAN_8W = "vulkan_8w" | 
|  | 151 | + | 
|  | 152 | + | 
|  | 153 | +class SpinQuant(str, Enum): | 
|  | 154 | +    CUDA = "cuda" | 
|  | 155 | +    NATIVE = "native" | 
|  | 156 | + | 
|  | 157 | + | 
|  | 158 | +@dataclass | 
|  | 159 | +class QuantizationConfig: | 
|  | 160 | +    qmode: Optional[str] = None | 
|  | 161 | +    embedding_quantize: Optional[str] = None | 
|  | 162 | +    pt2e_quantize: Optional[Pt2eQuantize] = None | 
|  | 163 | +    group_size: Optional[int] = None | 
|  | 164 | +    use_spin_quant: Optional[SpinQuant] = None | 
|  | 165 | +    use_qat: Optional[bool] = None | 
|  | 166 | +    calibration_tasks: Optional[List[str]] = None | 
|  | 167 | +    calibration_limit: Optional[int] = None | 
|  | 168 | +    calibration_seq_length: Optional[int] = None | 
|  | 169 | +    calibration_data: Optional[str] = None | 
|  | 170 | + | 
|  | 171 | +    def __post_init__(self): | 
|  | 172 | +        if self.qmode: | 
|  | 173 | +            self._validate_qmode() | 
|  | 174 | + | 
|  | 175 | +    def _validate_qmode(self) -> None: | 
|  | 176 | +        choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"] | 
|  | 177 | +        patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"] | 
|  | 178 | + | 
|  | 179 | +        if self.qmode in choices: | 
|  | 180 | +            return | 
|  | 181 | + | 
|  | 182 | +        for pattern in patterns: | 
|  | 183 | +            matches = re.findall(pattern, self.qmode) | 
|  | 184 | +            if len(matches) == 1: | 
|  | 185 | +                return | 
|  | 186 | + | 
|  | 187 | +        raise ValueError( | 
|  | 188 | +            f"Got qmode {self.qmode}, but expected one of {choices}, or one of the regex patterns {patterns}." | 
|  | 189 | +        ) | 
|  | 190 | + | 
|  | 191 | + | 
|  | 192 | +################################################################################ | 
|  | 193 | +############################### BackendConfig ################################## | 
|  | 194 | +################################################################################ | 
|  | 195 | + | 
|  | 196 | + | 
|  | 197 | +@dataclass | 
|  | 198 | +class XNNPackConfig: | 
|  | 199 | +    enabled: bool = False | 
|  | 200 | +    extended_ops: bool = False | 
|  | 201 | + | 
|  | 202 | + | 
|  | 203 | +class CoreMLQuantize(str, Enum): | 
|  | 204 | +    B4W = "b4w" | 
|  | 205 | +    C4W = "c4w" | 
|  | 206 | + | 
|  | 207 | + | 
|  | 208 | +class CoreMLComputeUnit(str, Enum): | 
|  | 209 | +    CPU_ONLY = "cpu_only" | 
|  | 210 | +    CPU_AND_GPU = "cpu_and_gpu" | 
|  | 211 | +    CPU_AND_NE = "cpu_and_ne" | 
|  | 212 | +    ALL = "all" | 
|  | 213 | + | 
|  | 214 | + | 
|  | 215 | +@dataclass | 
|  | 216 | +class CoreMLConfig: | 
|  | 217 | +    enabled: bool = False | 
|  | 218 | +    enable_state: bool = False | 
|  | 219 | +    preserve_sdpa: bool = False | 
|  | 220 | +    quantize: Optional[CoreMLQuantize] = None | 
|  | 221 | +    ios: Literal[15, 16, 17, 18] = 15 | 
|  | 222 | +    compute_units: CoreMLComputeUnit = CoreMLComputeUnit.CPU_ONLY | 
|  | 223 | + | 
|  | 224 | +    def __post_init__(self): | 
|  | 225 | +        if self.ios not in (15, 16, 17, 18): | 
|  | 226 | +            raise ValueError(f"Invalid coreml ios version: {self.ios}") | 
|  | 227 | + | 
|  | 228 | + | 
|  | 229 | +@dataclass | 
|  | 230 | +class VulkanConfig: | 
|  | 231 | +    enabled: bool = False | 
|  | 232 | + | 
|  | 233 | + | 
|  | 234 | +@dataclass | 
|  | 235 | +class QNNConfig: | 
|  | 236 | +    enabled: bool = False | 
|  | 237 | +    use_sha: bool = False | 
|  | 238 | +    soc_model: str = "SM8650" | 
|  | 239 | +    use_qnn_sha: bool = False | 
|  | 240 | +    optimized_rotation_path: Optional[str] = None | 
|  | 241 | +    num_sharding: int = 0 | 
|  | 242 | + | 
|  | 243 | + | 
|  | 244 | +@dataclass | 
|  | 245 | +class MPSConfig: | 
|  | 246 | +    enabled: Optional[bool] = False | 
|  | 247 | + | 
|  | 248 | + | 
|  | 249 | +@dataclass | 
|  | 250 | +class BackendConfig: | 
|  | 251 | +    xnnpack: XNNPackConfig = field(default_factory=XNNPackConfig) | 
|  | 252 | +    coreml: CoreMLConfig = field(default_factory=CoreMLConfig) | 
|  | 253 | +    vulkan: VulkanConfig = field(default_factory=VulkanConfig) | 
|  | 254 | +    qnn: QNNConfig = field(default_factory=QNNConfig) | 
|  | 255 | +    mps: MPSConfig = field(default_factory=MPSConfig) | 
|  | 256 | + | 
|  | 257 | + | 
|  | 258 | +################################################################################ | 
|  | 259 | +################################## LlmConfig ################################### | 
|  | 260 | +################################################################################ | 
|  | 261 | + | 
|  | 262 | + | 
|  | 263 | +@dataclass | 
|  | 264 | +class LlmConfig: | 
|  | 265 | +    base: BaseConfig = field(default_factory=BaseConfig) | 
|  | 266 | +    model: ModelConfig = field(default_factory=ModelConfig) | 
|  | 267 | +    quantization: QuantizationConfig = field(default_factory=QuantizationConfig) | 
|  | 268 | +    backend: BackendConfig = field(default_factory=BackendConfig) | 
0 commit comments