Skip to content

Commit 0b5d68a

Browse files
committed
feat: enable fast loading and vllm format saving functionality in fms_mo
Signed-off-by: Omobayode Fagbohungbe <[email protected]>
1 parent e9874ef commit 0b5d68a

File tree

6 files changed

+172
-98
lines changed

6 files changed

+172
-98
lines changed

fms_mo/dq.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright The FMS Model Optimizer Authors
2-
#
2+
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
6-
#
6+
77
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
8+
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -34,6 +34,7 @@
3434
)
3535
import torch
3636

37+
import os
3738
# Local
3839
from fms_mo import qconfig_init, qmodel_prep
3940
from fms_mo.custom_ext_kernels.utils import (
@@ -50,8 +51,11 @@
5051
from fms_mo.utils.dq_utils import config_quantize_smooth_layers
5152
from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU
5253
from fms_mo.utils.utils import patch_torch_bmm, prepare_input
53-
from fms_mo.utils.dq_inf import load_fp8_vllm, save_vllm_fp8
54-
from accelerate import load_checkpoint_and_dispatch
54+
from fms_mo.utils.dq_inf import (
55+
save_vllm_fp8,
56+
convert_fp8_vllm_to_fms_mo,
57+
check_quantization_setting,
58+
)
5559

5660
logger = logging.getLogger(__name__)
5761

@@ -129,18 +133,42 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
129133
low_cpu_mem_usage=bool(model_args.device_map),
130134
)
131135

136+
inference= model.config.to_dict().get("quantization_config",None)
137+
138+
if inference:
139+
quant_setting = check_quantization_setting(inference)
140+
if quant_setting:
141+
logger.info("Quantization config settings validated ")
142+
model = convert_fp8_vllm_to_fms_mo(model = model)
143+
else:
144+
exit("__This quantization config is wrong/not supported__")
145+
146+
132147
embedding_size = model.get_input_embeddings().weight.shape[0]
133148
if len(tokenizer) > embedding_size:
134149
model.resize_token_embeddings(len(tokenizer))
135150

136151
logger.info(f"Initialized model is: \n {model}")
137152
logger.info(f"Model is at {model.device} after intialization")
138153
logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}")
139-
140-
if not fms_mo_args.inference or fms_mo_args.vllm_fp8_load:
154+
155+
if not inference:
156+
logger.info("quantization mode activated, initalizing the qcfg file ")
141157
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
142158
else:
143-
qcfg = qconfig_init(recipe=opt_args.output_dir+"/qcfg")
159+
logger.info("inference mode activated")
160+
if os.path.isfile(model_args.model_name_or_path+"/qcfg.json"):
161+
if fms_mo_args.override_fms_args:
162+
logger.info("qcfg file found and some parameters are being over-written ")
163+
qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg", args=fms_mo_args)
164+
else:
165+
logger.info("qcfg file found, loading the qcfg file ")
166+
qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg")
167+
else:
168+
logger.info("qcfg file not found in {model_args.model_name_or_path},\
169+
loading fms_mo_args and recipe"
170+
)
171+
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
144172

145173
model_size = model_size_Wb(model, unit="GB")
146174
gpu_mem_util_per = model_size / total_gpu_memory
@@ -184,6 +212,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
184212
qcfg["model"] = model_args.model_name_or_path
185213
qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0 and "mx_specs" not in qcfg
186214
qcfg["plotsvg"] = False
215+
qcfg["output_folder"] = opt_args.output_dir
187216

188217
calibration_dataset = load_from_disk(data_args.training_data_path)
189218
calibration_dataset = calibration_dataset.with_format("torch")
@@ -196,7 +225,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
196225
)
197226

198227
# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
199-
if not fms_mo_args.inference and qcfg["smoothq"] :
228+
if not inference and qcfg["smoothq"] :
200229
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
201230
if qcfg.get("act_scale_path", None):
202231
# user provided a scale file (or a dir)
@@ -230,14 +259,12 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
230259
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
231260
use_dynamo=use_dynamo,
232261
dev=dev,
233-
mode=fms_mo_args.inference,
262+
mode=inference,
234263
save_fname="dq",
235-
folder=opt_args.output_dir,
236264
)
237265
logger.info(f"Quantized model {model}")
238266
logger.info("==" * 20)
239-
240-
if not fms_mo_args.inference:
267+
if not inference:
241268
if qcfg["smoothq"]:
242269
logger.info("Starting to apply smooth scale")
243270
dq_llm(model, act_scales, qcfg)
@@ -264,7 +291,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
264291
f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}"
265292
)
266293
save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True)
267-
elif opt_args.save_ckpt_for_vllm:
294+
elif not opt_args.save_ckpt:
268295
logger.info(
269296
f"Saving model processed for vLLM and tokenizer to {opt_args.output_dir}"
270297
)
@@ -287,19 +314,6 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
287314
clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8",
288315
# layer_to_exclude=["lm_head",]
289316
)
290-
else:
291-
if fms_mo_args.vllm_fp8_load:
292-
logger.info("loading llmcompressor fp8 model saved_checkpoint")
293-
model = load_fp8_vllm( model=model, checkpoint=opt_args.output_dir)
294-
295-
else:
296-
logger.info("loading dq fms_mo fp8 model saved_checkpoint")
297-
model = load_checkpoint_and_dispatch(
298-
model,
299-
checkpoint=opt_args.output_dir,
300-
device_map=None,
301-
no_split_module_classes=['Block']
302-
)
303317

304318
if fms_mo_args.eval_ppl:
305319
path_test = Path(data_args.test_data_path)

fms_mo/modules/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def forward(self, x):
281281
)
282282

283283
# pylint: disable=not-callable
284+
284285
return F.linear(x, self.W_fp, self.bias)
285286
else:
286287
qinput = self.quantize_feature(x / scale).to(x.dtype)
@@ -296,7 +297,6 @@ def forward(self, x):
296297
)
297298

298299
qbias = self.bias
299-
300300
# pylint: disable=not-callable
301301
output = F.linear(qinput, qweight, qbias)
302302

fms_mo/prep.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
# Third Party
2424
from torch import nn
2525
import torch
26-
26+
import compressed_tensors
2727
# Local
2828
from fms_mo.calib import qmodel_calib
2929
from fms_mo.modules import QBmm_modules, QConv2d_modules, QLinear_modules, QLSTM_modules
@@ -391,12 +391,14 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
391391
# For nn.Linear
392392
elif isinstance(module, nn.Linear):
393393
if module.__class__ != nn.Linear:
394-
logger.warning(
395-
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
396-
"Please make sure it doesn't wrap BN and activ func."
397-
"Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']."
398-
)
399-
394+
if isinstance(module, compressed_tensors.linear.compressed_linear.CompressedLinear):
395+
pass
396+
else:
397+
logger.warning(
398+
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
399+
"Please make sure it doesn't wrap BN and activ func."
400+
"Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']."
401+
)
400402
QLin = mapping.get(nn.Linear, None)
401403
if QLin is None:
402404
if verbose:
@@ -571,8 +573,8 @@ def has_quantized_module(model):
571573
return any(isinstance(m, quantized_modules) for m in model.modules())
572574

573575
def swap_qbmm(model: nn.Module, qcfg: dict):
574-
"""Go through all model.named_modules(), try to create an equivalent Qbmm layer to replace each of
575-
the existing linear Bmm layers.
576+
"""Go through all model.named_modules(), try to create an equivalent
577+
Qbmm layer to replace each of the existing linear Bmm layers.
576578
577579
Args:
578580
model (nn.Module): input model to be "prepared"
@@ -605,7 +607,7 @@ def swap_qbmm(model: nn.Module, qcfg: dict):
605607
qcfg=qcfg,
606608
)
607609
setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm)
608-
610+
609611
def qmodel_prep(
610612
model,
611613
dloader,
@@ -619,7 +621,6 @@ def qmodel_prep(
619621
use_dynamo=False,
620622
mode=False,
621623
verbose=False,
622-
folder=None,
623624
**kwargs,
624625
):
625626
"""Prepare a given PyTorch model for quantization process through three parts:
@@ -951,7 +952,7 @@ def qmodel_prep(
951952
model, device_ids=DPorDDPdevices
952953
)
953954

954-
qconfig_save(qcfg, fname=folder+"/qcfg.json")
955+
qconfig_save(qcfg, fname=qcfg["output_folder"]+"/qcfg.json")
955956
qcfg["tb_writer"] = tb_writer
956957

957958
logger.info(f"--- Quantized model --- \n{model}\n")

fms_mo/quant/quantizers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def get_weight_quantizer(
237237
recompute=False,
238238
perGp=None,
239239
use_subnormal=False,
240+
emulate = True,
240241
):
241242
"""Return a quantizer for weight quantization
242243
Regular quantizers:
@@ -346,7 +347,7 @@ def get_weight_quantizer(
346347
weight_quantizer = to_fp8(
347348
nbits,
348349
q_mode=qw_mode,
349-
emulate=True,
350+
emulate=emulate,
350351
perCh=Nch,
351352
)
352353
else:

fms_mo/training_args.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,6 @@ class OptArguments(TypeChecker):
160160
default=False,
161161
metadata={"help": "Prepare and save AIU-compliant checkpoint."},
162162
)
163-
save_ckpt_for_vllm: bool = field(
164-
default=False,
165-
metadata={"help": "Prepare and save vllm-compliant checkpoint."},
166-
)
167163

168164

169165
@dataclass
@@ -214,8 +210,7 @@ class FMSMOArguments(TypeChecker):
214210
metadata={"help": "Apply recomputation during checkpoint saving for AIU."},
215211
)
216212
fp8_use_subnormal: bool = field(default=False)
217-
inference: bool = field(default=False)
218-
vllm_fp8_load: bool = field(default=False)
213+
override_fms_args: bool = field(default=False)
219214

220215

221216
@dataclass

0 commit comments

Comments
 (0)