1818from datetime import date
1919from importlib .metadata import version
2020from pathlib import Path
21- from typing import Any
21+ from typing import Any , Union
2222import json
2323import logging
2424import os
@@ -113,10 +113,10 @@ def config_defaults() -> dict:
113113 "qkvsync" : False ,
114114 "extend_act_range" : False ,
115115 "plotsvg" : False ,
116+ "qskip_large_mag_layers" : False ,
116117 # Iterable vars
117118 "qlayer_name_pattern" : [],
118119 "qskip_layer_name" : [],
119- "qskip_large_mag_layers" : False ,
120120 "qspecial_layers" : {},
121121 "qsinglesided_name" : [],
122122 "clip_val_asst_percentile" : (0.1 , 99.9 ),
@@ -142,21 +142,24 @@ def config_defaults() -> dict:
142142 "temp_disable_calib" : False ,
143143 "org_batch_size" : {},
144144 "ptqmod_to_be_optimized" : [],
145+ # SmoothQuant vars
146+ "smoothq" : False ,
147+ "smoothq_scale_layers" : [],
148+ "smoothq_act_scale_path" : None ,
145149 # Other vars
146150 "which2patch_contextmanager" : None ,
147151 "force_stop_if_qbmm_auto_check_failed" : False ,
148152 "world_size" : max (1 , torch .cuda .device_count ()),
149153 "global_rank" : 0 ,
150154 "batch_size" : 2 ,
155+ "keys_to_save" : [],
151156 # items could be obsoleted
152157 "output_attentions" : False ,
153158 "bias_corr" : False ,
154159 "qwav2vec" : False ,
155160 "qvit" : False ,
156161 "numparamsfromloadertomodel" : 1 ,
157162 "gradclip" : 0.0 ,
158- "smoothq" : False ,
159- "keys_to_save" : [],
160163 }
161164
162165 return cfg_defaults
@@ -201,7 +204,7 @@ def find_recipe_json(recipe: str, subdir: str = None) -> Path:
201204 return json_file
202205
203206
204- def get_recipe (recipe : str , subdir : str = None ) -> Any :
207+ def get_recipe (recipe : str , subdir : str = None ) -> Union [ list , dict ] :
205208 """
206209 Get a json recipe.
207210
@@ -219,6 +222,10 @@ def get_recipe(recipe: str, subdir: str = None) -> Any:
219222 temp_data = json .load (openfile )
220223 logger .info (f"Loaded settings from { json_file } ." )
221224
225+ # Any recipe should be a dict (qcfg) or list (keys_to_save)
226+ if not isinstance (temp_data , (dict , list )):
227+ raise ValueError (f"Loaded recipe { json_file } was not a dict or list" )
228+
222229 return temp_data
223230
224231
@@ -378,8 +385,14 @@ def qconfig_init(recipe: str = None, args: Any = None) -> dict:
378385 # this can be used to load a previously saved ckpt as well
379386 if recipe :
380387 # qcfg recipes should reside in fms_mo/recipes
381- temp_cfg = get_recipe (recipe )
388+ temp_cfg = qconfig_load (recipe )
389+
382390 if temp_cfg :
391+ if not isinstance (temp_cfg , dict ):
392+ raise ValueError (
393+ f"Quantized config recipe={ recipe } is not a dictionary"
394+ )
395+
383396 qcfg .update (temp_cfg )
384397 logger .info ("Updated config with recipe values" )
385398 else :
@@ -562,7 +575,12 @@ def qconfig_save(
562575
563576 # Next, check in fms_mo/recipes and merge them into a unique set (in case they differ)
564577 keys_to_save_json = get_recipe (recipe )
578+
565579 if keys_to_save_json :
580+ if not isinstance (keys_to_save_json , list ):
581+ raise ValueError (f"Save recipe={ recipe } is not a list!" )
582+
583+ # Merge keys_to_save lists
566584 keys_to_save = list (set (keys_to_save + keys_to_save_json ))
567585
568586 # If we found keys to save, fetch them from qcfg
@@ -604,9 +622,12 @@ def qconfig_save(
604622
605623def qconfig_load (fname : str = "qcfg.json" ) -> dict :
606624 """Read config in json format, work together with qconfig_save"""
607- if os .path .isfile (fname ):
608- with open (fname , "r" , encoding = "utf-8" ) as openfile :
609- config = json .load (openfile )
625+ config = get_recipe (fname )
626+
627+ if config :
628+ # Check that loaded file is a dict
629+ if not isinstance (config , dict ):
630+ raise ValueError (f"Quantized config={ fname } is not a dictionary" )
610631
611632 # Add back wanted defaults for any missing vars
612633 add_wanted_defaults_to_config (config , minimal = False )
@@ -856,6 +877,8 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
856877 "plotsvg" ,
857878 "ptq_freezecvs" ,
858879 "ptq_qdrop" ,
880+ "qskip_large_mag_layers" ,
881+ "smoothq" ,
859882 ]
860883 for boolean_var_str in boolean_vars_str :
861884 boolean_var = config .get (
@@ -912,6 +935,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
912935 "firstptqmodule" ,
913936 "params2optim" ,
914937 "clip_val_asst_percentile" ,
938+ "smoothq_scale_layers" ,
915939 ]
916940 for iterable_var_str in iterable_vars_str :
917941 iterable_var_default = default_config .get (iterable_var_str )
@@ -990,3 +1014,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
9901014 f"which2patch_contextmanager = { which2patch_contextmanager } is not one of "
9911015 f"the following: { which2patch_contextmanager_settings } "
9921016 )
1017+
1018+ smoothq_act_scale_path = config .get ("smoothq_act_scale_path" , None )
1019+ if smoothq_act_scale_path and not smoothq_act_scale_path .endswith (".pt" ):
1020+ raise ValueError (f"{ smoothq_act_scale_path = } is not a .pt checkpoint" )
0 commit comments