99"""
1010Configurations for exporting Llama.
1111
12- Uses dataclases , which integrate with OmegaConf and Hydra.
12+ Uses dataclasses , which integrate with OmegaConf and Hydra.
1313"""
1414
15+ import ast
1516import re
1617from dataclasses import dataclass , field
1718from enum import Enum
18- from typing import List , Optional
19+ from typing import ClassVar , List , Optional
1920
2021
2122################################################################################
@@ -44,7 +45,7 @@ class PreqMode(str, Enum):
4445 If you are dealing with pre-quantized checkpoints, this used to
4546 be the way to specify them. Now you don't need to specify these
4647 options if you use a TorchAo-prequantized checkpoint, but they
47- are still around to preservce backward compatibility.
48+ are still around to preserve backward compatibility.
4849 """
4950
5051 PREQ_8DA4W = "8da4w"
@@ -57,18 +58,35 @@ class BaseConfig:
5758 Configurations specific to the model, e.g. whether it’s Qwen3 or Phi-4-mini,
5859 and are the minimal set of parameters needed to load the pretrained
5960 eager model and its weights.
61+
62+ Attributes:
63+ model_class: Which model to to export.
64+ params: Model parameters, such as n_layers, hidden_size, etc.
65+ If left empty will use defaults specified in model_args.py.
66+ checkpoint: Path to the checkpoint file.
67+ If left empty, the model will be initialized with random weights.
68+ checkpoint_dir: Path to directory containing sharded checkpoint files.
69+ tokenizer_path: Path to the tokenizer file.
70+ metadata: Json string containing metadata information.
71+ e.g. '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
72+ use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT.
73+ fairseq2: For legacy internal use cases, this is safe to ignore.
74+ preq_mode: Legacy option to specify how prequantized weights are loaded.
75+ Going forward, ExecuTorch supports loading weights prequantized through
76+ TorchAo as-is, without any special handling.
77+ preq_group_size: Legacy option to specify the group size of prequantized weights.
78+ preq_embedding_quantize: Legacy option to specify how prequantized embeddings
79+ are loaded.
6080 """
6181
6282 model_class : ModelType = ModelType .LLAMA3
6383 params : Optional [str ] = None
6484 checkpoint : Optional [str ] = None
65- checkpoint_dir : Optional [str ] = None # For sharded checkpoint.
85+ checkpoint_dir : Optional [str ] = None
6686 tokenizer_path : Optional [str ] = None
6787 metadata : Optional [str ] = None
68- use_lora : bool = False
69- fairseq2 : bool = False # For legacy internal use cases.
70-
71- # Legacy pre-quantization options that happen during model weight loading.
88+ use_lora : int = int
89+ fairseq2 : bool = False
7290 preq_mode : Optional [PreqMode ] = None
7391 preq_group_size : int = 32
7492 preq_embedding_quantize : str = "8,0"
@@ -98,6 +116,32 @@ class ModelConfig:
98116 finish off the rest of the model configuration in eager. You can think
99117 of these like optimizations / actual configurations. The same ModelConfig
100118 can be applied to multiple models.
119+
120+ Attributes:
121+ dtype_override: dtype to cast the model to.
122+ enable_dynamic_shape: whether to enable dynamic shapes on the sequence
123+ length so that the model can handle arbitrary prefill lengths and
124+ token generation.
125+ use_shared_embeddings: whether the embedding/output weights should be
126+ shared. Only available with torchao kernels, e.g. when
127+ qmode set to use a "torchao:8da(\\ d+)w" pattern.
128+ use_sdpa_with_kv_cache: Whether to use flash attention by substituting
129+ for our custom SDPA op. Note that the naming is poor and this
130+ doesn't actually have anything to do with the kv_cache at the moment.
131+ expand_rope_table: Temporary workaround to expand sin/cos table in head
132+ dim to take vectorized path in optimized kernels.
133+ use_attention_sink: Whether to use attention sink to support multi-round
134+ conversation. Structured as:
135+ '<sink_size>,<window_size>,<batch_eviction_size>',
136+ e.g., '4,2044,1024'.
137+ output_prune_map: Path to the output pruning token mapping file (token_map.json).
138+ input_prune_map: Path to the output pruning token mapping file (token_map.json).
139+ use_kv_cache: Whether to use KV cache.
140+ quantize_kv_cache: Whether to perform int8 per token quantization on the KV cache.
141+ local_global_attention: List of integers specifying local and global attention pattern.
142+ e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16.
143+ [0, 16, 32] pattern specifies 2nd and 3rd layers have sliding windows of 16 and 32.
144+ [16] pattern specifies all layers have a sliding window of 16.
101145 """
102146
103147 dtype_override : DtypeOverride = DtypeOverride .FP32
@@ -108,12 +152,44 @@ class ModelConfig:
108152 use_attention_sink : Optional [str ] = None
109153 output_prune_map : Optional [str ] = None
110154 input_prune_map : Optional [str ] = None
111-
112- # Below are config options relating to kv cache.
113155 use_kv_cache : bool = False
114156 quantize_kv_cache : bool = False
115157 local_global_attention : Optional [List [int ]] = None
116158
159+ def __post_init__ (self ):
160+ self ._validate_attention_sink ()
161+ self ._validate_local_global_attention ()
162+
163+ if self .quantize_kv_cache and not self .use_kv_cache :
164+ raise ValueError (
165+ "Cannot quantize the KV cache (quantize_kv_cache) without enabling the KV cache (use_kv_cache)"
166+ )
167+
168+ if self .local_global_attention and not self .use_kv_cache :
169+ raise ValueError (
170+ "Cannot use local_global_attention without enabling the KV cache (use_kv_cache)"
171+ )
172+
173+ def _validate_attention_sink (self ):
174+ if self .use_attention_sink :
175+ attention_sink_params = self .use_attention_sink .split ("," )
176+ if len (attention_sink_params ) != 3 :
177+ raise ValueError (
178+ "The value of use_attention_sink must be structured like '<sink_size>,<window_size>,<batch_eviction_size>'"
179+ )
180+
181+ def _validate_local_global_attention (self ):
182+ if self .local_global_attention :
183+ local_global_err = "The value of local_global_attention must be a list of integers, e.g., [0, 16, 0, 16]"
184+ try :
185+ parsed = ast .literal_eval (self .local_global_attention )
186+ if not (
187+ isinstance (parsed , list ) and all (isinstance (i , int ) for i in parsed )
188+ ):
189+ raise ValueError (local_global_err )
190+ except Exception :
191+ raise ValueError (local_global_err )
192+
117193
118194################################################################################
119195################################ ExportConfig ##################################
@@ -124,6 +200,15 @@ class ModelConfig:
124200class ExportConfig :
125201 """
126202 Configures properties relevant to the export process.
203+
204+ Attributes:
205+ max_seq_length: Maximum length of sequence to evaluate.
206+ max_context_length: Maximum of context for the model to remember.
207+ output_dir: Output dir to save the exported .pte file to.
208+ output_name: File name to override the exported .pte file.
209+ so_library: Shared library to specify custom quantized operators.
210+ export_only: Whether to stop right after torch.export() and
211+ just save the exported .pt2 graph file.
127212 """
128213
129214 max_seq_length : int = 128
@@ -133,6 +218,12 @@ class ExportConfig:
133218 so_library : Optional [str ] = None
134219 export_only : bool = False
135220
221+ def __post_init__ (self ):
222+ if self .max_context_length > self .max_seq_length :
223+ raise ValueError (
224+ f"max_context_length of { self .max_context_length } cannot be greater than max_seq_length of { self .max_seq_length } "
225+ )
226+
136227
137228################################################################################
138229################################# DebugConfig ##################################
@@ -143,6 +234,16 @@ class ExportConfig:
143234class DebugConfig :
144235 """
145236 Configures options to debug the export process.
237+
238+ Attributes:
239+ profile_memory: Whether to generate a chrome trace of activation memory
240+ for intermediate tensors.
241+ profile_path: Use cProfile to profile the export. Results are saved to
242+ profile_path as an html file.
243+ generate_etrecord: Whether to generate an ETRecord debug artifact.
244+ generate_full_logits: Whether to keep the full logits, potentially useful
245+ for debugging purposes. Kept off by default to save memory.
246+ verbose: Whether to log the export process verbosely (log level >= INFO).
146247 """
147248
148249 profile_memory : bool = False
@@ -188,8 +289,32 @@ class SpinQuant(str, Enum):
188289class QuantizationConfig :
189290 """
190291 Configures how the model should be quantized (PTQ).
292+
293+ Attributes:
294+ qmode: Quantization mode using TorchAo, expressed as a string.
295+ See the __post_init__ validation for available qmode options.
296+ embedding_quantize: Type of embedding quantization.
297+ Must be of the format '<bitwidth>,<groupsize>', e.g., '8,1024'.
298+ pt2e_quantize: Quantization mode using pt2e, which is an alternative
299+ to TorchAo that uses backend-aware graph mode quantization rather
300+ than source transformation quantization.
301+ group_size: Group size for quantization.
302+ use_spin_quant: Which spin quant mode to use. If unspecified, don't use
303+ spin quant.
304+ use_qat: Whether the checkpoint is quantization-awarely trained.
305+ calibration_tasks: Tasks for GPTQ calibration from lm_eval.
306+ calibration_limit: Number of samples used for calibration from lm_eval.
307+ calibration_seq_length: Sequence length for GPTQ calibration from lm_eval.
308+ calibration_data: Prompts use for calibration.
191309 """
192310
311+ # Constants.
312+ QMODE_OPTIONS : ClassVar [List [str ]] = ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
313+ AO_QUANT_PATTERNS : ClassVar [List [str ]] = [
314+ r"torchao:8da(\d+)w" ,
315+ r"torchao:fpa(\d+)w" ,
316+ ]
317+
193318 qmode : Optional [str ] = None
194319 embedding_quantize : Optional [str ] = None
195320 pt2e_quantize : Optional [Pt2eQuantize ] = None
@@ -206,21 +331,29 @@ def __post_init__(self):
206331 self ._validate_qmode ()
207332
208333 def _validate_qmode (self ) -> None :
209- choices = [ "int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
210- patterns = [ r"torchao:8da(\d+)w" , r"torchao:fpa(\d+)w" ]
334+ if not self . qmode :
335+ return
211336
212- if self .qmode in choices :
337+ if self .qmode in self . QMODE_OPTIONS :
213338 return
214339
215- for pattern in patterns :
340+ # If qmode is one of these below patterns, this means that we
341+ # are using ARM-based torchao ops.
342+ for pattern in self .AO_QUANT_PATTERNS :
216343 matches = re .findall (pattern , self .qmode )
217344 if len (matches ) == 1 :
218345 return
219346
220347 raise ValueError (
221- f"Got qmode { self .qmode } , but expected one of { choices } , or one of the regex patterns { patterns } ."
348+ f"Got qmode { self .qmode } , but expected one of { self . QMODE_OPTIONS } , or one of the regex patterns { self . AO_QUANT_PATTERNS } ."
222349 )
223350
351+ def _validate_embedding_quantize (self ):
352+ if len (self .embedding_quantize .split ("," )) != 2 :
353+ raise ValueError (
354+ f'embedding_quantize of { self .embedding_quantize } must follow the following format: "<bitwidth>,<groupsize>"'
355+ )
356+
224357
225358################################################################################
226359############################### BackendConfig ##################################
@@ -229,6 +362,14 @@ def _validate_qmode(self) -> None:
229362
230363@dataclass
231364class XNNPackConfig :
365+ """
366+ Configures the XNNPack backend.
367+
368+ Attributes:
369+ enabled: :)
370+ extended_ops: Whether to match more types of ops to delegates to XNNPack.
371+ """
372+
232373 enabled : bool = False
233374 extended_ops : bool = False
234375
@@ -247,6 +388,10 @@ class CoreMLComputeUnit(str, Enum):
247388
248389@dataclass
249390class CoreMLConfig :
391+ """
392+ Configures the CoreML backend.
393+ """
394+
250395 enabled : bool = False
251396 enable_state : bool = False
252397 preserve_sdpa : bool = False
@@ -261,11 +406,19 @@ def __post_init__(self):
261406
262407@dataclass
263408class VulkanConfig :
409+ """
410+ Configures the Vulkan backend.
411+ """
412+
264413 enabled : bool = False
265414
266415
267416@dataclass
268417class QNNConfig :
418+ """
419+ Configures the QNN backend.
420+ """
421+
269422 enabled : bool = False
270423 use_sha : bool = False
271424 soc_model : str = "SM8650"
@@ -276,6 +429,10 @@ class QNNConfig:
276429
277430@dataclass
278431class MPSConfig :
432+ """
433+ Configures the MPS backend.
434+ """
435+
279436 enabled : bool = False
280437
281438
@@ -310,3 +467,29 @@ class LlmConfig:
310467 debug : DebugConfig = field (default_factory = DebugConfig )
311468 quantization : QuantizationConfig = field (default_factory = QuantizationConfig )
312469 backend : BackendConfig = field (default_factory = BackendConfig )
470+
471+ def __post_init__ (self ):
472+ self ._validate_low_bit_no_xnnpack ()
473+
474+ def _validate_low_bit (self ):
475+ if not self .quantization .qmode :
476+ return
477+
478+ using_lowbit_ops = False
479+ for pattern in self .quantization .AO_QUANT_PATTERNS :
480+ matches = re .findall (pattern , self .quantization .qmode )
481+ if len (matches ) == 1 :
482+ using_lowbit_ops = True
483+
484+ # If we are using Ao's low bit quantization kernels for ARM,
485+ # we do not want to also be delegating to a CPU backend (XNNPack).
486+ if using_lowbit_ops and self .backend .xnnpack .enabled :
487+ raise ValueError (
488+ "Cannot use low-bit Ao ops (from qmode=torchao:...) while also delegating to XNNPack."
489+ )
490+
491+ # Also we can only use shared embeddings if we are using low bit kernels.
492+ if self .model .use_shared_embedding and not using_lowbit_ops :
493+ raise ValueError (
494+ "Can only use shared embeddings with low-bit ops (with qmode=torchao:...)."
495+ )
0 commit comments