1818from datetime import date
1919from importlib .metadata import version
2020from pathlib import Path
21- from typing import Any
21+ from typing import Any , Union
2222import json
2323import logging
2424import os
@@ -113,6 +113,7 @@ def config_defaults() -> dict:
113113 "qkvsync" : False ,
114114 "extend_act_range" : False ,
115115 "plotsvg" : False ,
116+ "qskip_large_mag_layers" : False ,
116117 # Iterable vars
117118 "qlayer_name_pattern" : [],
118119 "qskip_layer_name" : [],
@@ -142,21 +143,24 @@ def config_defaults() -> dict:
142143 "temp_disable_calib" : False ,
143144 "org_batch_size" : {},
144145 "ptqmod_to_be_optimized" : [],
146+ # SmoothQuant vars
147+ "smoothq" : False ,
148+ "smoothq_scale_layers" : [],
149+ "smoothq_act_scale_path" : None ,
145150 # Other vars
146151 "which2patch_contextmanager" : None ,
147152 "force_stop_if_qbmm_auto_check_failed" : False ,
148153 "world_size" : max (1 , torch .cuda .device_count ()),
149154 "global_rank" : 0 ,
150155 "batch_size" : 2 ,
156+ "keys_to_save" : [],
151157 # items could be obsoleted
152158 "output_attentions" : False ,
153159 "bias_corr" : False ,
154160 "qwav2vec" : False ,
155161 "qvit" : False ,
156162 "numparamsfromloadertomodel" : 1 ,
157163 "gradclip" : 0.0 ,
158- "smoothq" : False ,
159- "keys_to_save" : [],
160164 }
161165
162166 return cfg_defaults
@@ -201,7 +205,7 @@ def find_recipe_json(recipe: str, subdir: str = None) -> Path:
201205 return json_file
202206
203207
204- def get_recipe (recipe : str , subdir : str = None ) -> Any :
208+ def get_recipe (recipe : str , subdir : str = None ) -> Union [ list , dict ] :
205209 """
206210 Get a json recipe.
207211
@@ -219,6 +223,10 @@ def get_recipe(recipe: str, subdir: str = None) -> Any:
219223 temp_data = json .load (openfile )
220224 logger .info (f"Loaded settings from { json_file } ." )
221225
226+ # Any recipe should be a dict (qcfg) or list (keys_to_save)
227+ if not isinstance (temp_data , (dict , list )):
228+ raise ValueError (f"Loaded recipe { json_file } was not a dict or list" )
229+
222230 return temp_data
223231
224232
@@ -378,8 +386,14 @@ def qconfig_init(recipe: str = None, args: Any = None) -> dict:
378386 # this can be used to load a previously saved ckpt as well
379387 if recipe :
380388 # qcfg recipes should reside in fms_mo/recipes
381- temp_cfg = get_recipe (recipe )
389+ temp_cfg = qconfig_load (recipe )
390+
382391 if temp_cfg :
392+ if not isinstance (temp_cfg , dict ):
393+ raise ValueError (
394+ f"Quantized config recipe={ recipe } is not a dictionary"
395+ )
396+
383397 qcfg .update (temp_cfg )
384398 logger .info ("Updated config with recipe values" )
385399 else :
@@ -562,7 +576,12 @@ def qconfig_save(
562576
563577 # Next, check in fms_mo/recipes and merge them into a unique set (in case they differ)
564578 keys_to_save_json = get_recipe (recipe )
579+
565580 if keys_to_save_json :
581+ if not isinstance (keys_to_save_json , list ):
582+ raise ValueError (f"Save recipe={ recipe } is not a list!" )
583+
584+ # Merge keys_to_save lists
566585 keys_to_save = list (set (keys_to_save + keys_to_save_json ))
567586
568587 # If we found keys to save, fetch them from qcfg
@@ -604,9 +623,12 @@ def qconfig_save(
604623
605624def qconfig_load (fname : str = "qcfg.json" ) -> dict :
606625 """Read config in json format, work together with qconfig_save"""
607- if os .path .isfile (fname ):
608- with open (fname , "r" , encoding = "utf-8" ) as openfile :
609- config = json .load (openfile )
626+ config = get_recipe (fname )
627+
628+ if config :
629+ # Check that loaded file is a dict
630+ if not isinstance (config , dict ):
631+ raise ValueError (f"Quantized config={ fname } is not a dictionary" )
610632
611633 # Add back wanted defaults for any missing vars
612634 add_wanted_defaults_to_config (config , minimal = False )
@@ -856,6 +878,8 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
856878 "plotsvg" ,
857879 "ptq_freezecvs" ,
858880 "ptq_qdrop" ,
881+ "qskip_large_mag_layers" ,
882+ "smoothq" ,
859883 ]
860884 for boolean_var_str in boolean_vars_str :
861885 boolean_var = config .get (
@@ -912,6 +936,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
912936 "firstptqmodule" ,
913937 "params2optim" ,
914938 "clip_val_asst_percentile" ,
939+ "smoothq_scale_layers" ,
915940 ]
916941 for iterable_var_str in iterable_vars_str :
917942 iterable_var_default = default_config .get (iterable_var_str )
@@ -990,3 +1015,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
9901015 f"which2patch_contextmanager = { which2patch_contextmanager } is not one of "
9911016 f"the following: { which2patch_contextmanager_settings } "
9921017 )
1018+
1019+ smoothq_act_scale_path = config .get ("smoothq_act_scale_path" , None )
1020+ if smoothq_act_scale_path and not smoothq_act_scale_path .endswith (".pt" ):
1021+ raise ValueError (f"{ smoothq_act_scale_path = } is not a .pt checkpoint" )
0 commit comments