Skip to content

Commit fbdf19f

Browse files
committed
fix: type hinting arguments and returns
Signed-off-by: Omobayode Fagbohungbe <[email protected]>
1 parent aca818a commit fbdf19f

File tree

4 files changed

+78
-57
lines changed

4 files changed

+78
-57
lines changed

fms_mo/dq.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,16 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
141141
logger.info(f"Model is at {model.device} after intialization")
142142
logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}")
143143

144-
quant_mode = check_quantization_setting(model)
144+
inference_only = check_quantization_setting(model)
145145

146-
if not quant_mode:
146+
if not inference_only:
147147
logger.info("quantization mode activated, initalizing the qcfg file ")
148148
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
149149
else:
150150
logger.info("inference mode activated")
151151
qcfg = load_inference_qconfig_file(model_args, fms_mo_args)
152152

153-
if quant_mode:
153+
if inference_only:
154154
model = convert_fp8_vllm_to_fms_mo(model=model)
155155

156156
model_size = model_size_Wb(model, unit="GB")
@@ -176,7 +176,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
176176

177177
qcfg["model"] = model_args.model_name_or_path
178178
# config layers to skip, smooth scale
179-
if not quant_mode:
179+
if not inference_only:
180180
config_quantize_smooth_layers(qcfg)
181181

182182
use_dynamo = True
@@ -209,7 +209,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
209209
)
210210

211211
# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
212-
if not quant_mode and qcfg["smoothq"]:
212+
if not inference_only and qcfg["smoothq"]:
213213
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
214214
if qcfg.get("act_scale_path", None):
215215
# user provided a scale file (or a dir)
@@ -248,7 +248,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
248248
logger.info(f"Quantized model {model}")
249249
logger.info("==" * 20)
250250

251-
if not quant_mode:
251+
if not inference_only:
252252
if qcfg["smoothq"]:
253253
logger.info("Starting to apply smooth scale")
254254
dq_llm(model, act_scales, qcfg)

fms_mo/prep.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,17 +394,17 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
394394
if available_packages["compressed_tensors"]:
395395
# Third Party
396396
import compressed_tensors
397-
398-
if isinstance(
397+
# checks if the layer is CompressedLinear. If it is a CompressedLinear layer,
398+
# it does nothing. Otherwise, it throws the warning sign
399+
if not isinstance(
399400
module, compressed_tensors.linear.compressed_linear.CompressedLinear
400401
):
401-
pass
402-
else:
403402
logger.warning(
404403
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
405404
"Please make sure it doesn't wrap BN and activ func. Otherwise"
406405
"please create an equivalent Linear wrapper and change qcfg['mapping']."
407406
)
407+
408408
QLin = mapping.get(nn.Linear, None)
409409
if QLin is None:
410410
if verbose:

fms_mo/training_args.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,7 @@ class FMSMOArguments(TypeChecker):
209209
default=False,
210210
metadata={"help": "Apply recomputation during checkpoint saving for AIU."},
211211
)
212-
fp8_use_subnormal: bool = field(default=False)
213-
override_fms_args: bool = field(default=False)
212+
override_qcfg_args: bool = field(default=False)
214213

215214

216215
@dataclass

fms_mo/utils/dq_inf.py

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
# Standard
20+
from typing import Any, Dict, List, Tuple, Union
2021
import glob
2122
import json
2223
import logging
@@ -36,7 +37,7 @@
3637
logger = logging.getLogger(__name__)
3738

3839

39-
def check_quantization_setting(model: nn.Module = None):
40+
def check_quantization_setting(model: nn.Module) -> bool:
4041
"""
4142
function checks if the checkpoint is from fp8 quantization
4243
"""
@@ -47,36 +48,49 @@ def check_quantization_setting(model: nn.Module = None):
4748
return False
4849

4950
logger.info("Validating config settings")
50-
if quant_config["quant_method"] == "compressed-tensors":
51-
if quant_config["format"] != "float-quantized":
52-
raise ValueError(
53-
"The input activation and weight quantization dtypes are not supported"
54-
)
55-
56-
if (
57-
quant_config["config_groups"]["group_0"]["input_activations"]["num_bits"]
58-
!= 8
59-
):
60-
raise ValueError("Only 8 bit FP input activation quantization is supported")
61-
62-
if quant_config["config_groups"]["group_0"]["weights"]["num_bits"] != 8:
63-
raise ValueError("Only 8-bit FP weight quantization is supported")
64-
65-
if quant_config["kv_cache_scheme"] is None:
66-
pass
67-
else:
68-
if quant_config["kv_cache_scheme"]["type"] is not float:
69-
raise ValueError("The KV-Cache quantization dtype is not supported")
70-
71-
if quant_config["kv_cache_scheme"]["num_bits"] != 8:
72-
raise ValueError("Only 8-bit KV-Cache quantization dtype is supported")
73-
74-
return True
51+
if "quant_method" in quant_config.keys():
52+
if quant_config["quant_method"] == "compressed-tensors":
53+
if quant_config["format"] != "float-quantized":
54+
raise ValueError(
55+
"The input activation and weight quantization dtypes are not supported"
56+
)
57+
58+
if (
59+
quant_config["config_groups"]["group_0"]["input_activations"][
60+
"num_bits"
61+
]
62+
!= 8
63+
):
64+
raise ValueError(
65+
"Only 8 bit FP input activation quantization is supported"
66+
)
67+
68+
if quant_config["config_groups"]["group_0"]["weights"]["num_bits"] != 8:
69+
raise ValueError("Only 8-bit FP weight quantization is supported")
70+
71+
if quant_config["kv_cache_scheme"] is not None:
72+
if quant_config["kv_cache_scheme"]["type"] is not float:
73+
raise ValueError("The KV-Cache quantization dtype is not supported")
74+
75+
if quant_config["kv_cache_scheme"]["num_bits"] != 8:
76+
raise ValueError(
77+
"Only 8-bit KV-Cache quantization dtype is supported"
78+
)
79+
80+
return True
81+
raise ValueError(
82+
"The quantization method is not supported for inferencing."
83+
"Only Fp8 quantization is supported"
84+
)
7585

76-
raise ValueError("This quantization method is not supported for inferencing")
86+
raise ValueError(
87+
"The quantization method is not found. Please check the config file"
88+
)
7789

7890

79-
def load_inference_qconfig_file(model_args, fms_mo_args):
91+
def load_inference_qconfig_file(
92+
model_args: Any = None, fms_mo_args: Any = None
93+
) -> Dict[str, Union[int, float, str]]:
8094
"""
8195
Function to load the inference quantization config for fms_mo
8296
"""
@@ -87,12 +101,13 @@ def load_inference_qconfig_file(model_args, fms_mo_args):
87101
recipe=model_args.model_name_or_path + "/qcfg", args=fms_mo_args
88102
)
89103
else:
90-
logger.info("qcfg file found, loading the qcfg file ")
104+
logger.info(f"loading quantization configuration from\
105+
{model_args.model_name_or_path + '/qcfg.json'}")
91106
qcfg = qconfig_init(recipe=model_args.model_name_or_path + "/qcfg")
92107
else:
93108
logger.info(
94-
f"qcfg file not found in {model_args.model_name_or_path},\
95-
loading fms_mo_args and recipe"
109+
f"qcfg file not found in {model_args.model_name_or_path},"
110+
"loading fms_mo_args and recipe"
96111
)
97112
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
98113
qcfg = update_qcfg_from_model_config(model_args, qcfg)
@@ -101,7 +116,9 @@ def load_inference_qconfig_file(model_args, fms_mo_args):
101116
return qcfg
102117

103118

104-
def update_qcfg_from_model_config(model_args, qcfg):
119+
def update_qcfg_from_model_config(
120+
model_args: Any = None, qcfg: dict = None
121+
) -> Dict[str, Union[int, float, str]]:
105122
"""
106123
function to update the default qcfg setting with settings in the model config file.
107124
Important for the case where qcfg file does not exist.
@@ -144,15 +161,16 @@ def update_qcfg_from_model_config(model_args, qcfg):
144161
return qcfg
145162

146163

147-
# def rename_fms_dict_to_vllm_dict (model_dict : dict= None, qcfg : dict = None):
148-
def rename_fms_dict_to_vllm_dict(model_dict: dict = None):
164+
def rename_fms_dict_to_vllm_dict(
165+
model_dict: dict = None,
166+
) -> Tuple[Dict[str, Union[int, float]], Dict[str, Union[int, float]]]:
149167
"""
150168
Function to rename the dict in fms_mo format to vllm_format.
151169
"""
152170
st_dict = {}
153171
fms_dict = {}
154172
keys = model_dict.keys()
155-
173+
logger.info("WARNING: only static weights per-channel is supported at this time")
156174
for k, v in model_dict.items():
157175
if ".weight" in k:
158176
key = k.split("weight")[0]
@@ -167,7 +185,9 @@ def rename_fms_dict_to_vllm_dict(model_dict: dict = None):
167185
return st_dict, fms_dict
168186

169187

170-
def update_config(model_config_file: dict = None, qcfg: dict = None):
188+
def update_config(
189+
model_config_file: dict = None, qcfg: dict = None
190+
) -> Dict[str, Union[int, str]]:
171191
"""
172192
Function to update the model config file with quantization configuration
173193
"""
@@ -181,7 +201,9 @@ def update_config(model_config_file: dict = None, qcfg: dict = None):
181201
return model_config_file
182202

183203

184-
def save_vllm_fp8(model: nn.Module, qcfg: dict, tokenizer=None, folder: str = None):
204+
def save_vllm_fp8(
205+
model: nn.Module, qcfg: dict, tokenizer=None, folder: str = None
206+
) -> None:
185207
"""
186208
Function to save fp8 DQ model in vllm fp8 format
187209
"""
@@ -200,7 +222,9 @@ def save_vllm_fp8(model: nn.Module, qcfg: dict, tokenizer=None, folder: str = No
200222
json.dump(config, f, indent=4)
201223

202224

203-
def convert_fms_mo_to_vllm_fp8_format(checkpoint: str = None, folder: str = None):
225+
def convert_fms_mo_to_vllm_fp8_format(
226+
checkpoint: str = None, folder: str = None
227+
) -> None:
204228
"""
205229
Function to convert fp8 fms_mo DQ model checkpoint to vllm fp8 format
206230
"""
@@ -231,7 +255,7 @@ def convert_fms_mo_to_vllm_fp8_format(checkpoint: str = None, folder: str = None
231255
json.dump(config, f, indent=4)
232256

233257

234-
def find_file_glob(pattern: str, search_path: str):
258+
def find_file_glob(pattern: str, search_path: str) -> List[str]:
235259
"""
236260
Finds files matching a pattern within a directory and its subdirectories.
237261
"""
@@ -243,7 +267,7 @@ def find_file_glob(pattern: str, search_path: str):
243267

244268
def convert_fp8_vllm_dict_to_fms_mo_dict(
245269
checkpoint: str = None, output_dir: str = None
246-
):
270+
) -> None:
247271
"""
248272
Function to help convert vllm fp8 checkpoint into fms_mo fp8 format
249273
"""
@@ -257,7 +281,7 @@ def convert_fp8_vllm_dict_to_fms_mo_dict(
257281
save_torch_state_dict(fms_mo_dict, output_dir)
258282

259283

260-
def rename_vllm_dict_to_fms_mo(vllm_dict: dict = None):
284+
def rename_vllm_dict_to_fms_mo(vllm_dict: dict) -> dict:
261285
"""
262286
Function to help rename vllm dict format to fms_mo dict format
263287
"""
@@ -271,14 +295,12 @@ def rename_vllm_dict_to_fms_mo(vllm_dict: dict = None):
271295
fms_mo_dict[k] = v
272296
else:
273297
key = k.split("weight")[0]
274-
if key + "weight_scale" in vllm_dict.keys():
275-
pass
276-
else:
298+
if key + "weight_scale" not in vllm_dict.keys():
277299
fms_mo_dict[k] = v
278300
return fms_mo_dict
279301

280302

281-
def convert_fp8_vllm_to_fms_mo(model: nn.Module = None):
303+
def convert_fp8_vllm_to_fms_mo(model: nn.Module = None) -> nn.Module:
282304
"""
283305
Function to help convert fp8 vllm model dict format to fms_mo fp8 format
284306
"""

0 commit comments

Comments
 (0)