@@ -45,11 +45,16 @@ class ModelType(str, Enum):
4545 smollm2 = "smollm2"
4646
4747
48+ class PreqMode (str , Enum ):
49+ """
50+ If you are dealing with pre-quantized checkpoints, this used to
51+ be the way to specify them. Now you don't need to specify these
52+ options if you use a TorchAo-prequantized checkpoint, but they
53+ are still around to preserve backward compatibility.
54+ """
4855
49- PREQ_MODE_OPTIONS = [
50- "8da4w" ,
51- "8da4w_output_8da8w" ,
52- ]
56+ preq_8da4w = "8da4w"
57+ preq_8da4w_out_8da8w = "8da4w_output_8da8w"
5358
5459
5560@dataclass
@@ -81,36 +86,34 @@ class BaseConfig:
8186 are loaded.
8287 """
8388
84- model_class : str = " llama3"
89+ model_class : ModelType = ModelType . llama3
8590 params : Optional [str ] = None
8691 checkpoint : Optional [str ] = None
8792 checkpoint_dir : Optional [str ] = None
8893 tokenizer_path : Optional [str ] = None
8994 metadata : Optional [str ] = None
9095 use_lora : int = 0
9196 fairseq2 : bool = False
92- preq_mode : Optional [str ] = None
97+ preq_mode : Optional [PreqMode ] = None
9398 preq_group_size : int = 32
9499 preq_embedding_quantize : str = "8,0"
95100
96- def __post_init__ (self ):
97- if self .model_class not in MODEL_TYPE_OPTIONS :
98- raise ValueError (f"model_class must be one of { MODEL_TYPE_OPTIONS } , got '{ self .model_class } '" )
99-
100- if self .preq_mode is not None and self .preq_mode not in PREQ_MODE_OPTIONS :
101- raise ValueError (f"preq_mode must be one of { PREQ_MODE_OPTIONS } , got '{ self .preq_mode } '" )
102-
103101
104102################################################################################
105103################################# ModelConfig ##################################
106104################################################################################
107105
108106
109- DTYPE_OVERRIDE_OPTIONS = [
110- "fp32" ,
111- "fp16" ,
112- "bf16" ,
113- ]
107+ class DtypeOverride (str , Enum ):
108+ """
109+ DType of the model. Highly recommended to use "fp32", unless you want to
110+ export without a backend, in which case you can also use "bf16". "fp16"
111+ is not recommended.
112+ """
113+
114+ fp32 = "fp32"
115+ fp16 = "fp16"
116+ bf16 = "bf16"
114117
115118
116119@dataclass
@@ -148,7 +151,7 @@ class ModelConfig:
148151 [16] pattern specifies all layers have a sliding window of 16.
149152 """
150153
151- dtype_override : str = " fp32"
154+ dtype_override : DtypeOverride = DtypeOverride . fp32
152155 enable_dynamic_shape : bool = True
153156 use_shared_embedding : bool = False
154157 use_sdpa_with_kv_cache : bool = False
@@ -161,9 +164,6 @@ class ModelConfig:
161164 local_global_attention : Optional [List [int ]] = None
162165
163166 def __post_init__ (self ):
164- if self .dtype_override not in DTYPE_OVERRIDE_OPTIONS :
165- raise ValueError (f"dtype_override must be one of { DTYPE_OVERRIDE_OPTIONS } , got '{ self .dtype_override } '" )
166-
167167 self ._validate_attention_sink ()
168168 self ._validate_local_global_attention ()
169169
@@ -265,25 +265,31 @@ class DebugConfig:
265265################################################################################
266266
267267
268- PT2E_QUANTIZE_OPTIONS = [
269- "xnnpack_dynamic" ,
270- "xnnpack_dynamic_qc4" ,
271- "qnn_8a8w" ,
272- "qnn_16a16w" ,
273- "qnn_16a4w" ,
274- "coreml_c4w" ,
275- "coreml_8a_c8w" ,
276- "coreml_8a_c4w" ,
277- "coreml_baseline_8a_c8w" ,
278- "coreml_baseline_8a_c4w" ,
279- "vulkan_8w" ,
280- ]
268+ class Pt2eQuantize (str , Enum ):
269+ """
270+ Type of backend-specific Pt2e quantization strategy to use.
271+
272+ Pt2e uses a different quantization library that is graph-based
273+ compared to `qmode`, which is also specified in the QuantizationConfig
274+ and is source transform-based.
275+ """
281276
277+ xnnpack_dynamic = "xnnpack_dynamic"
278+ xnnpack_dynamic_qc4 = "xnnpack_dynamic_qc4"
279+ qnn_8a8w = "qnn_8a8w"
280+ qnn_16a16w = "qnn_16a16w"
281+ qnn_16a4w = "qnn_16a4w"
282+ coreml_c4w = "coreml_c4w"
283+ coreml_8a_c8w = "coreml_8a_c8w"
284+ coreml_8a_c4w = "coreml_8a_c4w"
285+ coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w"
286+ coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w"
287+ vulkan_8w = "vulkan_8w"
282288
283- SPIN_QUANT_OPTIONS = [
284- "cuda" ,
285- "native" ,
286- ]
289+
290+ class SpinQuant ( str , Enum ):
291+ cuda = "cuda"
292+ native = "native"
287293
288294
289295@dataclass
@@ -318,22 +324,16 @@ class QuantizationConfig:
318324
319325 qmode : Optional [str ] = None
320326 embedding_quantize : Optional [str ] = None
321- pt2e_quantize : Optional [str ] = None
327+ pt2e_quantize : Optional [Pt2eQuantize ] = None
322328 group_size : Optional [int ] = None
323- use_spin_quant : Optional [str ] = None
329+ use_spin_quant : Optional [SpinQuant ] = None
324330 use_qat : bool = False
325331 calibration_tasks : Optional [List [str ]] = None
326332 calibration_limit : Optional [int ] = None
327333 calibration_seq_length : Optional [int ] = None
328334 calibration_data : str = "Once upon a time"
329335
330336 def __post_init__ (self ):
331- if self .pt2e_quantize is not None and self .pt2e_quantize not in PT2E_QUANTIZE_OPTIONS :
332- raise ValueError (f"pt2e_quantize must be one of { PT2E_QUANTIZE_OPTIONS } , got '{ self .pt2e_quantize } '" )
333-
334- if self .use_spin_quant is not None and self .use_spin_quant not in SPIN_QUANT_OPTIONS :
335- raise ValueError (f"use_spin_quant must be one of { SPIN_QUANT_OPTIONS } , got '{ self .use_spin_quant } '" )
336-
337337 if self .qmode :
338338 self ._validate_qmode ()
339339
@@ -381,18 +381,16 @@ class XNNPackConfig:
381381 extended_ops : bool = False
382382
383383
384- COREML_QUANTIZE_OPTIONS = [
385- "b4w" ,
386- "c4w" ,
387- ]
384+ class CoreMLQuantize (str , Enum ):
385+ b4w = "b4w"
386+ c4w = "c4w"
388387
389388
390- COREML_COMPUTE_UNIT_OPTIONS = [
391- "cpu_only" ,
392- "cpu_and_gpu" ,
393- "cpu_and_ne" ,
394- "all" ,
395- ]
389+ class CoreMLComputeUnit (str , Enum ):
390+ cpu_only = "cpu_only"
391+ cpu_and_gpu = "cpu_and_gpu"
392+ cpu_and_ne = "cpu_and_ne"
393+ all = "all"
396394
397395
398396@dataclass
@@ -404,17 +402,11 @@ class CoreMLConfig:
404402 enabled : bool = False
405403 enable_state : bool = False
406404 preserve_sdpa : bool = False
407- quantize : Optional [str ] = None
405+ quantize : Optional [CoreMLQuantize ] = None
408406 ios : int = 15
409- compute_units : str = " cpu_only"
407+ compute_units : CoreMLComputeUnit = CoreMLComputeUnit . cpu_only
410408
411409 def __post_init__ (self ):
412- if self .quantize is not None and self .quantize not in COREML_QUANTIZE_OPTIONS :
413- raise ValueError (f"quantize must be one of { COREML_QUANTIZE_OPTIONS } , got '{ self .quantize } '" )
414-
415- if self .compute_units not in COREML_COMPUTE_UNIT_OPTIONS :
416- raise ValueError (f"compute_units must be one of { COREML_COMPUTE_UNIT_OPTIONS } , got '{ self .compute_units } '" )
417-
418410 if self .ios not in (15 , 16 , 17 , 18 ):
419411 raise ValueError (f"Invalid coreml ios version: { self .ios } " )
420412
@@ -493,7 +485,7 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
493485
494486 # BaseConfig
495487 if hasattr (args , "model" ):
496- llm_config .base .model_class = args .model
488+ llm_config .base .model_class = ModelType ( args .model )
497489 if hasattr (args , "params" ):
498490 llm_config .base .params = args .params
499491 if hasattr (args , "checkpoint" ):
@@ -511,15 +503,15 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
511503
512504 # PreqMode settings
513505 if hasattr (args , "preq_mode" ) and args .preq_mode :
514- llm_config .base .preq_mode = args .preq_mode
506+ llm_config .base .preq_mode = PreqMode ( args .preq_mode )
515507 if hasattr (args , "preq_group_size" ):
516508 llm_config .base .preq_group_size = args .preq_group_size
517509 if hasattr (args , "preq_embedding_quantize" ):
518510 llm_config .base .preq_embedding_quantize = args .preq_embedding_quantize
519511
520512 # ModelConfig
521513 if hasattr (args , "dtype_override" ):
522- llm_config .model .dtype_override = args .dtype_override
514+ llm_config .model .dtype_override = DtypeOverride ( args .dtype_override )
523515 if hasattr (args , "enable_dynamic_shape" ):
524516 llm_config .model .enable_dynamic_shape = args .enable_dynamic_shape
525517 if hasattr (args , "use_shared_embedding" ):
@@ -561,11 +553,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
561553 if hasattr (args , "embedding_quantize" ):
562554 llm_config .quantization .embedding_quantize = args .embedding_quantize
563555 if hasattr (args , "pt2e_quantize" ) and args .pt2e_quantize :
564- llm_config .quantization .pt2e_quantize = args .pt2e_quantize
556+ llm_config .quantization .pt2e_quantize = Pt2eQuantize ( args .pt2e_quantize )
565557 if hasattr (args , "group_size" ):
566558 llm_config .quantization .group_size = args .group_size
567559 if hasattr (args , "use_spin_quant" ) and args .use_spin_quant :
568- llm_config .quantization .use_spin_quant = args .use_spin_quant
560+ llm_config .quantization .use_spin_quant = SpinQuant ( args .use_spin_quant )
569561 if hasattr (args , "use_qat" ):
570562 llm_config .quantization .use_qat = args .use_qat
571563 if hasattr (args , "calibration_tasks" ):
@@ -593,11 +585,13 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
593585 args , "coreml_preserve_sdpa" , False
594586 )
595587 if hasattr (args , "coreml_quantize" ) and args .coreml_quantize :
596- llm_config .backend .coreml .quantize = args .coreml_quantize
588+ llm_config .backend .coreml .quantize = CoreMLQuantize ( args .coreml_quantize )
597589 if hasattr (args , "coreml_ios" ):
598590 llm_config .backend .coreml .ios = args .coreml_ios
599591 if hasattr (args , "coreml_compute_units" ):
600- llm_config .backend .coreml .compute_units = args .coreml_compute_units
592+ llm_config .backend .coreml .compute_units = CoreMLComputeUnit (
593+ args .coreml_compute_units
594+ )
601595
602596 # Vulkan
603597 if hasattr (args , "vulkan" ):
0 commit comments