Skip to content

Commit 7167028

Browse files
authored
Merge pull request #152 from BrandonGroth/qconfig_bug
fix: Saved qconfig recipe being overwritten with defaults
2 parents e8f35bb + a6bd15a commit 7167028

File tree

2 files changed

+61
-25
lines changed

2 files changed

+61
-25
lines changed

fms_mo/utils/qconfig_utils.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -713,44 +713,51 @@ def remove_unwanted_from_config(
713713
return config, dump
714714

715715

716-
def get_unwanted_defaults() -> dict:
716+
def get_unserializable_defaults() -> dict:
717717
"""Add back those unserializable items if needed"""
718-
unwanted_items = [
719-
("sweep_cv_percentile", False),
720-
("tb_writer", None),
721-
(
722-
"mapping",
723-
{
724-
nn.Conv2d: QConv2d,
725-
nn.ConvTranspose2d: QConvTranspose2d,
726-
nn.Linear: QLinear,
727-
nn.LSTM: QLSTM,
728-
"matmul_or_bmm": QBmm,
729-
},
730-
),
731-
("checkQerr_frequency", False),
732-
("newlySwappedModules", []),
733-
("force_calib_once", False),
718+
unserializable_items = {
719+
"sweep_cv_percentile": False,
720+
"tb_writer": None,
721+
"mapping": {
722+
nn.Conv2d: QConv2d,
723+
nn.ConvTranspose2d: QConvTranspose2d,
724+
nn.Linear: QLinear,
725+
nn.LSTM: QLSTM,
726+
"matmul_or_bmm": QBmm,
727+
},
728+
"checkQerr_frequency": False,
729+
"newlySwappedModules": [],
730+
"force_calib_once": False,
734731
# if we keep the follwing LUTs, it will save the entire model
735-
("LUTmodule_name", {}),
736-
]
737-
return unwanted_items
732+
"LUTmodule_name": {},
733+
}
734+
return unserializable_items
735+
736+
737+
def add_if_not_present(config: dict, items_to_add: dict) -> None:
738+
"""
739+
Add items to config dict only if they aren't present
740+
741+
Args:
742+
config (dict): Quantized config
743+
items_to_add (dict): Items that will be added if not present in config
744+
"""
745+
for key, val in items_to_add.items():
746+
if key not in config:
747+
config[key] = val
738748

739749

740750
def add_required_defaults_to_config(config: dict) -> None:
741751
"""Recover "unserializable" items that are previously removed from config"""
742-
unwanted_items = get_unwanted_defaults()
743-
for key, default_val in unwanted_items:
744-
if key not in config:
745-
config[key] = default_val
752+
add_if_not_present(config, get_unserializable_defaults())
746753

747754

748755
def add_wanted_defaults_to_config(config: dict, minimal: bool = True) -> None:
749756
"""Util function to add basic config defaults that are missing into a config
750757
if a wanted item is not in the config, add it w/ default value
751758
"""
752759
if not minimal:
753-
config.update(config_defaults())
760+
add_if_not_present(config, config_defaults())
754761

755762

756763
def qconfig_save(

tests/models/test_saveconfig.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121

2222
# Local
23+
from fms_mo import qconfig_init
2324
from fms_mo.utils.qconfig_utils import qconfig_load, qconfig_save
2425
from tests.models.test_model_utils import (
2526
delete_file,
@@ -298,3 +299,31 @@ def test_load_config_required_pair(
298299

299300
loaded_config = qconfig_load("qcfg.json")
300301
assert loaded_config.get(key) == default_val
302+
303+
304+
def test_save_init_recipe(
305+
config_int8: dict,
306+
):
307+
"""
308+
Change a config, save it,
309+
310+
Args:
311+
config_fp32 (dict): Config for fp32 quantization
312+
"""
313+
# Change some elements of config to ensure its being saved/loaded properly
314+
config_int8["qa_mode"] = "minmax"
315+
config_int8["qa_mode"] = "pertokenmax"
316+
config_int8["qmodel_calibration"] = 17
317+
config_int8["qskip_layer_name"] = ["lm_head"]
318+
319+
qconfig_save(config_int8)
320+
recipe_config = qconfig_init(recipe="qcfg.json")
321+
322+
# Remove date field from recipe_config - only added at save
323+
del recipe_config["date"]
324+
325+
assert len(recipe_config) == len(config_int8)
326+
327+
for key, val in config_int8.items():
328+
assert key in recipe_config
329+
assert recipe_config.get(key) == val

0 commit comments

Comments
 (0)