Skip to content

Commit 34e5ad1

Browse files
committed
fix: Added guards to qconfig save/load, added smoothq prefix for vars, and added smoothq vars to default config
Signed-off-by: Brandon Groth <[email protected]>
1 parent 1a0161a commit 34e5ad1

File tree

4 files changed

+52
-23
lines changed

4 files changed

+52
-23
lines changed

fms_mo/quant/ptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2537,7 +2537,7 @@ def dq_llm(model, scale, qcfg):
25372537

25382538
for name, module in model.named_modules():
25392539
if isinstance(module, (QLinear,)):
2540-
if any(x in name for x in qcfg["scale_layers"]):
2540+
if any(x in name for x in qcfg["smoothq_scale_layers"]):
25412541
module.set_act_scale(scale[name])
25422542
logger.info(
25432543
f"Apply layer {name} with activation scales (10)"

fms_mo/recipes/dq.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"decoder_arch": true,
33
"align_zero": true,
44
"qgroup": null,
5-
"act_scale_path": null,
5+
"smoothq_act_scale_path": null,
66
"qmodel_calibration_new": 10,
77
"qskip_large_mag_layers": true,
88
"ptq_nbatch": 128,

fms_mo/utils/dq_utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def config_quantize_smooth_layers(qcfg: dict):
1818
"""Update qcfg with model-dependent config parameters:
1919
- qlayer_name_pattern: identifier of transformer layers containing linear layers
2020
to quantize (if any, tracing is bypassed)
21-
- scale_layers: identifier of linear layers to apply smoothquant on
2221
- qskip_layer_name: full name of linear layers that will not be quantized
23-
- act_scale_path: path to save/load smoothquant activation scales
22+
- smoothq_scale_layers: identifier of linear layers to apply smoothquant on
23+
- smoothq_act_scale_path: path to save/load smoothquant activation scales
2424
2525
Selected model is determined by comparing all architecture identifiers against
2626
`model` and `model_type` fields in qcfg.
@@ -56,7 +56,7 @@ def config_quantize_smooth_layers(qcfg: dict):
5656
model in qcfg["model_type"] for model in llama_architecture
5757
):
5858
qcfg["qlayer_name_pattern"] = ["model.layers."]
59-
qcfg["scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"]
59+
qcfg["smoothq_scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"]
6060
if qcfg["qskip_large_mag_layers"]:
6161
large_mag_layers = {
6262
"2-7b": [1, 30],
@@ -75,13 +75,13 @@ def config_quantize_smooth_layers(qcfg: dict):
7575
model in qcfg["model_type"] for model in granite_architecture
7676
):
7777
qcfg["qlayer_name_pattern"] = ["model.layers."]
78-
qcfg["scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"]
78+
qcfg["smoothq_scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"]
7979
# NOTE: supported granite-v3 models do not need layer skip for large magnitude
8080
elif "mixtral" in qcfg["model"]:
8181
qcfg["qlayer_name_pattern"] = (
8282
["model.layers"] if qcfg["nbits_bmm1"] == 32 else []
8383
)
84-
qcfg["scale_layers"] = ["q_proj", "k_proj", "v_proj", "w1", "w3"]
84+
qcfg["smoothq_scale_layers"] = ["q_proj", "k_proj", "v_proj", "w1", "w3"]
8585
qcfg["qskip_layer_name"] += [
8686
f"model.layers.{i}.block_sparse_moe.gate" for i in range(32)
8787
]
@@ -98,22 +98,22 @@ def config_quantize_smooth_layers(qcfg: dict):
9898
[31, 7],
9999
]
100100
]
101-
qcfg["act_scale_path"] = "./act_scales/Mixtral-8x7B-v0.1.pt"
101+
qcfg["smoothq_act_scale_path"] = "./act_scales/Mixtral-8x7B-v0.1.pt"
102102
elif any(model in qcfg["model"] for model in bigcode_architecture):
103103
qcfg["qlayer_name_pattern"] = ["transformer.h"]
104-
qcfg["scale_layers"] = ["c_attn", "c_fc"]
104+
qcfg["smoothq_scale_layers"] = ["c_attn", "c_fc"]
105105
# NOTE: supported bigcode models do not need layer skip for large magnitude
106106
if "granite-3b-base-v2" in qcfg["model"]:
107-
qcfg["act_scale_path"] = "./act_scales/granite_3b_base_v2_500_nw.pt"
107+
qcfg["smoothq_act_scale_path"] = "./act_scales/granite_3b_base_v2_500_nw.pt"
108108
if "granite-13b-base-v2" in qcfg["model"]:
109-
qcfg["act_scale_path"] = "./act_scales/granite_13b_base_v2.pt"
109+
qcfg["smoothq_act_scale_path"] = "./act_scales/granite_13b_base_v2.pt"
110110
if "granite-20b-code-base" in qcfg["model"]:
111-
qcfg["act_scale_path"] = "./act_scales/graniteCodeHF_20b_base12.pt"
111+
qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_20b_base12.pt"
112112
if "granite-20b-code-instruct" in qcfg["model"]:
113-
qcfg["act_scale_path"] = "./act_scales/graniteCodeHF_20b_base12.pt"
113+
qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_20b_base12.pt"
114114
if "granite-34b-code-base" in qcfg["model"]:
115-
qcfg["act_scale_path"] = "./act_scales/graniteCodeHF_34b_base12.pt"
115+
qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_34b_base12.pt"
116116
if "granite-34b-code-instruct" in qcfg["model"]:
117-
qcfg["act_scale_path"] = "./act_scales/graniteCodeHF_34b_base12.pt"
117+
qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_34b_base12.pt"
118118
else:
119119
raise ValueError("The model architecture is not supported for DQ.")

fms_mo/utils/qconfig_utils.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from datetime import date
1919
from importlib.metadata import version
2020
from pathlib import Path
21-
from typing import Any
21+
from typing import Any, Union
2222
import json
2323
import logging
2424
import os
@@ -113,6 +113,7 @@ def config_defaults() -> dict:
113113
"qkvsync": False,
114114
"extend_act_range": False,
115115
"plotsvg": False,
116+
"qskip_large_mag_layers": False,
116117
# Iterable vars
117118
"qlayer_name_pattern": [],
118119
"qskip_layer_name": [],
@@ -142,21 +143,24 @@ def config_defaults() -> dict:
142143
"temp_disable_calib": False,
143144
"org_batch_size": {},
144145
"ptqmod_to_be_optimized": [],
146+
# SmoothQuant vars
147+
"smoothq": False,
148+
"smoothq_scale_layers": [],
149+
"smoothq_act_scale_path": None,
145150
# Other vars
146151
"which2patch_contextmanager": None,
147152
"force_stop_if_qbmm_auto_check_failed": False,
148153
"world_size": max(1, torch.cuda.device_count()),
149154
"global_rank": 0,
150155
"batch_size": 2,
156+
"keys_to_save": [],
151157
# items could be obsoleted
152158
"output_attentions": False,
153159
"bias_corr": False,
154160
"qwav2vec": False,
155161
"qvit": False,
156162
"numparamsfromloadertomodel": 1,
157163
"gradclip": 0.0,
158-
"smoothq": False,
159-
"keys_to_save": [],
160164
}
161165

162166
return cfg_defaults
@@ -201,7 +205,7 @@ def find_recipe_json(recipe: str, subdir: str = None) -> Path:
201205
return json_file
202206

203207

204-
def get_recipe(recipe: str, subdir: str = None) -> Any:
208+
def get_recipe(recipe: str, subdir: str = None) -> Union[list, dict]:
205209
"""
206210
Get a json recipe.
207211
@@ -219,6 +223,10 @@ def get_recipe(recipe: str, subdir: str = None) -> Any:
219223
temp_data = json.load(openfile)
220224
logger.info(f"Loaded settings from {json_file}.")
221225

226+
# Any recipe should be a dict (qcfg) or list (keys_to_save)
227+
if not isinstance(temp_data, (dict, list)):
228+
raise ValueError(f"Loaded recipe {json_file} was not a dict or list")
229+
222230
return temp_data
223231

224232

@@ -378,8 +386,14 @@ def qconfig_init(recipe: str = None, args: Any = None) -> dict:
378386
# this can be used to load a previously saved ckpt as well
379387
if recipe:
380388
# qcfg recipes should reside in fms_mo/recipes
381-
temp_cfg = get_recipe(recipe)
389+
temp_cfg = qconfig_load(recipe)
390+
382391
if temp_cfg:
392+
if not isinstance(temp_cfg, dict):
393+
raise ValueError(
394+
f"Quantized config recipe={recipe} is not a dictionary"
395+
)
396+
383397
qcfg.update(temp_cfg)
384398
logger.info("Updated config with recipe values")
385399
else:
@@ -562,7 +576,12 @@ def qconfig_save(
562576

563577
# Next, check in fms_mo/recipes and merge them into a unique set (in case they differ)
564578
keys_to_save_json = get_recipe(recipe)
579+
565580
if keys_to_save_json:
581+
if not isinstance(keys_to_save_json, list):
582+
raise ValueError(f"Save recipe={recipe} is not a list!")
583+
584+
# Merge keys_to_save lists
566585
keys_to_save = list(set(keys_to_save + keys_to_save_json))
567586

568587
# If we found keys to save, fetch them from qcfg
@@ -604,9 +623,12 @@ def qconfig_save(
604623

605624
def qconfig_load(fname: str = "qcfg.json") -> dict:
606625
"""Read config in json format, work together with qconfig_save"""
607-
if os.path.isfile(fname):
608-
with open(fname, "r", encoding="utf-8") as openfile:
609-
config = json.load(openfile)
626+
config = get_recipe(fname)
627+
628+
if config:
629+
# Check that loaded file is a dict
630+
if not isinstance(config, dict):
631+
raise ValueError(f"Quantized config={fname} is not a dictionary")
610632

611633
# Add back wanted defaults for any missing vars
612634
add_wanted_defaults_to_config(config, minimal=False)
@@ -856,6 +878,8 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
856878
"plotsvg",
857879
"ptq_freezecvs",
858880
"ptq_qdrop",
881+
"qskip_large_mag_layers",
882+
"smoothq",
859883
]
860884
for boolean_var_str in boolean_vars_str:
861885
boolean_var = config.get(
@@ -912,6 +936,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
912936
"firstptqmodule",
913937
"params2optim",
914938
"clip_val_asst_percentile",
939+
"smoothq_scale_layers",
915940
]
916941
for iterable_var_str in iterable_vars_str:
917942
iterable_var_default = default_config.get(iterable_var_str)
@@ -990,3 +1015,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
9901015
f"which2patch_contextmanager = {which2patch_contextmanager} is not one of "
9911016
f"the following: {which2patch_contextmanager_settings}"
9921017
)
1018+
1019+
smoothq_act_scale_path = config.get("smoothq_act_scale_path", None)
1020+
if smoothq_act_scale_path and not smoothq_act_scale_path.endswith(".pt"):
1021+
raise ValueError(f"{smoothq_act_scale_path=} is not a .pt checkpoint")

0 commit comments

Comments
 (0)