1616import ast
1717import re
1818from dataclasses import dataclass , field
19+ from enum import Enum
1920from typing import ClassVar , List , Optional
2021
2122
2425################################################################################
2526
2627
27- MODEL_TYPE_OPTIONS = [
28- "stories110m" ,
29- "llama2" ,
30- "llama3" ,
31- "llama3_1" ,
32- "llama3_2" ,
33- "llama3_2_vision" ,
34- "static_llama" ,
35- "qwen2_5" ,
36- "qwen3-0_6b" ,
37- "qwen3-1_7b" ,
38- "qwen3-4b" ,
39- "phi_4_mini" ,
40- "smollm2" ,
41- ]
28+ class ModelType (str , Enum ):
29+ stories110m = "stories110m"
30+ llama2 = "llama2"
31+ llama3 = "llama3"
32+ llama3_1 = "llama3_1"
33+ llama3_2 = "llama3_2"
34+ llama3_2_vision = "llama3_2_vision"
35+ static_llama = "static_llama"
36+ qwen2_5 = "qwen2_5"
37+ qwen3_0_6b = "qwen3-0_6b"
38+ qwen3_1_7b = "qwen3-1_7b"
39+ qwen3_4b = "qwen3-4b"
40+ phi_4_mini = "phi_4_mini"
41+ smollm2 = "smollm2"
4242
4343
44- PREQ_MODE_OPTIONS = [
45- "8da4w" ,
46- "8da4w_output_8da8w" ,
47- ]
44+ class PreqMode (str , Enum ):
45+ """
46+ If you are dealing with pre-quantized checkpoints, this used to
47+ be the way to specify them. Now you don't need to specify these
48+ options if you use a TorchAo-prequantized checkpoint, but they
49+ are still around to preserve backward compatibility.
50+ """
51+
52+ preq_8da4w = "8da4w"
53+ preq_8da4w_out_8da8w = "8da4w_output_8da8w"
4854
4955
5056@dataclass
@@ -76,36 +82,34 @@ class BaseConfig:
7682 are loaded.
7783 """
7884
79- model_class : str = " llama3"
85+ model_class : ModelType = ModelType . llama3
8086 params : Optional [str ] = None
8187 checkpoint : Optional [str ] = None
8288 checkpoint_dir : Optional [str ] = None
8389 tokenizer_path : Optional [str ] = None
8490 metadata : Optional [str ] = None
8591 use_lora : int = 0
8692 fairseq2 : bool = False
87- preq_mode : Optional [str ] = None
93+ preq_mode : Optional [PreqMode ] = None
8894 preq_group_size : int = 32
8995 preq_embedding_quantize : str = "8,0"
9096
91- def __post_init__ (self ):
92- if self .model_class not in MODEL_TYPE_OPTIONS :
93- raise ValueError (f"model_class must be one of { MODEL_TYPE_OPTIONS } , got '{ self .model_class } '" )
94-
95- if self .preq_mode is not None and self .preq_mode not in PREQ_MODE_OPTIONS :
96- raise ValueError (f"preq_mode must be one of { PREQ_MODE_OPTIONS } , got '{ self .preq_mode } '" )
97-
9897
9998################################################################################
10099################################# ModelConfig ##################################
101100################################################################################
102101
103102
104- DTYPE_OVERRIDE_OPTIONS = [
105- "fp32" ,
106- "fp16" ,
107- "bf16" ,
108- ]
103+ class DtypeOverride (str , Enum ):
104+ """
105+ DType of the model. Highly recommended to use "fp32", unless you want to
106+ export without a backend, in which case you can also use "bf16". "fp16"
107+ is not recommended.
108+ """
109+
110+ fp32 = "fp32"
111+ fp16 = "fp16"
112+ bf16 = "bf16"
109113
110114
111115@dataclass
@@ -143,7 +147,7 @@ class ModelConfig:
143147 [16] pattern specifies all layers have a sliding window of 16.
144148 """
145149
146- dtype_override : str = " fp32"
150+ dtype_override : DtypeOverride = DtypeOverride . fp32
147151 enable_dynamic_shape : bool = True
148152 use_shared_embedding : bool = False
149153 use_sdpa_with_kv_cache : bool = False
@@ -156,9 +160,6 @@ class ModelConfig:
156160 local_global_attention : Optional [List [int ]] = None
157161
158162 def __post_init__ (self ):
159- if self .dtype_override not in DTYPE_OVERRIDE_OPTIONS :
160- raise ValueError (f"dtype_override must be one of { DTYPE_OVERRIDE_OPTIONS } , got '{ self .dtype_override } '" )
161-
162163 self ._validate_attention_sink ()
163164 self ._validate_local_global_attention ()
164165
@@ -260,25 +261,31 @@ class DebugConfig:
260261################################################################################
261262
262263
263- PT2E_QUANTIZE_OPTIONS = [
264- "xnnpack_dynamic" ,
265- "xnnpack_dynamic_qc4" ,
266- "qnn_8a8w" ,
267- "qnn_16a16w" ,
268- "qnn_16a4w" ,
269- "coreml_c4w" ,
270- "coreml_8a_c8w" ,
271- "coreml_8a_c4w" ,
272- "coreml_baseline_8a_c8w" ,
273- "coreml_baseline_8a_c4w" ,
274- "vulkan_8w" ,
275- ]
264+ class Pt2eQuantize (str , Enum ):
265+ """
266+ Type of backend-specific Pt2e quantization strategy to use.
267+
268+ Pt2e uses a different quantization library that is graph-based
269+ compared to `qmode`, which is also specified in the QuantizationConfig
270+ and is source transform-based.
271+ """
276272
273+ xnnpack_dynamic = "xnnpack_dynamic"
274+ xnnpack_dynamic_qc4 = "xnnpack_dynamic_qc4"
275+ qnn_8a8w = "qnn_8a8w"
276+ qnn_16a16w = "qnn_16a16w"
277+ qnn_16a4w = "qnn_16a4w"
278+ coreml_c4w = "coreml_c4w"
279+ coreml_8a_c8w = "coreml_8a_c8w"
280+ coreml_8a_c4w = "coreml_8a_c4w"
281+ coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w"
282+ coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w"
283+ vulkan_8w = "vulkan_8w"
277284
278- SPIN_QUANT_OPTIONS = [
279- "cuda" ,
280- "native" ,
281- ]
285+
286+ class SpinQuant ( str , Enum ):
287+ cuda = "cuda"
288+ native = "native"
282289
283290
284291@dataclass
@@ -313,22 +320,16 @@ class QuantizationConfig:
313320
314321 qmode : Optional [str ] = None
315322 embedding_quantize : Optional [str ] = None
316- pt2e_quantize : Optional [str ] = None
323+ pt2e_quantize : Optional [Pt2eQuantize ] = None
317324 group_size : Optional [int ] = None
318- use_spin_quant : Optional [str ] = None
325+ use_spin_quant : Optional [SpinQuant ] = None
319326 use_qat : bool = False
320327 calibration_tasks : Optional [List [str ]] = None
321328 calibration_limit : Optional [int ] = None
322329 calibration_seq_length : Optional [int ] = None
323330 calibration_data : str = "Once upon a time"
324331
325332 def __post_init__ (self ):
326- if self .pt2e_quantize is not None and self .pt2e_quantize not in PT2E_QUANTIZE_OPTIONS :
327- raise ValueError (f"pt2e_quantize must be one of { PT2E_QUANTIZE_OPTIONS } , got '{ self .pt2e_quantize } '" )
328-
329- if self .use_spin_quant is not None and self .use_spin_quant not in SPIN_QUANT_OPTIONS :
330- raise ValueError (f"use_spin_quant must be one of { SPIN_QUANT_OPTIONS } , got '{ self .use_spin_quant } '" )
331-
332333 if self .qmode :
333334 self ._validate_qmode ()
334335
@@ -376,18 +377,16 @@ class XNNPackConfig:
376377 extended_ops : bool = False
377378
378379
379- COREML_QUANTIZE_OPTIONS = [
380- "b4w" ,
381- "c4w" ,
382- ]
380+ class CoreMLQuantize (str , Enum ):
381+ b4w = "b4w"
382+ c4w = "c4w"
383383
384384
385- COREML_COMPUTE_UNIT_OPTIONS = [
386- "cpu_only" ,
387- "cpu_and_gpu" ,
388- "cpu_and_ne" ,
389- "all" ,
390- ]
385+ class CoreMLComputeUnit (str , Enum ):
386+ cpu_only = "cpu_only"
387+ cpu_and_gpu = "cpu_and_gpu"
388+ cpu_and_ne = "cpu_and_ne"
389+ all = "all"
391390
392391
393392@dataclass
@@ -399,17 +398,11 @@ class CoreMLConfig:
399398 enabled : bool = False
400399 enable_state : bool = False
401400 preserve_sdpa : bool = False
402- quantize : Optional [str ] = None
401+ quantize : Optional [CoreMLQuantize ] = None
403402 ios : int = 15
404- compute_units : str = " cpu_only"
403+ compute_units : CoreMLComputeUnit = CoreMLComputeUnit . cpu_only
405404
406405 def __post_init__ (self ):
407- if self .quantize is not None and self .quantize not in COREML_QUANTIZE_OPTIONS :
408- raise ValueError (f"quantize must be one of { COREML_QUANTIZE_OPTIONS } , got '{ self .quantize } '" )
409-
410- if self .compute_units not in COREML_COMPUTE_UNIT_OPTIONS :
411- raise ValueError (f"compute_units must be one of { COREML_COMPUTE_UNIT_OPTIONS } , got '{ self .compute_units } '" )
412-
413406 if self .ios not in (15 , 16 , 17 , 18 ):
414407 raise ValueError (f"Invalid coreml ios version: { self .ios } " )
415408
@@ -488,7 +481,7 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
488481
489482 # BaseConfig
490483 if hasattr (args , "model" ):
491- llm_config .base .model_class = args .model
484+ llm_config .base .model_class = ModelType ( args .model )
492485 if hasattr (args , "params" ):
493486 llm_config .base .params = args .params
494487 if hasattr (args , "checkpoint" ):
@@ -506,15 +499,15 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
506499
507500 # PreqMode settings
508501 if hasattr (args , "preq_mode" ) and args .preq_mode :
509- llm_config .base .preq_mode = args .preq_mode
502+ llm_config .base .preq_mode = PreqMode ( args .preq_mode )
510503 if hasattr (args , "preq_group_size" ):
511504 llm_config .base .preq_group_size = args .preq_group_size
512505 if hasattr (args , "preq_embedding_quantize" ):
513506 llm_config .base .preq_embedding_quantize = args .preq_embedding_quantize
514507
515508 # ModelConfig
516509 if hasattr (args , "dtype_override" ):
517- llm_config .model .dtype_override = args .dtype_override
510+ llm_config .model .dtype_override = DtypeOverride ( args .dtype_override )
518511 if hasattr (args , "enable_dynamic_shape" ):
519512 llm_config .model .enable_dynamic_shape = args .enable_dynamic_shape
520513 if hasattr (args , "use_shared_embedding" ):
@@ -556,11 +549,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
556549 if hasattr (args , "embedding_quantize" ):
557550 llm_config .quantization .embedding_quantize = args .embedding_quantize
558551 if hasattr (args , "pt2e_quantize" ) and args .pt2e_quantize :
559- llm_config .quantization .pt2e_quantize = args .pt2e_quantize
552+ llm_config .quantization .pt2e_quantize = Pt2eQuantize ( args .pt2e_quantize )
560553 if hasattr (args , "group_size" ):
561554 llm_config .quantization .group_size = args .group_size
562555 if hasattr (args , "use_spin_quant" ) and args .use_spin_quant :
563- llm_config .quantization .use_spin_quant = args .use_spin_quant
556+ llm_config .quantization .use_spin_quant = SpinQuant ( args .use_spin_quant )
564557 if hasattr (args , "use_qat" ):
565558 llm_config .quantization .use_qat = args .use_qat
566559 if hasattr (args , "calibration_tasks" ):
@@ -588,11 +581,13 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
588581 args , "coreml_preserve_sdpa" , False
589582 )
590583 if hasattr (args , "coreml_quantize" ) and args .coreml_quantize :
591- llm_config .backend .coreml .quantize = args .coreml_quantize
584+ llm_config .backend .coreml .quantize = CoreMLQuantize ( args .coreml_quantize )
592585 if hasattr (args , "coreml_ios" ):
593586 llm_config .backend .coreml .ios = args .coreml_ios
594587 if hasattr (args , "coreml_compute_units" ):
595- llm_config .backend .coreml .compute_units = args .coreml_compute_units
588+ llm_config .backend .coreml .compute_units = CoreMLComputeUnit (
589+ args .coreml_compute_units
590+ )
596591
597592 # Vulkan
598593 if hasattr (args , "vulkan" ):
0 commit comments