Skip to content

Commit 31dd8c7

Browse files
committed
fix: updated the inference file
Signed-off-by: Omobayode Fagbohungbe <[email protected]>
1 parent adb7f38 commit 31dd8c7

File tree

5 files changed

+128
-57
lines changed

5 files changed

+128
-57
lines changed

fms_mo/dq.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
# Standard
2222
from pathlib import Path
2323
import logging
24-
import os
25-
import sys
2624

2725
# Third Party
2826
from datasets import load_from_disk
@@ -52,6 +50,7 @@
5250
from fms_mo.utils.dq_inf import (
5351
check_quantization_setting,
5452
convert_fp8_vllm_to_fms_mo,
53+
load_inference_qconfig_file,
5554
save_vllm_fp8,
5655
)
5756
from fms_mo.utils.dq_utils import config_quantize_smooth_layers
@@ -134,18 +133,6 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
134133
low_cpu_mem_usage=bool(model_args.device_map),
135134
)
136135

137-
inference_qconfig = None
138-
if hasattr(model, "config"):
139-
inference_qconfig = model.config.to_dict().get("quantization_config", None)
140-
141-
if inference_qconfig:
142-
quant_setting = check_quantization_setting(inference_qconfig)
143-
if quant_setting:
144-
logger.info("Quantization config settings validated ")
145-
model = convert_fp8_vllm_to_fms_mo(model=model)
146-
else:
147-
sys.exit("Error: This quantization config is wrong/not supported")
148-
149136
embedding_size = model.get_input_embeddings().weight.shape[0]
150137
if len(tokenizer) > embedding_size:
151138
model.resize_token_embeddings(len(tokenizer))
@@ -154,29 +141,17 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
154141
logger.info(f"Model is at {model.device} after intialization")
155142
logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}")
156143

157-
if not inference_qconfig:
144+
quant_mode = check_quantization_setting(model)
145+
146+
if not quant_mode:
158147
logger.info("quantization mode activated, initalizing the qcfg file ")
159148
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
160149
else:
161150
logger.info("inference mode activated")
162-
if os.path.isfile(model_args.model_name_or_path + "/qcfg.json"):
163-
if fms_mo_args.override_fms_args:
164-
logger.info(
165-
"qcfg file found and some parameters are being over-written "
166-
)
167-
qcfg = qconfig_init(
168-
recipe=model_args.model_name_or_path + "/qcfg", args=fms_mo_args
169-
)
170-
else:
171-
logger.info("qcfg file found, loading the qcfg file ")
172-
qcfg = qconfig_init(recipe=model_args.model_name_or_path + "/qcfg")
173-
else:
174-
logger.info(
175-
"qcfg file not found in {model_args.model_name_or_path},\
176-
loading fms_mo_args and recipe"
177-
)
178-
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
179-
qcfg["fp8_inference"] = True
151+
qcfg = load_inference_qconfig_file(model_args, fms_mo_args)
152+
153+
if quant_mode:
154+
model = convert_fp8_vllm_to_fms_mo(model=model)
180155

181156
model_size = model_size_Wb(model, unit="GB")
182157
gpu_mem_util_per = model_size / total_gpu_memory
@@ -201,7 +176,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
201176

202177
qcfg["model"] = model_args.model_name_or_path
203178
# config layers to skip, smooth scale
204-
if not inference_qconfig:
179+
if not quant_mode:
205180
config_quantize_smooth_layers(qcfg)
206181

207182
use_dynamo = True
@@ -234,7 +209,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
234209
)
235210

236211
# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
237-
if not inference_qconfig and qcfg["smoothq"]:
212+
if not quant_mode and qcfg["smoothq"]:
238213
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
239214
if qcfg.get("act_scale_path", None):
240215
# user provided a scale file (or a dir)
@@ -272,7 +247,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
272247
)
273248
logger.info(f"Quantized model {model}")
274249
logger.info("==" * 20)
275-
if not inference_qconfig:
250+
251+
if not quant_mode:
276252
if qcfg["smoothq"]:
277253
logger.info("Starting to apply smooth scale")
278254
dq_llm(model, act_scales, qcfg)

fms_mo/prep.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -395,16 +395,16 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
395395
# Third Party
396396
import compressed_tensors
397397

398-
if isinstance(
399-
module, compressed_tensors.linear.compressed_linear.CompressedLinear
400-
):
401-
pass
402-
else:
403-
logger.warning(
404-
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
405-
"Please make sure it doesn't wrap BN and activ func. Otherwise"
406-
"please create an equivalen Linear wrapper and change qcfg['mapping']."
407-
)
398+
if isinstance(
399+
module, compressed_tensors.linear.compressed_linear.CompressedLinear
400+
):
401+
pass
402+
else:
403+
logger.warning(
404+
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
405+
"Please make sure it doesn't wrap BN and activ func. Otherwise"
406+
"please create an equivalent Linear wrapper and change qcfg['mapping']."
407+
)
408408
QLin = mapping.get(nn.Linear, None)
409409
if QLin is None:
410410
if verbose:

fms_mo/quant/quantizers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def get_weight_quantizer(
237237
recompute=False,
238238
perGp=None,
239239
use_subnormal=False,
240-
emulate=True,
241240
):
242241
"""Return a quantizer for weight quantization
243242
Regular quantizers:
@@ -347,7 +346,7 @@ def get_weight_quantizer(
347346
weight_quantizer = to_fp8(
348347
nbits,
349348
q_mode=qw_mode,
350-
emulate=emulate,
349+
emulate=True,
351350
perCh=Nch,
352351
)
353352
else:

fms_mo/utils/dq_inf.py

Lines changed: 104 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,119 @@
2929
import torch
3030

3131
# Local
32+
from fms_mo import qconfig_init
3233
from fms_mo.quant.quantizers import to_fp8_scaled_perCh
3334
from fms_mo.utils.qconfig_utils import get_recipe
3435

3536
logger = logging.getLogger(__name__)
3637

3738

38-
def check_quantization_setting(inference: dict = None):
39+
def check_quantization_setting(model: nn.Module = None):
3940
"""
4041
function checks if the checkpoint is from fp8 quantization
4142
"""
42-
return (
43-
inference["config_groups"]["group_0"]["input_activations"]["num_bits"] == 8
44-
and inference["config_groups"]["group_0"]["weights"]["num_bits"] == 8
45-
and inference["config_groups"]["group_0"]["weights"]["type"] == "float"
46-
and inference["config_groups"]["group_0"]["input_activations"]["type"]
47-
== "float"
48-
)
43+
quant_config = None
44+
if hasattr(model, "config"):
45+
quant_config = model.config.to_dict().get("quantization_config", None)
46+
if quant_config is None:
47+
return False
48+
49+
logger.info("Validating config settings")
50+
if quant_config["quant_method"] == "compressed-tensors":
51+
if quant_config["format"] != "float-quantized":
52+
raise Exception(
53+
"The input activation and weight quantization dtypes are not supported"
54+
)
55+
56+
if (
57+
quant_config["config_groups"]["group_0"]["input_activations"]["num_bits"]
58+
!= 8
59+
):
60+
raise Exception("Only 8 bit FP input activation quantization is supported")
61+
62+
if quant_config["config_groups"]["group_0"]["weights"]["num_bits"] != 8:
63+
raise Exception("Only 8-bit FP weight quantization is supported")
64+
65+
if quant_config["kv_cache_scheme"] is None:
66+
pass
67+
else:
68+
if quant_config["kv_cache_scheme"]["type"] is not float:
69+
raise Exception("The KV-Cache quantization dtype is not supported")
70+
71+
if quant_config["kv_cache_scheme"]["num_bits"] != 8:
72+
raise Exception("Only 8-bit KV-Cache quantization dtype is supported")
73+
74+
return True
75+
76+
raise Exception("This quantization method is not supported for inferencing")
77+
78+
79+
def load_inference_qconfig_file(model_args, fms_mo_args):
80+
"""
81+
Function to load the inference quantization config for fms_mo
82+
"""
83+
if os.path.isfile(model_args.model_name_or_path + "/qcfg.json"):
84+
if fms_mo_args.override_qcfg_args:
85+
logger.info("qcfg file found and some parameters are being over-written")
86+
qcfg = qconfig_init(
87+
recipe=model_args.model_name_or_path + "/qcfg", args=fms_mo_args
88+
)
89+
else:
90+
logger.info("qcfg file found, loading the qcfg file ")
91+
qcfg = qconfig_init(recipe=model_args.model_name_or_path + "/qcfg")
92+
else:
93+
logger.info(
94+
f"qcfg file not found in {model_args.model_name_or_path},\
95+
loading fms_mo_args and recipe"
96+
)
97+
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
98+
qcfg = update_qcfg_from_model_config(model_args, qcfg)
99+
qcfg["fp8_inference"] = True
100+
101+
return qcfg
102+
103+
104+
def update_qcfg_from_model_config(model_args, qcfg):
105+
"""
106+
function to update the default qcfg setting with settings in the model config file.
107+
Important for the case where qcfg file does not exist.
108+
"""
109+
config = get_recipe(model_args.model_name_or_path + "/config")
110+
if (
111+
config["quantization_config"]["config_groups"]["group_0"]["input_activations"][
112+
"strategy"
113+
]
114+
== "token"
115+
):
116+
qcfg["qa_mode"] = "fp8_e4m3_scale_perToken"
117+
else:
118+
raise Exception("Only perToken Fp8 activation quantizer is supported")
119+
120+
if (
121+
config["quantization_config"]["config_groups"]["group_0"]["weights"]["strategy"]
122+
== "channel"
123+
):
124+
qcfg["qw_mode"] = "fp8_e4m3_scale_perCh"
125+
elif (
126+
config["quantization_config"]["config_groups"]["group_0"]["weights"]["strategy"]
127+
== "tensor"
128+
):
129+
qcfg["qw_mode"] = "fp8_e4m3_scale"
130+
else:
131+
raise Exception(
132+
"Only perChannel or pertensor FP8 quantizers are currently supported"
133+
)
134+
135+
qcfg["smoothq"] = False
136+
qcfg["nbits_a"] = config["quantization_config"]["config_groups"]["group_0"][
137+
"input_activations"
138+
]["num_bits"]
139+
qcfg["nbits_w"] = config["quantization_config"]["config_groups"]["group_0"][
140+
"weights"
141+
]["num_bits"]
142+
qcfg["torch_dtype"] = "float16"
143+
144+
return qcfg
49145

50146

51147
# def rename_fms_dict_to_vllm_dict (model_dict : dict= None, qcfg : dict = None):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies = [
3535

3636
[project.optional-dependencies]
3737
examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"]
38-
fp8 = ["llmcompressor", "torchao==0.11", "compressed_tensors"]
38+
fp8 = ["llmcompressor", "torchao==0.11"]
3939
gptq = ["Cython", "gptqmodel>=1.7.3"]
4040
mx = ["microxcaling>=1.1"]
4141
opt = ["fms-model-optimizer[fp8, gptq, mx]"]

0 commit comments

Comments
 (0)