Skip to content

Commit 71b5aa4

Browse files
committed
fix: correcting errors
Signed-off-by: Omobayode Fagbohungbe <[email protected]>
1 parent 0b5d68a commit 71b5aa4

File tree

6 files changed

+143
-91
lines changed

6 files changed

+143
-91
lines changed

fms_mo/dq.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright The FMS Model Optimizer Authors
2-
2+
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
6-
6+
#
77
# http://www.apache.org/licenses/LICENSE-2.0
8-
8+
#
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -21,6 +21,7 @@
2121
# Standard
2222
from pathlib import Path
2323
import logging
24+
import os
2425

2526
# Third Party
2627
from datasets import load_from_disk
@@ -33,8 +34,8 @@
3334
default_data_collator,
3435
)
3536
import torch
37+
import sys
3638

37-
import os
3839
# Local
3940
from fms_mo import qconfig_init, qmodel_prep
4041
from fms_mo.custom_ext_kernels.utils import (
@@ -48,14 +49,14 @@
4849
get_act_scales_1gpu,
4950
)
5051
from fms_mo.utils.aiu_utils import save_for_aiu
51-
from fms_mo.utils.dq_utils import config_quantize_smooth_layers
52-
from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU
53-
from fms_mo.utils.utils import patch_torch_bmm, prepare_input
5452
from fms_mo.utils.dq_inf import (
55-
save_vllm_fp8,
56-
convert_fp8_vllm_to_fms_mo,
5753
check_quantization_setting,
54+
convert_fp8_vllm_to_fms_mo,
55+
save_vllm_fp8,
5856
)
57+
from fms_mo.utils.dq_utils import config_quantize_smooth_layers
58+
from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU
59+
from fms_mo.utils.utils import patch_torch_bmm, prepare_input
5960

6061
logger = logging.getLogger(__name__)
6162

@@ -133,16 +134,15 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
133134
low_cpu_mem_usage=bool(model_args.device_map),
134135
)
135136

136-
inference= model.config.to_dict().get("quantization_config",None)
137+
inference = model.config.to_dict().get("quantization_config", None)
137138

138139
if inference:
139140
quant_setting = check_quantization_setting(inference)
140141
if quant_setting:
141142
logger.info("Quantization config settings validated ")
142-
model = convert_fp8_vllm_to_fms_mo(model = model)
143+
model = convert_fp8_vllm_to_fms_mo(model=model)
143144
else:
144-
exit("__This quantization config is wrong/not supported__")
145-
145+
sys.exit("Error: This quantization config is wrong/not supported")
146146

147147
embedding_size = model.get_input_embeddings().weight.shape[0]
148148
if len(tokenizer) > embedding_size:
@@ -157,17 +157,22 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
157157
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
158158
else:
159159
logger.info("inference mode activated")
160-
if os.path.isfile(model_args.model_name_or_path+"/qcfg.json"):
160+
if os.path.isfile(model_args.model_name_or_path + "/qcfg.json"):
161161
if fms_mo_args.override_fms_args:
162-
logger.info("qcfg file found and some parameters are being over-written ")
163-
qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg", args=fms_mo_args)
162+
logger.info(
163+
"qcfg file found and some parameters are being over-written "
164+
)
165+
qcfg = qconfig_init(
166+
recipe=model_args.model_name_or_path + "/qcfg", args=fms_mo_args
167+
)
164168
else:
165169
logger.info("qcfg file found, loading the qcfg file ")
166-
qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg")
170+
qcfg = qconfig_init(recipe=model_args.model_name_or_path + "/qcfg")
167171
else:
168-
logger.info("qcfg file not found in {model_args.model_name_or_path},\
172+
logger.info(
173+
"qcfg file not found in {model_args.model_name_or_path},\
169174
loading fms_mo_args and recipe"
170-
)
175+
)
171176
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
172177

173178
model_size = model_size_Wb(model, unit="GB")
@@ -225,7 +230,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
225230
)
226231

227232
# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
228-
if not inference and qcfg["smoothq"] :
233+
if not inference and qcfg["smoothq"]:
229234
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
230235
if qcfg.get("act_scale_path", None):
231236
# user provided a scale file (or a dir)
@@ -295,11 +300,11 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
295300
logger.info(
296301
f"Saving model processed for vLLM and tokenizer to {opt_args.output_dir}"
297302
)
298-
save_vllm_fp8(model,qcfg,tokenizer,opt_args.output_dir)
303+
save_vllm_fp8(model, qcfg, tokenizer, opt_args.output_dir)
299304
elif opt_args.save_ckpt:
300305
logger.info(
301306
f"Saving quantized model and tokenizer to {opt_args.output_dir}"
302-
)
307+
)
303308
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
304309
tokenizer.save_pretrained(opt_args.output_dir)
305310

fms_mo/modules/linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def forward(self, x):
281281
)
282282

283283
# pylint: disable=not-callable
284-
284+
285285
return F.linear(x, self.W_fp, self.bias)
286286
else:
287287
qinput = self.quantize_feature(x / scale).to(x.dtype)
@@ -297,6 +297,7 @@ def forward(self, x):
297297
)
298298

299299
qbias = self.bias
300+
300301
# pylint: disable=not-callable
301302
output = F.linear(qinput, qweight, qbias)
302303

fms_mo/prep.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222

2323
# Third Party
2424
from torch import nn
25-
import torch
2625
import compressed_tensors
26+
import torch
27+
2728
# Local
2829
from fms_mo.calib import qmodel_calib
2930
from fms_mo.modules import QBmm_modules, QConv2d_modules, QLinear_modules, QLSTM_modules
@@ -391,14 +392,16 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
391392
# For nn.Linear
392393
elif isinstance(module, nn.Linear):
393394
if module.__class__ != nn.Linear:
394-
if isinstance(module, compressed_tensors.linear.compressed_linear.CompressedLinear):
395+
if isinstance(
396+
module, compressed_tensors.linear.compressed_linear.CompressedLinear
397+
):
395398
pass
396399
else:
397400
logger.warning(
398401
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
399-
"Please make sure it doesn't wrap BN and activ func."
400-
"Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']."
401-
)
402+
"Please make sure it doesn't wrap BN and activ func. Otherwise"
403+
"please create an equivalen Linear wrapper and change qcfg['mapping']."
404+
)
402405
QLin = mapping.get(nn.Linear, None)
403406
if QLin is None:
404407
if verbose:
@@ -572,6 +575,7 @@ def has_quantized_module(model):
572575
"""Check if model is already quantized - do not want to quantize twice if so"""
573576
return any(isinstance(m, quantized_modules) for m in model.modules())
574577

578+
575579
def swap_qbmm(model: nn.Module, qcfg: dict):
576580
"""Go through all model.named_modules(), try to create an equivalent
577581
Qbmm layer to replace each of the existing linear Bmm layers.
@@ -581,14 +585,13 @@ def swap_qbmm(model: nn.Module, qcfg: dict):
581585
qcfg (dict): quant config
582586
583587
Returns: updated model is returned with the Qbmm added
584-
588+
585589
"""
586590

591+
# Local
587592
from fms_mo.modules import QBmm
588593

589-
qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][
590-
"which2patch_contextmanager"
591-
]
594+
qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"]["which2patch_contextmanager"]
592595
isbmm = qcfg["which2patch_contextmanager"] == "torch.bmm"
593596
for mod_name, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items():
594597
mod_bmm_happened = model.get_submodule(mod_name)
@@ -608,6 +611,7 @@ def swap_qbmm(model: nn.Module, qcfg: dict):
608611
)
609612
setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm)
610613

614+
611615
def qmodel_prep(
612616
model,
613617
dloader,
@@ -696,13 +700,12 @@ def qmodel_prep(
696700
nn.Module: quantized model ready for further PTQ/QAT
697701
"""
698702
if mode:
699-
700-
if qcfg.get("QBmm"):
701-
swap_qbmm(model,qcfg)
703+
if qcfg.get("QBmm"):
704+
swap_qbmm(model, qcfg)
702705

703-
model = q_any_net_5(model, qcfg, verbose = False)
706+
model = q_any_net_5(model, qcfg, verbose=False)
704707
return model
705-
708+
706709
sys.setrecursionlimit(4000)
707710

708711
currDev = next(model.parameters()).device if dev is None else dev
@@ -952,7 +955,7 @@ def qmodel_prep(
952955
model, device_ids=DPorDDPdevices
953956
)
954957

955-
qconfig_save(qcfg, fname=qcfg["output_folder"]+"/qcfg.json")
958+
qconfig_save(qcfg, fname=qcfg["output_folder"] + "/qcfg.json")
956959
qcfg["tb_writer"] = tb_writer
957960

958961
logger.info(f"--- Quantized model --- \n{model}\n")

fms_mo/quant/quantizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def get_weight_quantizer(
237237
recompute=False,
238238
perGp=None,
239239
use_subnormal=False,
240-
emulate = True,
240+
emulate=True,
241241
):
242242
"""Return a quantizer for weight quantization
243243
Regular quantizers:

0 commit comments

Comments
 (0)