Skip to content

Commit d3e7c61

Browse files
committed
fix: improving argument hints and inferencing for models with skipped layers
Signed-off-by: Omobayode Fagbohungbe <[email protected]>
1 parent fbdf19f commit d3e7c61

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

fms_mo/utils/dq_inf.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""
1818

1919
# Standard
20-
from typing import Any, Dict, List, Tuple, Union
20+
from typing import Any
2121
import glob
2222
import json
2323
import logging
@@ -90,7 +90,7 @@ def check_quantization_setting(model: nn.Module) -> bool:
9090

9191
def load_inference_qconfig_file(
9292
model_args: Any = None, fms_mo_args: Any = None
93-
) -> Dict[str, Union[int, float, str]]:
93+
) -> dict[str, int | float | str]:
9494
"""
9595
Function to load the inference quantization config for fms_mo
9696
"""
@@ -118,7 +118,7 @@ def load_inference_qconfig_file(
118118

119119
def update_qcfg_from_model_config(
120120
model_args: Any = None, qcfg: dict = None
121-
) -> Dict[str, Union[int, float, str]]:
121+
) -> dict[str, int | float | str]:
122122
"""
123123
function to update the default qcfg setting with settings in the model config file.
124124
Important for the case where qcfg file does not exist.
@@ -157,13 +157,18 @@ def update_qcfg_from_model_config(
157157
"weights"
158158
]["num_bits"]
159159
qcfg["torch_dtype"] = "float16"
160-
160+
if config["quantization_config"]["ignore"] is not []:
161+
qcfg["qskip_layer_name"] = config["quantization_config"]["ignore"]
162+
qcfg["qskip_large_mag_layers"] = True
163+
else:
164+
qcfg["qskip_layer_name"] = []
165+
qcfg["qskip_large_mag_layers"] = False
161166
return qcfg
162167

163168

164169
def rename_fms_dict_to_vllm_dict(
165170
model_dict: dict = None,
166-
) -> Tuple[Dict[str, Union[int, float]], Dict[str, Union[int, float]]]:
171+
) -> tuple[dict[str, float | int], dict[str, float | int]]:
167172
"""
168173
Function to rename the dict in fms_mo format to vllm_format.
169174
"""
@@ -187,7 +192,7 @@ def rename_fms_dict_to_vllm_dict(
187192

188193
def update_config(
189194
model_config_file: dict = None, qcfg: dict = None
190-
) -> Dict[str, Union[int, str]]:
195+
) -> dict[str, float | int | str]:
191196
"""
192197
Function to update the model config file with quantization configuration
193198
"""
@@ -196,7 +201,8 @@ def update_config(
196201
data["quantization_config"]["config_groups"]["group_0"]["weights"] = (
197202
"{num_bits: 8, type: float, symmetric: true, strategy: tensor}"
198203
)
199-
204+
if qcfg["qskip_large_mag_layers"] == True:
205+
data["quantization_config"]["ignore"] = qcfg["qskip_layer_name"]
200206
model_config_file.update(data)
201207
return model_config_file
202208

@@ -255,7 +261,7 @@ def convert_fms_mo_to_vllm_fp8_format(
255261
json.dump(config, f, indent=4)
256262

257263

258-
def find_file_glob(pattern: str, search_path: str) -> List[str]:
264+
def find_file_glob(pattern: str, search_path: str) -> list[str]:
259265
"""
260266
Finds files matching a pattern within a directory and its subdirectories.
261267
"""
@@ -281,7 +287,7 @@ def convert_fp8_vllm_dict_to_fms_mo_dict(
281287
save_torch_state_dict(fms_mo_dict, output_dir)
282288

283289

284-
def rename_vllm_dict_to_fms_mo(vllm_dict: dict) -> dict:
290+
def rename_vllm_dict_to_fms_mo(vllm_dict: dict) -> dict[str, float | int | str]:
285291
"""
286292
Function to help rename vllm dict format to fms_mo dict format
287293
"""

fms_mo/utils/import_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"torchvision",
4343
"huggingface_hub",
4444
"torchao",
45+
"compressed_tensors"
4546
]
4647

4748
available_packages = {}

0 commit comments

Comments
 (0)