@@ -713,44 +713,51 @@ def remove_unwanted_from_config(
713713 return config , dump
714714
715715
716- def get_unwanted_defaults () -> dict :
716+ def get_unserializable_defaults () -> dict :
717717 """Add back those unserializable items if needed"""
718- unwanted_items = [
719- ("sweep_cv_percentile" , False ),
720- ("tb_writer" , None ),
721- (
722- "mapping" ,
723- {
724- nn .Conv2d : QConv2d ,
725- nn .ConvTranspose2d : QConvTranspose2d ,
726- nn .Linear : QLinear ,
727- nn .LSTM : QLSTM ,
728- "matmul_or_bmm" : QBmm ,
729- },
730- ),
731- ("checkQerr_frequency" , False ),
732- ("newlySwappedModules" , []),
733- ("force_calib_once" , False ),
718+ unserializable_items = {
719+ "sweep_cv_percentile" : False ,
720+ "tb_writer" : None ,
721+ "mapping" : {
722+ nn .Conv2d : QConv2d ,
723+ nn .ConvTranspose2d : QConvTranspose2d ,
724+ nn .Linear : QLinear ,
725+ nn .LSTM : QLSTM ,
726+ "matmul_or_bmm" : QBmm ,
727+ },
728+ "checkQerr_frequency" : False ,
729+ "newlySwappedModules" : [],
730+ "force_calib_once" : False ,
734731 # if we keep the follwing LUTs, it will save the entire model
735- ("LUTmodule_name" , {}),
736- ]
737- return unwanted_items
732+ "LUTmodule_name" : {},
733+ }
734+ return unserializable_items
735+
736+
737+ def add_if_not_present (config : dict , items_to_add : dict ) -> None :
738+ """
739+ Add items to config dict only if they aren't present
740+
741+ Args:
742+ config (dict): Quantized config
743+ items_to_add (dict): Items that will be added if not present in config
744+ """
745+ for key , val in items_to_add .items ():
746+ if key not in config :
747+ config [key ] = val
738748
739749
740750def add_required_defaults_to_config (config : dict ) -> None :
741751 """Recover "unserializable" items that are previously removed from config"""
742- unwanted_items = get_unwanted_defaults ()
743- for key , default_val in unwanted_items :
744- if key not in config :
745- config [key ] = default_val
752+ add_if_not_present (config , get_unserializable_defaults ())
746753
747754
748755def add_wanted_defaults_to_config (config : dict , minimal : bool = True ) -> None :
749756 """Util function to add basic config defaults that are missing into a config
750757 if a wanted item is not in the config, add it w/ default value
751758 """
752759 if not minimal :
753- config . update ( config_defaults ())
760+ add_if_not_present ( config , config_defaults ())
754761
755762
756763def qconfig_save (
0 commit comments