Skip to content

Commit b137e0c

Browse files
committed
feat: fast model inference
Signed-off-by: Omobayode Fagbohungbe <[email protected]>
1 parent f9ca98a commit b137e0c

File tree

6 files changed

+256
-43
lines changed

6 files changed

+256
-43
lines changed

fms_mo/dq.py

Lines changed: 70 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
from fms_mo.utils.dq_utils import config_quantize_smooth_layers
5151
from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU
5252
from fms_mo.utils.utils import patch_torch_bmm, prepare_input
53+
from fms_mo.utils.dq_inf import load_fp8_vllm, save_vllm_fp8
54+
from accelerate import load_checkpoint_and_dispatch
5355

5456
logger = logging.getLogger(__name__)
5557

@@ -134,7 +136,11 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
134136
logger.info(f"Initialized model is: \n {model}")
135137
logger.info(f"Model is at {model.device} after intialization")
136138
logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}")
137-
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
139+
140+
if not fms_mo_args.inference or fms_mo_args.vllm_fp8_load:
141+
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
142+
else:
143+
qcfg = qconfig_init(recipe=opt_args.output_dir+"/qcfg")
138144

139145
model_size = model_size_Wb(model, unit="GB")
140146
gpu_mem_util_per = model_size / total_gpu_memory
@@ -190,7 +196,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
190196
)
191197

192198
# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
193-
if qcfg["smoothq"]:
199+
if not fms_mo_args.inference and qcfg["smoothq"] :
194200
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
195201
if qcfg.get("act_scale_path", None):
196202
# user provided a scale file (or a dir)
@@ -224,53 +230,76 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
224230
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
225231
use_dynamo=use_dynamo,
226232
dev=dev,
233+
mode=fms_mo_args.inference,
227234
save_fname="dq",
235+
folder=opt_args.output_dir,
228236
)
229237
logger.info(f"Quantized model {model}")
230238
logger.info("==" * 20)
231239

232-
if qcfg["smoothq"]:
233-
logger.info("Starting to apply smooth scale")
234-
dq_llm(model, act_scales, qcfg)
235-
logger.info("Finished applying smooth scale")
240+
if not fms_mo_args.inference:
241+
if qcfg["smoothq"]:
242+
logger.info("Starting to apply smooth scale")
243+
dq_llm(model, act_scales, qcfg)
244+
logger.info("Finished applying smooth scale")
245+
246+
if qcfg["qmodel_calibration_new"] > 0:
247+
logger.info("Starting to calibrate activation clip_val")
248+
if qcfg["large_model"]:
249+
calibration_llm_1GPU_v2(qcfg, model, dq_dataloader)
250+
else:
251+
model.to("cuda")
252+
pbar = tqdm(
253+
dq_dataloader,
254+
desc=" calibration after applying smoothq scale and before inference",
255+
total=qcfg["qmodel_calibration_new"],
256+
)
257+
for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])):
258+
data_mb = prepare_input(model.device, data_mb)
259+
with patch_torch_bmm(qcfg):
260+
model(**data_mb)
261+
262+
if opt_args.save_ckpt_for_aiu:
263+
logger.info(
264+
f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}"
265+
)
266+
save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True)
267+
elif opt_args.save_ckpt_for_vllm:
268+
logger.info(
269+
f"Saving model processed for vLLM and tokenizer to {opt_args.output_dir}"
270+
)
271+
save_vllm_fp8(model,qcfg,tokenizer,opt_args.output_dir)
272+
elif opt_args.save_ckpt:
273+
logger.info(
274+
f"Saving quantized model and tokenizer to {opt_args.output_dir}"
275+
)
276+
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
277+
tokenizer.save_pretrained(opt_args.output_dir)
278+
279+
if fms_mo_args.aiu_sim_triton:
280+
# NOTE plz apply correct HW settings here, defaults are not real HW params
281+
lower_qmodel_triton(
282+
model,
283+
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
284+
max_acc_bits=qcfg.get("max_acc_bits", 32),
285+
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
286+
chunk_size=qcfg.get("chunk_size", 32), # 1024
287+
clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8",
288+
# layer_to_exclude=["lm_head",]
289+
)
290+
else:
291+
if fms_mo_args.vllm_fp8_load:
292+
logger.info("loading llmcompressor fp8 model saved_checkpoint")
293+
model = load_fp8_vllm( model=model, checkpoint=opt_args.output_dir)
236294

237-
if qcfg["qmodel_calibration_new"] > 0:
238-
logger.info("Starting to calibrate activation clip_val")
239-
if qcfg["large_model"]:
240-
calibration_llm_1GPU_v2(qcfg, model, dq_dataloader)
241295
else:
242-
model.to("cuda")
243-
pbar = tqdm(
244-
dq_dataloader,
245-
desc=" calibration after applying smoothq scale and before inference",
246-
total=qcfg["qmodel_calibration_new"],
296+
logger.info("loading dq fms_mo fp8 model saved_checkpoint")
297+
model = load_checkpoint_and_dispatch(
298+
model,
299+
checkpoint=opt_args.output_dir,
300+
device_map=None,
301+
no_split_module_classes=['Block']
247302
)
248-
for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])):
249-
data_mb = prepare_input(model.device, data_mb)
250-
with patch_torch_bmm(qcfg):
251-
model(**data_mb)
252-
253-
if opt_args.save_ckpt_for_aiu:
254-
logger.info(
255-
f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}"
256-
)
257-
save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True)
258-
elif opt_args.save_ckpt:
259-
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
260-
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
261-
tokenizer.save_pretrained(opt_args.output_dir)
262-
263-
if fms_mo_args.aiu_sim_triton:
264-
# NOTE plz apply correct HW settings here, defaults are not real HW params
265-
lower_qmodel_triton(
266-
model,
267-
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
268-
max_acc_bits=qcfg.get("max_acc_bits", 32),
269-
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
270-
chunk_size=qcfg.get("chunk_size", 32), # 1024
271-
clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8",
272-
# layer_to_exclude=["lm_head",]
273-
)
274303

275304
if fms_mo_args.eval_ppl:
276305
path_test = Path(data_args.test_data_path)

fms_mo/prep.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,42 @@ def has_quantized_module(model):
570570
"""Check if model is already quantized - do not want to quantize twice if so"""
571571
return any(isinstance(m, quantized_modules) for m in model.modules())
572572

573+
def swap_qbmm(model: nn.Module, qcfg: dict):
574+
"""Go through all model.named_modules(), try to create an equivalent Qbmm layer to replace each of
575+
the existing linear Bmm layers.
573576
577+
Args:
578+
model (nn.Module): input model to be "prepared"
579+
qcfg (dict): quant config
580+
581+
Returns: updated model is returned with the Qbmm added
582+
583+
"""
584+
585+
from fms_mo.modules import QBmm
586+
587+
qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][
588+
"which2patch_contextmanager"
589+
]
590+
isbmm = qcfg["which2patch_contextmanager"] == "torch.bmm"
591+
for mod_name, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items():
592+
mod_bmm_happened = model.get_submodule(mod_name)
593+
for whichQBmm, ln in enumerate(line_nums, start=1):
594+
nbits = qcfg[f"nbits_bmm{whichQBmm}"]
595+
newQBmm = QBmm(
596+
num_bits_m1=max(nbits, 8) if whichQBmm == 2 else nbits,
597+
num_bits_m2=nbits,
598+
qm1_mode=qcfg[f"bmm{whichQBmm}_qm1_mode"],
599+
qm2_mode=qcfg[f"bmm{whichQBmm}_qm2_mode"],
600+
m1_unidirectional=(whichQBmm == 2),
601+
m1_bounded=(whichQBmm == 2), # see Note 5
602+
m2_unidirectional=False,
603+
m2_bounded=False,
604+
replaceBmm=isbmm,
605+
qcfg=qcfg,
606+
)
607+
setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm)
608+
574609
def qmodel_prep(
575610
model,
576611
dloader,
@@ -582,7 +617,9 @@ def qmodel_prep(
582617
Qcali=False,
583618
dev=None,
584619
use_dynamo=False,
620+
mode=False,
585621
verbose=False,
622+
folder=None,
586623
**kwargs,
587624
):
588625
"""Prepare a given PyTorch model for quantization process through three parts:
@@ -657,7 +694,14 @@ def qmodel_prep(
657694
Returns:
658695
nn.Module: quantized model ready for further PTQ/QAT
659696
"""
697+
if mode:
698+
699+
if qcfg.get("QBmm"):
700+
swap_qbmm(model,qcfg)
660701

702+
model = q_any_net_5(model, qcfg, verbose = False)
703+
return model
704+
661705
sys.setrecursionlimit(4000)
662706

663707
currDev = next(model.parameters()).device if dev is None else dev
@@ -907,7 +951,7 @@ def qmodel_prep(
907951
model, device_ids=DPorDDPdevices
908952
)
909953

910-
qconfig_save(qcfg, fname="qcfg.json")
954+
qconfig_save(qcfg, fname=folder+"/qcfg.json")
911955
qcfg["tb_writer"] = tb_writer
912956

913957
logger.info(f"--- Quantized model --- \n{model}\n")

fms_mo/recipes/quant.json

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
{
2+
"quantization_config": {
3+
"config_groups": {
4+
"group_0": {
5+
"input_activations": {
6+
"actorder": null,
7+
"block_structure": null,
8+
"dynamic": true,
9+
"group_size": null,
10+
"num_bits": 8,
11+
"observer": null,
12+
"observer_kwargs": {},
13+
"strategy": "token",
14+
"symmetric": true,
15+
"type": "float"
16+
},
17+
"output_activations": null,
18+
"targets": [
19+
"Linear"
20+
],
21+
"weights": {
22+
"actorder": null,
23+
"block_structure": null,
24+
"dynamic": false,
25+
"group_size": null,
26+
"num_bits": 8,
27+
"observer": "minmax",
28+
"observer_kwargs": {},
29+
"strategy": "channel",
30+
"symmetric": true,
31+
"type": "float"
32+
}
33+
}
34+
},
35+
"format": "float-quantized",
36+
"global_compression_ratio": null,
37+
"ignore": [
38+
"lm_head"
39+
],
40+
"kv_cache_scheme": null,
41+
"quant_method": "compressed-tensors",
42+
"quantization_status": "compressed"
43+
}
44+
}

fms_mo/training_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ class OptArguments(TypeChecker):
160160
default=False,
161161
metadata={"help": "Prepare and save AIU-compliant checkpoint."},
162162
)
163+
save_ckpt_for_vllm: bool = field(
164+
default=False,
165+
metadata={"help": "Prepare and save vllm-compliant checkpoint."},
166+
)
163167

164168

165169
@dataclass
@@ -209,6 +213,9 @@ class FMSMOArguments(TypeChecker):
209213
default=False,
210214
metadata={"help": "Apply recomputation during checkpoint saving for AIU."},
211215
)
216+
fp8_use_subnormal: bool = field(default=False)
217+
inference: bool = field(default=False)
218+
vllm_fp8_load: bool = field(default=False)
212219

213220

214221
@dataclass

fms_mo/utils/dq_inf.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
import fms_mo
3+
from fms_mo.quant.quantizers import to_fp8_scaled_perCh as fp8
4+
from huggingface_hub import save_torch_state_dict
5+
import json
6+
import os
7+
import glob
8+
from fms_mo.utils.qconfig_utils import get_recipe
9+
from safetensors.torch import load_file, save_file
10+
from torch import nn
11+
12+
def save_vllm_fp8(model: nn.Module, qcfg: dict, tokenizer = None, folder: str = None):
13+
"""
14+
Function to save fms_mo fp8 checkpoint in vllm fp8 format
15+
"""
16+
17+
st_dict={}
18+
19+
for k,v in model.state_dict().items():
20+
if k[-11:] == "proj.weight":
21+
weight, scale = fp8(v,emulate=False)
22+
st_dict[k]= weight
23+
24+
if k[:-7] in qcfg["qskip_layer_name"]:
25+
pass
26+
else:
27+
st_dict[k + "_scale"] = 1/scale
28+
29+
elif k[-6:] == "weight":
30+
st_dict[k]=v
31+
else:
32+
pass
33+
34+
config = model.config.to_dict()
35+
36+
#TO DO: To support multiple recipes, check qconfig arguments and update data loaded from quant.json
37+
data = get_recipe('quant')
38+
39+
config.update(data)
40+
41+
save_torch_state_dict(st_dict, folder)
42+
43+
tokenizer.save_pretrained(folder)
44+
45+
with open(folder+'/config.json', 'a') as f:
46+
json.dump(config, f, indent=4)
47+
48+
49+
50+
def find_file_glob(pattern: str , search_path: str):
51+
"""
52+
Finds files matching a pattern within a directory and its subdirectories.
53+
"""
54+
# Use '**' for recursive search in modern Python versions (3.5+)
55+
full_pattern = os.path.join(search_path, '**', pattern)
56+
found_files = glob.glob(full_pattern, recursive=True)
57+
return sorted(found_files)
58+
59+
def load_fp8_vllm(model: nn.Module = None, checkpoint: str=None):
60+
"""
61+
Function to help load vllm fp8 checkpoint into fms_mo
62+
"""
63+
64+
merged_files_dict={}
65+
66+
files = find_file_glob('*.safetensors',checkpoint)
67+
68+
model_dict = model.state_dict()
69+
70+
for file in files:
71+
merged_files_dict = load_file(file)
72+
73+
for k,v in merged_files_dict.items():
74+
75+
if k[-11:] == "proj.weight":
76+
scale = merged_files_dict[k+ "_scale"].reshape(-1,1)
77+
model_dict[k]= merged_files_dict[k].to(torch.float16) * scale
78+
79+
elif k[-6:] == "weight":
80+
model_dict[k]=v
81+
82+
else:
83+
pass
84+
85+
return model
86+
87+
88+
89+

fms_mo/utils/dq_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def config_quantize_smooth_layers(qcfg: dict):
7474
for llama_family, layers in large_mag_layers.items():
7575
if llama_family in qcfg["model"]:
7676
qcfg["qskip_layer_name"] += [
77-
f"model.layers.{i}.mlp.down_proj" for i in layers
77+
f"model.layers.{i}.mlp.down_projj" for i in layers
7878
]
7979
break
8080
elif any(model in qcfg["model"] for model in granite_architecture) or any(

0 commit comments

Comments
 (0)