1717"""
1818
1919# Standard
20+ from typing import Any , Dict , List , Tuple , Union
2021import glob
2122import json
2223import logging
3637logger = logging .getLogger (__name__ )
3738
3839
39- def check_quantization_setting (model : nn .Module = None ) :
40+ def check_quantization_setting (model : nn .Module ) -> bool :
4041 """
4142 function checks if the checkpoint is from fp8 quantization
4243 """
@@ -47,36 +48,49 @@ def check_quantization_setting(model: nn.Module = None):
4748 return False
4849
4950 logger .info ("Validating config settings" )
50- if quant_config ["quant_method" ] == "compressed-tensors" :
51- if quant_config ["format" ] != "float-quantized" :
52- raise ValueError (
53- "The input activation and weight quantization dtypes are not supported"
54- )
55-
56- if (
57- quant_config ["config_groups" ]["group_0" ]["input_activations" ]["num_bits" ]
58- != 8
59- ):
60- raise ValueError ("Only 8 bit FP input activation quantization is supported" )
61-
62- if quant_config ["config_groups" ]["group_0" ]["weights" ]["num_bits" ] != 8 :
63- raise ValueError ("Only 8-bit FP weight quantization is supported" )
64-
65- if quant_config ["kv_cache_scheme" ] is None :
66- pass
67- else :
68- if quant_config ["kv_cache_scheme" ]["type" ] is not float :
69- raise ValueError ("The KV-Cache quantization dtype is not supported" )
70-
71- if quant_config ["kv_cache_scheme" ]["num_bits" ] != 8 :
72- raise ValueError ("Only 8-bit KV-Cache quantization dtype is supported" )
73-
74- return True
51+ if "quant_method" in quant_config .keys ():
52+ if quant_config ["quant_method" ] == "compressed-tensors" :
53+ if quant_config ["format" ] != "float-quantized" :
54+ raise ValueError (
55+ "The input activation and weight quantization dtypes are not supported"
56+ )
57+
58+ if (
59+ quant_config ["config_groups" ]["group_0" ]["input_activations" ][
60+ "num_bits"
61+ ]
62+ != 8
63+ ):
64+ raise ValueError (
65+ "Only 8 bit FP input activation quantization is supported"
66+ )
67+
68+ if quant_config ["config_groups" ]["group_0" ]["weights" ]["num_bits" ] != 8 :
69+ raise ValueError ("Only 8-bit FP weight quantization is supported" )
70+
71+ if quant_config ["kv_cache_scheme" ] is not None :
72+ if quant_config ["kv_cache_scheme" ]["type" ] is not float :
73+ raise ValueError ("The KV-Cache quantization dtype is not supported" )
74+
75+ if quant_config ["kv_cache_scheme" ]["num_bits" ] != 8 :
76+ raise ValueError (
77+ "Only 8-bit KV-Cache quantization dtype is supported"
78+ )
79+
80+ return True
81+ raise ValueError (
82+ "The quantization method is not supported for inferencing."
83+ "Only Fp8 quantization is supported"
84+ )
7585
76- raise ValueError ("This quantization method is not supported for inferencing" )
86+ raise ValueError (
87+ "The quantization method is not found. Please check the config file"
88+ )
7789
7890
79- def load_inference_qconfig_file (model_args , fms_mo_args ):
91+ def load_inference_qconfig_file (
92+ model_args : Any = None , fms_mo_args : Any = None
93+ ) -> Dict [str , Union [int , float , str ]]:
8094 """
8195 Function to load the inference quantization config for fms_mo
8296 """
@@ -87,12 +101,13 @@ def load_inference_qconfig_file(model_args, fms_mo_args):
87101 recipe = model_args .model_name_or_path + "/qcfg" , args = fms_mo_args
88102 )
89103 else :
90- logger .info ("qcfg file found, loading the qcfg file " )
104+ logger .info (f"loading quantization configuration from\
105+ { model_args .model_name_or_path + '/qcfg.json' } " )
91106 qcfg = qconfig_init (recipe = model_args .model_name_or_path + "/qcfg" )
92107 else :
93108 logger .info (
94- f"qcfg file not found in { model_args .model_name_or_path } ,\
95- loading fms_mo_args and recipe"
109+ f"qcfg file not found in { model_args .model_name_or_path } ,"
110+ " loading fms_mo_args and recipe"
96111 )
97112 qcfg = qconfig_init (recipe = "dq" , args = fms_mo_args )
98113 qcfg = update_qcfg_from_model_config (model_args , qcfg )
@@ -101,7 +116,9 @@ def load_inference_qconfig_file(model_args, fms_mo_args):
101116 return qcfg
102117
103118
104- def update_qcfg_from_model_config (model_args , qcfg ):
119+ def update_qcfg_from_model_config (
120+ model_args : Any = None , qcfg : dict = None
121+ ) -> Dict [str , Union [int , float , str ]]:
105122 """
106123 function to update the default qcfg setting with settings in the model config file.
107124 Important for the case where qcfg file does not exist.
@@ -144,15 +161,16 @@ def update_qcfg_from_model_config(model_args, qcfg):
144161 return qcfg
145162
146163
147- # def rename_fms_dict_to_vllm_dict (model_dict : dict= None, qcfg : dict = None):
148- def rename_fms_dict_to_vllm_dict (model_dict : dict = None ):
164+ def rename_fms_dict_to_vllm_dict (
165+ model_dict : dict = None ,
166+ ) -> Tuple [Dict [str , Union [int , float ]], Dict [str , Union [int , float ]]]:
149167 """
150168 Function to rename the dict in fms_mo format to vllm_format.
151169 """
152170 st_dict = {}
153171 fms_dict = {}
154172 keys = model_dict .keys ()
155-
173+ logger . info ( "WARNING: only static weights per-channel is supported at this time" )
156174 for k , v in model_dict .items ():
157175 if ".weight" in k :
158176 key = k .split ("weight" )[0 ]
@@ -167,7 +185,9 @@ def rename_fms_dict_to_vllm_dict(model_dict: dict = None):
167185 return st_dict , fms_dict
168186
169187
170- def update_config (model_config_file : dict = None , qcfg : dict = None ):
188+ def update_config (
189+ model_config_file : dict = None , qcfg : dict = None
190+ ) -> Dict [str , Union [int , str ]]:
171191 """
172192 Function to update the model config file with quantization configuration
173193 """
@@ -181,7 +201,9 @@ def update_config(model_config_file: dict = None, qcfg: dict = None):
181201 return model_config_file
182202
183203
184- def save_vllm_fp8 (model : nn .Module , qcfg : dict , tokenizer = None , folder : str = None ):
204+ def save_vllm_fp8 (
205+ model : nn .Module , qcfg : dict , tokenizer = None , folder : str = None
206+ ) -> None :
185207 """
186208 Function to save fp8 DQ model in vllm fp8 format
187209 """
@@ -200,7 +222,9 @@ def save_vllm_fp8(model: nn.Module, qcfg: dict, tokenizer=None, folder: str = No
200222 json .dump (config , f , indent = 4 )
201223
202224
203- def convert_fms_mo_to_vllm_fp8_format (checkpoint : str = None , folder : str = None ):
225+ def convert_fms_mo_to_vllm_fp8_format (
226+ checkpoint : str = None , folder : str = None
227+ ) -> None :
204228 """
205229 Function to convert fp8 fms_mo DQ model checkpoint to vllm fp8 format
206230 """
@@ -231,7 +255,7 @@ def convert_fms_mo_to_vllm_fp8_format(checkpoint: str = None, folder: str = None
231255 json .dump (config , f , indent = 4 )
232256
233257
234- def find_file_glob (pattern : str , search_path : str ):
258+ def find_file_glob (pattern : str , search_path : str ) -> List [ str ] :
235259 """
236260 Finds files matching a pattern within a directory and its subdirectories.
237261 """
@@ -243,7 +267,7 @@ def find_file_glob(pattern: str, search_path: str):
243267
244268def convert_fp8_vllm_dict_to_fms_mo_dict (
245269 checkpoint : str = None , output_dir : str = None
246- ):
270+ ) -> None :
247271 """
248272 Function to help convert vllm fp8 checkpoint into fms_mo fp8 format
249273 """
@@ -257,7 +281,7 @@ def convert_fp8_vllm_dict_to_fms_mo_dict(
257281 save_torch_state_dict (fms_mo_dict , output_dir )
258282
259283
260- def rename_vllm_dict_to_fms_mo (vllm_dict : dict = None ) :
284+ def rename_vllm_dict_to_fms_mo (vllm_dict : dict ) -> dict :
261285 """
262286 Function to help rename vllm dict format to fms_mo dict format
263287 """
@@ -271,14 +295,12 @@ def rename_vllm_dict_to_fms_mo(vllm_dict: dict = None):
271295 fms_mo_dict [k ] = v
272296 else :
273297 key = k .split ("weight" )[0 ]
274- if key + "weight_scale" in vllm_dict .keys ():
275- pass
276- else :
298+ if key + "weight_scale" not in vllm_dict .keys ():
277299 fms_mo_dict [k ] = v
278300 return fms_mo_dict
279301
280302
281- def convert_fp8_vllm_to_fms_mo (model : nn .Module = None ):
303+ def convert_fp8_vllm_to_fms_mo (model : nn .Module = None ) -> nn . Module :
282304 """
283305 Function to help convert fp8 vllm model dict format to fms_mo fp8 format
284306 """
0 commit comments