Skip to content

Commit 60f1aa0

Browse files
committed
Add qconfig save during saving for AIU
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent c02f333 commit 60f1aa0

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

fms_mo/utils/aiu_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818

1919
# Third Party
20+
from fms_mo.utils.qconfig_utils import qconfig_save
2021
from transformers.modeling_utils import PreTrainedModel
2122
import torch
2223

@@ -217,7 +218,7 @@ def convert_sd_for_aiu(
217218

218219
def save_sd_for_aiu(
219220
model: PreTrainedModel,
220-
output_dir: str,
221+
output_dir: str = "./",
221222
savename: str = "qmodel_state_dict.pt",
222223
verbose: bool = False,
223224
) -> None:
@@ -226,3 +227,31 @@ def save_sd_for_aiu(
226227
converted_sd = convert_sd_for_aiu(model, verbose)
227228
torch.save(converted_sd, Path(output_dir) / savename)
228229
logger.info("Model saved.")
230+
231+
232+
def save_for_aiu(
233+
model: PreTrainedModel,
234+
qcfg: dict,
235+
output_dir: str = "./",
236+
file_name: str = "qmodel.pt",
237+
cfg_name: str = "qcfg.json",
238+
recipe: str | None = None,
239+
verbose: bool = False,
240+
) -> None:
241+
"""Save quantized model and configuration in the format request by the AIU.
242+
The checkpoint saving is customized for AIU compatibility.
243+
The general qconfig_save function is used to save the quantization configuration.
244+
"""
245+
246+
save_sd_for_aiu(model, output_dir, file_name, verbose)
247+
248+
# define specific keys needed when reloading model for AIU
249+
qcfg["keys_to_save"] = [
250+
"qa_mode",
251+
"qw_mode",
252+
"smoothq",
253+
"scale_layers",
254+
"qskip_layer_name",
255+
"qskip_large_mag_layers",
256+
]
257+
qconfig_save(qcfg, recipe=recipe, minimal=True, fname=Path(output_dir) / cfg_name)

fms_mo/utils/qconfig_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,9 @@ def add_wanted_defaults_to_config(config: dict, minimal: bool = True) -> None:
538538

539539
def qconfig_save(
540540
qcfg: dict,
541-
recipe: str = None,
541+
recipe: str | None = None,
542542
minimal: bool = True,
543-
fname="qcfg.json",
543+
fname: str = "qcfg.json",
544544
) -> None:
545545
"""
546546
Try to save qcfg into a JSON file (or use .pt format if something really can't be text-only).
@@ -550,8 +550,8 @@ def qconfig_save(
550550
Args:
551551
qcfg (dict): Quantized config.
552552
recipe (str, optional): String name for a save recipe. Defaults to None.
553-
minimal (bool, optional): Save a minimal quantized config. Defaults to True.
554-
fname (str, optional): File name to save quantized config. Defaults to "qcfg.json".
553+
minimal (bool): Save a minimal quantized config. Defaults to True.
554+
fname (str): File name to save quantized config. Defaults to "qcfg.json".
555555
"""
556556

557557
# First check in qcfg for added save list
@@ -598,6 +598,7 @@ def qconfig_save(
598598
warnings.warn(message, UserWarning)
599599
with open(fname, "w", encoding="utf-8") as outfile:
600600
json.dump(temp_qcfg, outfile, indent=4)
601+
logger.info(f"Quantization configuration saved to {fname}")
601602

602603

603604
def qconfig_load(fname: str = "qcfg.json") -> dict:

0 commit comments

Comments
 (0)