Skip to content

Commit 00bad82

Browse files
committed
fix: Fixed recipe being overwritten in qconfig_init
Signed-off-by: Brandon Groth <[email protected]>
1 parent 9b75c10 commit 00bad82

File tree

1 file changed

+32
-25
lines changed

1 file changed

+32
-25
lines changed

fms_mo/utils/qconfig_utils.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -713,44 +713,51 @@ def remove_unwanted_from_config(
713713
return config, dump
714714

715715

716-
def get_unwanted_defaults() -> dict:
716+
def get_unserializable_defaults() -> dict:
717717
"""Add back those unserializable items if needed"""
718-
unwanted_items = [
719-
("sweep_cv_percentile", False),
720-
("tb_writer", None),
721-
(
722-
"mapping",
723-
{
724-
nn.Conv2d: QConv2d,
725-
nn.ConvTranspose2d: QConvTranspose2d,
726-
nn.Linear: QLinear,
727-
nn.LSTM: QLSTM,
728-
"matmul_or_bmm": QBmm,
729-
},
730-
),
731-
("checkQerr_frequency", False),
732-
("newlySwappedModules", []),
733-
("force_calib_once", False),
718+
unserializable_items = {
719+
"sweep_cv_percentile": False,
720+
"tb_writer": None,
721+
"mapping": {
722+
nn.Conv2d: QConv2d,
723+
nn.ConvTranspose2d: QConvTranspose2d,
724+
nn.Linear: QLinear,
725+
nn.LSTM: QLSTM,
726+
"matmul_or_bmm": QBmm,
727+
},
728+
"checkQerr_frequency": False,
729+
"newlySwappedModules": [],
730+
"force_calib_once": False,
734731
# if we keep the follwing LUTs, it will save the entire model
735-
("LUTmodule_name", {}),
736-
]
737-
return unwanted_items
732+
"LUTmodule_name": {},
733+
}
734+
return unserializable_items
735+
736+
737+
def add_if_not_present(config: dict, items_to_add: dict) -> None:
738+
"""
739+
Add items to config dict only if they aren't present
740+
741+
Args:
742+
config (dict): Quantized config
743+
items_to_add (dict): Items that will be added if not present in config
744+
"""
745+
for key, val in items_to_add.items():
746+
if key not in config:
747+
config[key] = val
738748

739749

740750
def add_required_defaults_to_config(config: dict) -> None:
741751
"""Recover "unserializable" items that are previously removed from config"""
742-
unwanted_items = get_unwanted_defaults()
743-
for key, default_val in unwanted_items:
744-
if key not in config:
745-
config[key] = default_val
752+
add_if_not_present(config, get_unserializable_defaults())
746753

747754

748755
def add_wanted_defaults_to_config(config: dict, minimal: bool = True) -> None:
749756
"""Util function to add basic config defaults that are missing into a config
750757
if a wanted item is not in the config, add it w/ default value
751758
"""
752759
if not minimal:
753-
config.update(config_defaults())
760+
add_if_not_present(config, config_defaults())
754761

755762

756763
def qconfig_save(

0 commit comments

Comments
 (0)