1717"""
1818
1919# Standard
20- from typing import Any , Dict , List , Tuple , Union
20+ from typing import Any
2121import glob
2222import json
2323import logging
@@ -90,7 +90,7 @@ def check_quantization_setting(model: nn.Module) -> bool:
9090
9191def load_inference_qconfig_file (
9292 model_args : Any = None , fms_mo_args : Any = None
93- ) -> Dict [str , Union [ int , float , str ] ]:
93+ ) -> dict [str , int | float | str ]:
9494 """
9595 Function to load the inference quantization config for fms_mo
9696 """
@@ -118,7 +118,7 @@ def load_inference_qconfig_file(
118118
119119def update_qcfg_from_model_config (
120120 model_args : Any = None , qcfg : dict = None
121- ) -> Dict [str , Union [ int , float , str ] ]:
121+ ) -> dict [str , int | float | str ]:
122122 """
123123 function to update the default qcfg setting with settings in the model config file.
124124 Important for the case where qcfg file does not exist.
@@ -157,13 +157,18 @@ def update_qcfg_from_model_config(
157157 "weights"
158158 ]["num_bits" ]
159159 qcfg ["torch_dtype" ] = "float16"
160-
160+ if config ["quantization_config" ]["ignore" ] is not []:
161+ qcfg ["qskip_layer_name" ] = config ["quantization_config" ]["ignore" ]
162+ qcfg ["qskip_large_mag_layers" ] = True
163+ else :
164+ qcfg ["qskip_layer_name" ] = []
165+ qcfg ["qskip_large_mag_layers" ] = False
161166 return qcfg
162167
163168
164169def rename_fms_dict_to_vllm_dict (
165170 model_dict : dict = None ,
166- ) -> Tuple [ Dict [str , Union [ int , float ]], Dict [str , Union [ int , float ] ]]:
171+ ) -> tuple [ dict [str , float | int ], dict [str , float | int ]]:
167172 """
168173 Function to rename the dict in fms_mo format to vllm_format.
169174 """
@@ -187,7 +192,7 @@ def rename_fms_dict_to_vllm_dict(
187192
188193def update_config (
189194 model_config_file : dict = None , qcfg : dict = None
190- ) -> Dict [str , Union [ int , str ] ]:
195+ ) -> dict [str , float | int | str ]:
191196 """
192197 Function to update the model config file with quantization configuration
193198 """
@@ -196,7 +201,8 @@ def update_config(
196201 data ["quantization_config" ]["config_groups" ]["group_0" ]["weights" ] = (
197202 "{num_bits: 8, type: float, symmetric: true, strategy: tensor}"
198203 )
199-
204+ if qcfg ["qskip_large_mag_layers" ] == True :
205+ data ["quantization_config" ]["ignore" ] = qcfg ["qskip_layer_name" ]
200206 model_config_file .update (data )
201207 return model_config_file
202208
@@ -255,7 +261,7 @@ def convert_fms_mo_to_vllm_fp8_format(
255261 json .dump (config , f , indent = 4 )
256262
257263
258- def find_file_glob (pattern : str , search_path : str ) -> List [str ]:
264+ def find_file_glob (pattern : str , search_path : str ) -> list [str ]:
259265 """
260266 Finds files matching a pattern within a directory and its subdirectories.
261267 """
@@ -281,7 +287,7 @@ def convert_fp8_vllm_dict_to_fms_mo_dict(
281287 save_torch_state_dict (fms_mo_dict , output_dir )
282288
283289
284- def rename_vllm_dict_to_fms_mo (vllm_dict : dict ) -> dict :
290+ def rename_vllm_dict_to_fms_mo (vllm_dict : dict ) -> dict [ str , float | int | str ] :
285291 """
286292 Function to help rename vllm dict format to fms_mo dict format
287293 """
0 commit comments