99"""
1010Configurations for exporting Llama.
1111
12- Uses dataclases , which integrate with OmegaConf and Hydra.
12+ Uses dataclasses , which integrate with OmegaConf and Hydra.
1313"""
1414
15+ import argparse
16+ import ast
1517import re
1618from dataclasses import dataclass , field
1719from enum import Enum
18- from typing import List , Optional
20+ from typing import ClassVar , List , Optional , Self
1921
2022
2123################################################################################
@@ -44,7 +46,7 @@ class PreqMode(str, Enum):
4446 If you are dealing with pre-quantized checkpoints, this used to
4547 be the way to specify them. Now you don't need to specify these
4648 options if you use a TorchAo-prequantized checkpoint, but they
47- are still around to preservce backward compatibility.
49+ are still around to preserve backward compatibility.
4850 """
4951
5052 PREQ_8DA4W = "8da4w"
@@ -57,18 +59,35 @@ class BaseConfig:
5759 Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini,
5860 and are the minimal set of parameters needed to load the pretrained
5961 eager model and its weights.
62+
63+ Attributes:
64+ model_class: Which model to to export.
65+ params: Model parameters, such as n_layers, hidden_size, etc.
66+ If left empty will use defaults specified in model_args.py.
67+ checkpoint: Path to the checkpoint file.
68+ If left empty, the model will be initialized with random weights.
69+ checkpoint_dir: Path to directory containing sharded checkpoint files.
70+ tokenizer_path: Path to the tokenizer file.
71+ metadata: Json string containing metadata information.
72+ e.g. '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
73+ use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT.
74+ fairseq2: For legacy internal use cases, this is safe to ignore.
75+ preq_mode: Legacy option to specify how prequantized weights are loaded.
76+ Going forward, ExecuTorch supports loading weights prequantized through
77+ TorchAo as-is, without any special handling.
78+ preq_group_size: Legacy option to specify the group size of prequantized weights.
79+ preq_embedding_quantize: Legacy option to specify how prequantized embeddings
80+ are loaded.
6081 """
6182
6283 model_class : ModelType = ModelType .LLAMA3
6384 params : Optional [str ] = None
6485 checkpoint : Optional [str ] = None
65- checkpoint_dir : Optional [str ] = None # For sharded checkpoint.
86+ checkpoint_dir : Optional [str ] = None
6687 tokenizer_path : Optional [str ] = None
6788 metadata : Optional [str ] = None
68- use_lora : bool = False
69- fairseq2 : bool = False # For legacy internal use cases.
70-
71- # Legacy pre-quantization options that happen during model weight loading.
89+ use_lora : int = int
90+ fairseq2 : bool = False
7291 preq_mode : Optional [PreqMode ] = None
7392 preq_group_size : int = 32
7493 preq_embedding_quantize : str = "8,0"
@@ -98,6 +117,32 @@ class ModelConfig:
98117 finish off the rest of the model configuration in eager. You can think
99118 of these like optimizations / actual configurations. The same ModelConfig
100119 can be applied to multiple models.
120+
121+ Attributes:
122+ dtype_override: dtype to cast the model to.
123+ enable_dynamic_shape: whether to enable dynamic shapes on the sequence
124+ length so that the model can handle arbitrary prefill lengths and
125+ token generation.
126+ use_shared_embeddings: whether the embedding/output weights should be
127+ shared. Only available with torchao kernels, e.g. when
128+ qmode set to use a "torchao:8da(\\ d+)w" pattern.
129+ use_sdpa_with_kv_cache: Whether to use flash attention by substituting
130+ for our custom SDPA op. Note that the naming is poor and this
131+ doesn't actually have anything to do with the kv_cache at the moment.
132+ expand_rope_table: Temporary workaround to expand sin/cos table in head
133+ dim to take vectorized path in optimized kernels.
134+ use_attention_sink: Whether to use attention sink to support multi-round
135+ conversation. Structured as:
136+ '<sink_size>,<window_size>,<batch_eviction_size>',
137+ e.g., '4,2044,1024'.
138+ output_prune_map: Path to the output pruning token mapping file (token_map.json).
139+ input_prune_map: Path to the output pruning token mapping file (token_map.json).
140+ use_kv_cache: Whether to use KV cache.
141+ quantize_kv_cache: Whether to perform int8 per token quantization on the KV cache.
142+ local_global_attention: List of integers specifying local and global attention pattern.
143+ e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16.
144+ [0, 16, 32] pattern specifies 2nd and 3rd layers have sliding windows of 16 and 32.
145+ [16] pattern specifies all layers have a sliding window of 16.
101146 """
102147
103148 dtype_override : DtypeOverride = DtypeOverride .FP32
@@ -108,12 +153,44 @@ class ModelConfig:
108153 use_attention_sink : Optional [str ] = None
109154 output_prune_map : Optional [str ] = None
110155 input_prune_map : Optional [str ] = None
111-
112- # Below are config options relating to kv cache.
113156 use_kv_cache : bool = False
114157 quantize_kv_cache : bool = False
115158 local_global_attention : Optional [List [int ]] = None
116159
160+ def __post_init__ (self ):
161+ self ._validate_attention_sink ()
162+ self ._validate_local_global_attention ()
163+
164+ if self .quantize_kv_cache and not self .use_kv_cache :
165+ raise ValueError (
166+ "Cannot quantize the KV cache (quantize_kv_cache) without enabling the KV cache (use_kv_cache)"
167+ )
168+
169+ if self .local_global_attention and not self .use_kv_cache :
170+ raise ValueError (
171+ "Cannot use local_global_attention without enabling the KV cache (use_kv_cache)"
172+ )
173+
174+ def _validate_attention_sink (self ):
175+ if self .use_attention_sink :
176+ attention_sink_params = self .use_attention_sink .split ("," )
177+ if len (attention_sink_params ) != 3 :
178+ raise ValueError (
179+ "The value of use_attention_sink must be structured like '<sink_size>,<window_size>,<batch_eviction_size>'"
180+ )
181+
182+ def _validate_local_global_attention (self ):
183+ if self .local_global_attention :
184+ local_global_err = "The value of local_global_attention must be a list of integers, e.g., [0, 16, 0, 16]"
185+ try :
186+ parsed = ast .literal_eval (self .local_global_attention )
187+ if not (
188+ isinstance (parsed , list ) and all (isinstance (i , int ) for i in parsed )
189+ ):
190+ raise ValueError (local_global_err )
191+ except Exception :
192+ raise ValueError (local_global_err )
193+
117194
118195################################################################################
119196################################ ExportConfig ##################################
@@ -124,6 +201,15 @@ class ModelConfig:
124201class ExportConfig :
125202 """
126203 Configures properties relevant to the export process.
204+
205+ Attributes:
206+ max_seq_length: Maximum length of sequence to evaluate.
207+ max_context_length: Maximum of context for the model to remember.
208+ output_dir: Output dir to save the exported .pte file to.
209+ output_name: File name to override the exported .pte file.
210+ so_library: Shared library to specify custom quantized operators.
211+ export_only: Whether to stop right after torch.export() and
212+ just save the exported .pt2 graph file.
127213 """
128214
129215 max_seq_length : int = 128
@@ -133,6 +219,12 @@ class ExportConfig:
133219 so_library : Optional [str ] = None
134220 export_only : bool = False
135221
222+ def __post_init__ (self ):
223+ if self .max_context_length > self .max_seq_length :
224+ raise ValueError (
225+ f"max_context_length of { self .max_context_length } cannot be greater than max_seq_length of { self .max_seq_length } "
226+ )
227+
136228
137229################################################################################
138230################################# DebugConfig ##################################
@@ -143,6 +235,16 @@ class ExportConfig:
143235class DebugConfig :
144236 """
145237 Configures options to debug the export process.
238+
239+ Attributes:
240+ profile_memory: Whether to generate a chrome trace of activation memory
241+ for intermediate tensors.
242+ profile_path: Use cProfile to profile the export. Results are saved to
243+ profile_path as an html file.
244+ generate_etrecord: Whether to generate an ETRecord debug artifact.
245+ generate_full_logits: Whether to keep the full logits, potentially useful
246+ for debugging purposes. Kept off by default to save memory.
247+ verbose: Whether to log the export process verbosely (log level >= INFO).
146248 """
147249
148250 profile_memory : bool = False
@@ -188,8 +290,32 @@ class SpinQuant(str, Enum):
188290class QuantizationConfig :
189291 """
190292 Configures how the model should be quantized (PTQ).
293+
294+ Attributes:
295+ qmode: Quantization mode using TorchAo, expressed as a string.
296+ See the __post_init__ validation for available qmode options.
297+ embedding_quantize: Type of embedding quantization.
298+ Must be of the format '<bitwidth>,<groupsize>', e.g., '8,1024'.
299+ pt2e_quantize: Quantization mode using pt2e, which is an alternative
300+ to TorchAo that uses backend-aware graph mode quantization rather
301+ than source transformation quantization.
302+ group_size: Group size for quantization.
303+ use_spin_quant: Which spin quant mode to use. If unspecified, don't use
304+ spin quant.
305+ use_qat: Whether the checkpoint is quantization-awarely trained.
306+ calibration_tasks: Tasks for GPTQ calibration from lm_eval.
307+ calibration_limit: Number of samples used for calibration from lm_eval.
308+ calibration_seq_length: Sequence length for GPTQ calibration from lm_eval.
309+ calibration_data: Prompts use for calibration.
191310 """
192311
312+ # Constants.
313+ QMODE_OPTIONS : ClassVar [List [str ]] = ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
314+ AO_QUANT_PATTERNS : ClassVar [List [str ]] = [
315+ r"torchao:8da(\d+)w" ,
316+ r"torchao:fpa(\d+)w" ,
317+ ]
318+
193319 qmode : Optional [str ] = None
194320 embedding_quantize : Optional [str ] = None
195321 pt2e_quantize : Optional [Pt2eQuantize ] = None
@@ -206,21 +332,29 @@ def __post_init__(self):
206332 self ._validate_qmode ()
207333
208334 def _validate_qmode (self ) -> None :
209- choices = [ "int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
210- patterns = [ r"torchao:8da(\d+)w" , r"torchao:fpa(\d+)w" ]
335+ if not self . qmode :
336+ return
211337
212- if self .qmode in choices :
338+ if self .qmode in self . QMODE_OPTIONS :
213339 return
214340
215- for pattern in patterns :
341+ # If qmode is one of these below patterns, this means that we
342+ # are using ARM-based torchao ops.
343+ for pattern in self .AO_QUANT_PATTERNS :
216344 matches = re .findall (pattern , self .qmode )
217345 if len (matches ) == 1 :
218346 return
219347
220348 raise ValueError (
221- f"Got qmode { self .qmode } , but expected one of { choices } , or one of the regex patterns { patterns } ."
349+ f"Got qmode { self .qmode } , but expected one of { self . QMODE_OPTIONS } , or one of the regex patterns { self . AO_QUANT_PATTERNS } ."
222350 )
223351
352+ def _validate_embedding_quantize (self ):
353+ if len (self .embedding_quantize .split ("," )) != 2 :
354+ raise ValueError (
355+ f'embedding_quantize of { self .embedding_quantize } must follow the following format: "<bitwidth>,<groupsize>"'
356+ )
357+
224358
225359################################################################################
226360############################### BackendConfig ##################################
@@ -229,6 +363,14 @@ def _validate_qmode(self) -> None:
229363
230364@dataclass
231365class XNNPackConfig :
366+ """
367+ Configures the XNNPack backend.
368+
369+ Attributes:
370+ enabled: :)
371+ extended_ops: Whether to match more types of ops to delegates to XNNPack.
372+ """
373+
232374 enabled : bool = False
233375 extended_ops : bool = False
234376
@@ -247,6 +389,10 @@ class CoreMLComputeUnit(str, Enum):
247389
248390@dataclass
249391class CoreMLConfig :
392+ """
393+ Configures the CoreML backend.
394+ """
395+
250396 enabled : bool = False
251397 enable_state : bool = False
252398 preserve_sdpa : bool = False
@@ -261,11 +407,19 @@ def __post_init__(self):
261407
262408@dataclass
263409class VulkanConfig :
410+ """
411+ Configures the Vulkan backend.
412+ """
413+
264414 enabled : bool = False
265415
266416
267417@dataclass
268418class QNNConfig :
419+ """
420+ Configures the QNN backend.
421+ """
422+
269423 enabled : bool = False
270424 use_sha : bool = False
271425 soc_model : str = "SM8650"
@@ -276,6 +430,10 @@ class QNNConfig:
276430
277431@dataclass
278432class MPSConfig :
433+ """
434+ Configures the MPS backend.
435+ """
436+
279437 enabled : bool = False
280438
281439
@@ -310,3 +468,41 @@ class LlmConfig:
310468 debug : DebugConfig = field (default_factory = DebugConfig )
311469 quantization : QuantizationConfig = field (default_factory = QuantizationConfig )
312470 backend : BackendConfig = field (default_factory = BackendConfig )
471+
472+ @staticmethod
473+ def from_args (args : argparse .Namespace ) -> Self :
474+ """
475+ To support legacy purposes, this function converts CLI args from
476+ argparse to an LlmConfig, which is used by the LLM export process.
477+ """
478+ llm_config = LlmConfig ()
479+
480+ # TODO: conversion code.
481+
482+ return llm_config
483+
484+ def __post_init__ (self ):
485+ self ._validate_low_bit ()
486+
487+ def _validate_low_bit (self ):
488+ if not self .quantization .qmode :
489+ return
490+
491+ using_lowbit_ops = False
492+ for pattern in self .quantization .AO_QUANT_PATTERNS :
493+ matches = re .findall (pattern , self .quantization .qmode )
494+ if len (matches ) == 1 :
495+ using_lowbit_ops = True
496+
497+ # If we are using Ao's low bit quantization kernels for ARM,
498+ # we do not want to also be delegating to a CPU backend (XNNPack).
499+ if using_lowbit_ops and self .backend .xnnpack .enabled :
500+ raise ValueError (
501+ "Cannot use low-bit Ao ops (from qmode=torchao:...) while also delegating to XNNPack."
502+ )
503+
504+ # Also we can only use shared embeddings if we are using low bit kernels.
505+ if self .model .use_shared_embedding and not using_lowbit_ops :
506+ raise ValueError (
507+ "Can only use shared embeddings with low-bit ops (with qmode=torchao:...)."
508+ )
0 commit comments