1616import ast
1717import re
1818from dataclasses import dataclass , field
19- from enum import Enum
2019from typing import ClassVar , List , Optional
2120
2221
2524################################################################################
2625
2726
28- class ModelType (str , Enum ):
29- STORIES110M = "stories110m"
30- LLAMA2 = "llama2"
31- LLAMA3 = "llama3"
32- LLAMA3_1 = "llama3_1"
33- LLAMA3_2 = "llama3_2"
34- LLAMA3_2_VISION = "llama3_2_vision"
35- STATIC_LLAMA = "static_llama"
36- QWEN2_5 = "qwen2_5"
37- QWEN3_0_6B = "qwen3-0_6b"
38- QWEN3_1_7B = "qwen3-1_7b"
39- QWEN3_4B = "qwen3-4b"
40- PHI_4_MINI = "phi_4_mini"
41- SMOLLM2 = "smollm2"
27+ MODEL_TYPE_OPTIONS = [
28+ "stories110m" ,
29+ "llama2" ,
30+ "llama3" ,
31+ "llama3_1" ,
32+ "llama3_2" ,
33+ "llama3_2_vision" ,
34+ "static_llama" ,
35+ "qwen2_5" ,
36+ "qwen3-0_6b" ,
37+ "qwen3-1_7b" ,
38+ "qwen3-4b" ,
39+ "phi_4_mini" ,
40+ "smollm2" ,
41+ ]
4242
4343
44- class PreqMode (str , Enum ):
45- """
46- If you are dealing with pre-quantized checkpoints, this used to
47- be the way to specify them. Now you don't need to specify these
48- options if you use a TorchAo-prequantized checkpoint, but they
49- are still around to preserve backward compatibility.
50- """
51-
52- PREQ_8DA4W = "8da4w"
53- PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
44+ PREQ_MODE_OPTIONS = [
45+ "8da4w" ,
46+ "8da4w_output_8da8w" ,
47+ ]
5448
5549
5650@dataclass
@@ -82,34 +76,36 @@ class BaseConfig:
8276 are loaded.
8377 """
8478
85- model_class : ModelType = ModelType . LLAMA3
79+ model_class : str = "llama3"
8680 params : Optional [str ] = None
8781 checkpoint : Optional [str ] = None
8882 checkpoint_dir : Optional [str ] = None
8983 tokenizer_path : Optional [str ] = None
9084 metadata : Optional [str ] = None
9185 use_lora : int = 0
9286 fairseq2 : bool = False
93- preq_mode : Optional [PreqMode ] = None
87+ preq_mode : Optional [str ] = None
9488 preq_group_size : int = 32
9589 preq_embedding_quantize : str = "8,0"
9690
91+ def __post_init__ (self ):
92+ if self .model_class not in MODEL_TYPE_OPTIONS :
93+ raise ValueError (f"model_class must be one of { MODEL_TYPE_OPTIONS } , got '{ self .model_class } '" )
94+
95+ if self .preq_mode is not None and self .preq_mode not in PREQ_MODE_OPTIONS :
96+ raise ValueError (f"preq_mode must be one of { PREQ_MODE_OPTIONS } , got '{ self .preq_mode } '" )
97+
9798
9899################################################################################
99100################################# ModelConfig ##################################
100101################################################################################
101102
102103
103- class DtypeOverride (str , Enum ):
104- """
105- DType of the model. Highly recommended to use "fp32", unless you want to
106- export without a backend, in which case you can also use "bf16". "fp16"
107- is not recommended.
108- """
109-
110- FP32 = "fp32"
111- FP16 = "fp16"
112- BF16 = "bf16"
104+ DTYPE_OVERRIDE_OPTIONS = [
105+ "fp32" ,
106+ "fp16" ,
107+ "bf16" ,
108+ ]
113109
114110
115111@dataclass
@@ -147,7 +143,7 @@ class ModelConfig:
147143 [16] pattern specifies all layers have a sliding window of 16.
148144 """
149145
150- dtype_override : DtypeOverride = DtypeOverride . FP32
146+ dtype_override : str = "fp32"
151147 enable_dynamic_shape : bool = True
152148 use_shared_embedding : bool = False
153149 use_sdpa_with_kv_cache : bool = False
@@ -160,6 +156,9 @@ class ModelConfig:
160156 local_global_attention : Optional [List [int ]] = None
161157
162158 def __post_init__ (self ):
159+ if self .dtype_override not in DTYPE_OVERRIDE_OPTIONS :
160+ raise ValueError (f"dtype_override must be one of { DTYPE_OVERRIDE_OPTIONS } , got '{ self .dtype_override } '" )
161+
163162 self ._validate_attention_sink ()
164163 self ._validate_local_global_attention ()
165164
@@ -261,31 +260,25 @@ class DebugConfig:
261260################################################################################
262261
263262
264- class Pt2eQuantize (str , Enum ):
265- """
266- Type of backend-specific Pt2e quantization strategy to use.
267-
268- Pt2e uses a different quantization library that is graph-based
269- compared to `qmode`, which is also specified in the QuantizationConfig
270- and is source transform-based.
271- """
263+ PT2E_QUANTIZE_OPTIONS = [
264+ "xnnpack_dynamic" ,
265+ "xnnpack_dynamic_qc4" ,
266+ "qnn_8a8w" ,
267+ "qnn_16a16w" ,
268+ "qnn_16a4w" ,
269+ "coreml_c4w" ,
270+ "coreml_8a_c8w" ,
271+ "coreml_8a_c4w" ,
272+ "coreml_baseline_8a_c8w" ,
273+ "coreml_baseline_8a_c4w" ,
274+ "vulkan_8w" ,
275+ ]
272276
273- XNNPACK_DYNAMIC = "xnnpack_dynamic"
274- XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
275- QNN_8A8W = "qnn_8a8w"
276- QNN_16A16W = "qnn_16a16w"
277- QNN_16A4W = "qnn_16a4w"
278- COREML_C4W = "coreml_c4w"
279- COREML_8A_C8W = "coreml_8a_c8w"
280- COREML_8A_C4W = "coreml_8a_c4w"
281- COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w"
282- COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w"
283- VULKAN_8W = "vulkan_8w"
284277
285-
286- class SpinQuant ( str , Enum ):
287- CUDA = "cuda"
288- NATIVE = "native"
278+ SPIN_QUANT_OPTIONS = [
279+ "cuda" ,
280+ "native" ,
281+ ]
289282
290283
291284@dataclass
@@ -320,16 +313,22 @@ class QuantizationConfig:
320313
321314 qmode : Optional [str ] = None
322315 embedding_quantize : Optional [str ] = None
323- pt2e_quantize : Optional [Pt2eQuantize ] = None
316+ pt2e_quantize : Optional [str ] = None
324317 group_size : Optional [int ] = None
325- use_spin_quant : Optional [SpinQuant ] = None
318+ use_spin_quant : Optional [str ] = None
326319 use_qat : bool = False
327320 calibration_tasks : Optional [List [str ]] = None
328321 calibration_limit : Optional [int ] = None
329322 calibration_seq_length : Optional [int ] = None
330323 calibration_data : str = "Once upon a time"
331324
332325 def __post_init__ (self ):
326+ if self .pt2e_quantize is not None and self .pt2e_quantize not in PT2E_QUANTIZE_OPTIONS :
327+ raise ValueError (f"pt2e_quantize must be one of { PT2E_QUANTIZE_OPTIONS } , got '{ self .pt2e_quantize } '" )
328+
329+ if self .use_spin_quant is not None and self .use_spin_quant not in SPIN_QUANT_OPTIONS :
330+ raise ValueError (f"use_spin_quant must be one of { SPIN_QUANT_OPTIONS } , got '{ self .use_spin_quant } '" )
331+
333332 if self .qmode :
334333 self ._validate_qmode ()
335334
@@ -377,16 +376,18 @@ class XNNPackConfig:
377376 extended_ops : bool = False
378377
379378
380- class CoreMLQuantize (str , Enum ):
381- B4W = "b4w"
382- C4W = "c4w"
379+ COREML_QUANTIZE_OPTIONS = [
380+ "b4w" ,
381+ "c4w" ,
382+ ]
383383
384384
385- class CoreMLComputeUnit (str , Enum ):
386- CPU_ONLY = "cpu_only"
387- CPU_AND_GPU = "cpu_and_gpu"
388- CPU_AND_NE = "cpu_and_ne"
389- ALL = "all"
385+ COREML_COMPUTE_UNIT_OPTIONS = [
386+ "cpu_only" ,
387+ "cpu_and_gpu" ,
388+ "cpu_and_ne" ,
389+ "all" ,
390+ ]
390391
391392
392393@dataclass
@@ -398,11 +399,17 @@ class CoreMLConfig:
398399 enabled : bool = False
399400 enable_state : bool = False
400401 preserve_sdpa : bool = False
401- quantize : Optional [CoreMLQuantize ] = None
402+ quantize : Optional [str ] = None
402403 ios : int = 15
403- compute_units : CoreMLComputeUnit = CoreMLComputeUnit . CPU_ONLY
404+ compute_units : str = "cpu_only"
404405
405406 def __post_init__ (self ):
407+ if self .quantize is not None and self .quantize not in COREML_QUANTIZE_OPTIONS :
408+ raise ValueError (f"quantize must be one of { COREML_QUANTIZE_OPTIONS } , got '{ self .quantize } '" )
409+
410+ if self .compute_units not in COREML_COMPUTE_UNIT_OPTIONS :
411+ raise ValueError (f"compute_units must be one of { COREML_COMPUTE_UNIT_OPTIONS } , got '{ self .compute_units } '" )
412+
406413 if self .ios not in (15 , 16 , 17 , 18 ):
407414 raise ValueError (f"Invalid coreml ios version: { self .ios } " )
408415
@@ -481,7 +488,7 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
481488
482489 # BaseConfig
483490 if hasattr (args , "model" ):
484- llm_config .base .model_class = ModelType ( args .model )
491+ llm_config .base .model_class = args .model
485492 if hasattr (args , "params" ):
486493 llm_config .base .params = args .params
487494 if hasattr (args , "checkpoint" ):
@@ -499,15 +506,15 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
499506
500507 # PreqMode settings
501508 if hasattr (args , "preq_mode" ) and args .preq_mode :
502- llm_config .base .preq_mode = PreqMode ( args .preq_mode )
509+ llm_config .base .preq_mode = args .preq_mode
503510 if hasattr (args , "preq_group_size" ):
504511 llm_config .base .preq_group_size = args .preq_group_size
505512 if hasattr (args , "preq_embedding_quantize" ):
506513 llm_config .base .preq_embedding_quantize = args .preq_embedding_quantize
507514
508515 # ModelConfig
509516 if hasattr (args , "dtype_override" ):
510- llm_config .model .dtype_override = DtypeOverride ( args .dtype_override )
517+ llm_config .model .dtype_override = args .dtype_override
511518 if hasattr (args , "enable_dynamic_shape" ):
512519 llm_config .model .enable_dynamic_shape = args .enable_dynamic_shape
513520 if hasattr (args , "use_shared_embedding" ):
@@ -549,11 +556,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
549556 if hasattr (args , "embedding_quantize" ):
550557 llm_config .quantization .embedding_quantize = args .embedding_quantize
551558 if hasattr (args , "pt2e_quantize" ) and args .pt2e_quantize :
552- llm_config .quantization .pt2e_quantize = Pt2eQuantize ( args .pt2e_quantize )
559+ llm_config .quantization .pt2e_quantize = args .pt2e_quantize
553560 if hasattr (args , "group_size" ):
554561 llm_config .quantization .group_size = args .group_size
555562 if hasattr (args , "use_spin_quant" ) and args .use_spin_quant :
556- llm_config .quantization .use_spin_quant = SpinQuant ( args .use_spin_quant )
563+ llm_config .quantization .use_spin_quant = args .use_spin_quant
557564 if hasattr (args , "use_qat" ):
558565 llm_config .quantization .use_qat = args .use_qat
559566 if hasattr (args , "calibration_tasks" ):
@@ -581,13 +588,11 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
581588 args , "coreml_preserve_sdpa" , False
582589 )
583590 if hasattr (args , "coreml_quantize" ) and args .coreml_quantize :
584- llm_config .backend .coreml .quantize = CoreMLQuantize ( args .coreml_quantize )
591+ llm_config .backend .coreml .quantize = args .coreml_quantize
585592 if hasattr (args , "coreml_ios" ):
586593 llm_config .backend .coreml .ios = args .coreml_ios
587594 if hasattr (args , "coreml_compute_units" ):
588- llm_config .backend .coreml .compute_units = CoreMLComputeUnit (
589- args .coreml_compute_units
590- )
595+ llm_config .backend .coreml .compute_units = args .coreml_compute_units
591596
592597 # Vulkan
593598 if hasattr (args , "vulkan" ):
0 commit comments