1313Uses dataclases, which integrate with OmegaConf and Hydra.
1414"""
1515
16+ import re
1617from dataclasses import dataclass , field
17- from typing import List , Optional
18+ from enum import Enum
19+ from typing import List , Literal , Optional
20+
21+
22+ ################################################################################
23+ ################################## BaseConfig ##################################
24+ ################################################################################
25+
26+
27+ class ModelType (str , Enum ):
28+ STORIES110M = "stories110m"
29+ LLAMA2 = "llama2"
30+ LLAMA3 = "llama3"
31+ LLAMA3_1 = "llama3_1"
32+ LLAMA3_2 = "llama3_2"
33+ LLAMA3_2_VISION = "llama3_2_vision"
34+ STATIC_LLAMA = "static_llama"
35+ QWEN2_5 = "qwen2_5"
36+ QWEN3_0_6B = "qwen3-0_6b"
37+ QWEN3_1_7B = "qwen3-1_7b"
38+ QWEN3_4B = "qwen3-4b"
39+ PHI_4_MINI = "phi_4_mini"
40+ SMOLLM2 = "smollm2"
41+
42+
43+ class PreqMode (str , Enum ):
44+ PREQ_8DA4W = "8da4w"
45+ PREQ_8DA4W_OUT_8DA8W = "8da4w_output_8da8w"
1846
1947
2048@dataclass
2149class BaseConfig :
2250 """
2351 These are specific to the specific model, e.g. whether it’s Qwen3 0.6B or Phi-4-mini.
24- for each of these different models, you can expect each of these fields to change.
52+ For each of these different models, you can expect each of these fields to change.
2553 """
2654
27- model_class : str = "llama"
55+ model_class : ModelType = ModelType . LLAMA3
2856 params : Optional [str ] = None
2957 checkpoint : Optional [str ] = None
30- checkpoint_dir : Optional [str ] = None # For sharded checkpoint
58+ checkpoint_dir : Optional [str ] = None # For sharded checkpoint.
3159 tokenizer_path : Optional [str ] = None
3260 metadata : Optional [str ] = None
33- fairseq2 : bool = False # For legacy internal use cases
61+ use_lora : bool = False
62+ fairseq2 : bool = False # For legacy internal use cases.
63+
64+ # Legacy pre-quantization options that happen during model weight loading.
65+ preq_mode : Optional [PreqMode ] = None
66+ preq_group_size : int = 32
67+ preq_embedding_quantize : str = "8,0"
68+
69+
70+ ################################################################################
71+ ################################# ModelConfig ##################################
72+ ################################################################################
73+
74+
75+ class DtypeOverride (str , Enum ):
76+ FP32 = "fp32"
77+ FP16 = "fp16"
78+ BF16 = "bf16"
3479
3580
3681@dataclass
@@ -42,29 +87,39 @@ class ModelConfig:
4287 to different models.
4388 """
4489
45- dtype_override : str = "fp32"
90+ dtype_override : DtypeOverride = DtypeOverride . FP32
4691 enable_dynamic_shape : bool = True
4792 use_shared_embedding : bool = False
48- use_lora : bool = False
4993 use_sdpa_with_kv_cache : bool = False
5094 expand_rope_table : bool = False
95+ use_attention_sink : Optional [str ] = None
5196 output_prune_map : Optional [str ] = None
5297 input_prune_map : Optional [str ] = None
5398
5499 # Below are config options relating to kv cache.
55- use_kv_cache : Optional [bool ] = None
56- quantize_kv_cache : Optional [bool ] = None
57- local_global_attention : List [int ] = None
100+ use_kv_cache : bool = False
101+ quantize_kv_cache : bool = False
102+ local_global_attention : Optional [List [int ]] = None
103+
104+
105+ ################################################################################
106+ ################################ ExportConfig ##################################
107+ ################################################################################
58108
59109
60110@dataclass
61111class ExportConfig :
62- max_seq_length : Optional [ int ] = None
63- max_context_length : Optional [ int ] = None
112+ max_seq_length : int = 128
113+ max_context_length : int = 128
64114 output_dir : Optional [str ] = None
65115 output_name : Optional [str ] = None
66116 so_library : Optional [str ] = None
67- export_only : Optional [bool ] = None
117+ export_only : bool = False
118+
119+
120+ ################################################################################
121+ ################################# DebugConfig ##################################
122+ ################################################################################
68123
69124
70125@dataclass
@@ -73,45 +128,101 @@ class DebugConfig:
73128 profile_path : Optional [str ] = None
74129 generate_etrecord : bool = False
75130 generate_full_logits : bool = False
76- verbose : bool = False # Would be good to remove this from the config eventually
131+ verbose : bool = False
132+
133+
134+ ################################################################################
135+ ############################# QuantizationConfig ###############################
136+ ################################################################################
77137
78138
79- ########################################################################
80- #### The below config can eventually be replaced by export recipes #####
81- ########################################################################
139+ class Pt2eQuantize (str , Enum ):
140+ XNNPACK_DYNAMIC = "xnnpack_dynamic"
141+ XNNPACK_DYNAMIC_QC4 = "xnnpack_dynamic_qc4"
142+ QNN_8A8W = "qnn_8a8w"
143+ QNN_16A16W = "qnn_16a16w"
144+ QNN_16A4W = "qnn_16a4w"
145+ COREML_C4W = "coreml_c4w"
146+ COREML_8A_C8W = "coreml_8a_c8w"
147+ COREML_8A_C4W = "coreml_8a_c4w"
148+ COREML_BASELINE_8A_C8W = "coreml_baseline_8a_c8w"
149+ COREML_BASELINE_8A_C4W = "coreml_baseline_8a_c4w"
150+ VULKAN_8W = "vulkan_8w"
151+
152+
153+ class SpinQuant (str , Enum ):
154+ CUDA = "cuda"
155+ NATIVE = "native"
82156
83157
84158@dataclass
85159class QuantizationConfig :
86160 qmode : Optional [str ] = None
87- embedding_quantize : Optional [bool ] = None
88- pt2e_quantize : Optional [bool ] = None
161+ embedding_quantize : Optional [str ] = None
162+ pt2e_quantize : Optional [Pt2eQuantize ] = None
89163 group_size : Optional [int ] = None
90- use_spin_quant : Optional [bool ] = None
164+ use_spin_quant : Optional [SpinQuant ] = None
91165 use_qat : Optional [bool ] = None
92- preq_mode : Optional [str ] = None
93- preq_group_size : Optional [int ] = None
94- preq_embedding_quantize : Optional [bool ] = None
95- calibration_tasks : Optional [str ] = None
166+ calibration_tasks : Optional [List [str ]] = None
96167 calibration_limit : Optional [int ] = None
97168 calibration_seq_length : Optional [int ] = None
98169 calibration_data : Optional [str ] = None
99170
171+ def __post_init__ (self ):
172+ self ._validate_qmode ()
173+
174+ def _validate_qmode (self ) -> None :
175+ choices = ["int8" , "8da4w" , "8da4w-gptq" , "vulkan_4w" ]
176+ patterns = [r"torchao:8da(\d+)w" , r"torchao:fpa(\d+)w" ]
177+
178+ if self .qmode in choices :
179+ return
180+
181+ for pattern in patterns :
182+ matches = re .findall (pattern , self .qmode )
183+ if len (matches ) == 1 :
184+ return
185+
186+ raise ValueError (
187+ f"Got qmode { self .qmode } , but expected one of { choices } , or one of the regex patterns { patterns } ."
188+ )
189+
190+
191+ ################################################################################
192+ ############################### BackendConfig ##################################
193+ ################################################################################
194+
100195
101196@dataclass
102197class XNNPackConfig :
103- enabled : Optional [bool ] = None
104- extended_ops : Optional [bool ] = None
198+ enabled : bool = False
199+ extended_ops : bool = False
200+
201+
202+ class CoreMLQuantize (str , Enum ):
203+ B4W = "b4w"
204+ C4W = "c4w"
205+
206+
207+ class CoreMLComputeUnit (str , Enum ):
208+ CPU_ONLY = "cpu_only"
209+ CPU_AND_GPU = "cpu_and_gpu"
210+ CPU_AND_NE = "cpu_and_ne"
211+ ALL = "all"
105212
106213
107214@dataclass
108- class CoreMLConfig : # coreML recipe?
109- enabled : Optional [bool ] = None
110- enable_state : Optional [bool ] = None
111- preserve_sdpa : Optional [bool ] = None
112- quantize : Optional [bool ] = None
113- ios : Optional [bool ] = None
114- compute_units : Optional [str ] = None
215+ class CoreMLConfig :
216+ enabled : bool = False
217+ enable_state : bool = False
218+ preserve_sdpa : bool = False
219+ quantize : Optional [CoreMLQuantize ] = None
220+ ios : Literal [15 , 16 , 17 , 18 ] = 15
221+ compute_units : CoreMLComputeUnit = CoreMLComputeUnit .CPU_ONLY
222+
223+ def __post_init__ (self ):
224+ if self .ios not in (15 , 16 , 17 , 18 ):
225+ raise ValueError (f"Invalid coreml ios version: { self .ios } " )
115226
116227
117228@dataclass
@@ -143,6 +254,11 @@ class BackendConfig:
143254 mps : MPSConfig = field (default_factory = MPSConfig )
144255
145256
257+ ################################################################################
258+ ################################## LlmConfig ###################################
259+ ################################################################################
260+
261+
146262@dataclass
147263class LlmConfig :
148264 base : BaseConfig = field (default_factory = BaseConfig )
0 commit comments