Skip to content

Commit 929477d

Browse files
committed
feat: modified fms_mo for inference
Signed-off-by: omobayode.fagbohungbe <[email protected]>
1 parent 9e2dc3e commit 929477d

File tree

5 files changed

+102
-47
lines changed

5 files changed

+102
-47
lines changed

fms_mo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
# Local
2121
from fms_mo.prep import qmodel_prep
22-
from fms_mo.utils.qconfig_utils import qconfig_init
22+
from fms_mo.utils.qconfig_utils import qconfig_init, qconfig_load
2323

2424
VERSION_FALLBACK = "0.0.0"
2525

fms_mo/dq.py

Lines changed: 64 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import torch
3636

3737
# Local
38-
from fms_mo import qconfig_init, qmodel_prep
38+
from fms_mo import qconfig_init, qmodel_prep, qconfig_load
3939
from fms_mo.fx.utils import model_size_Wb
4040
from fms_mo.quant.ptq import (
4141
calibration_llm_1GPU,
@@ -145,7 +145,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
145145
]
146146
qcfg["large_model"] = any(
147147
name in model_args.model_name_or_path for name in known_large_models
148-
) or (gpu_mem_util_per > 0.7)
148+
) or (gpu_mem_util_per > 0.1)
149149
dev = "cpu" if qcfg["large_model"] else "cuda"
150150
if model_args.device_map is None:
151151
model.to(dev)
@@ -154,6 +154,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
154154
qcfg["model_type"] = model.config.model_type
155155

156156
qcfg["model"] = model_args.model_name_or_path
157+
qcfg["qskip_large_mag_layers"] = True
157158
# config layers to skip, smooth scale
158159
config_quantize_smooth_layers(qcfg)
159160

@@ -174,6 +175,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
174175
qcfg["model"] = model_args.model_name_or_path
175176
qcfg["smoothq"] = True
176177
qcfg["plotsvg"] = False
178+
177179

178180
calibration_dataset = load_from_disk(data_args.training_data_path)
179181
calibration_dataset = calibration_dataset.with_format("torch")
@@ -184,62 +186,80 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
184186
collate_fn=default_data_collator,
185187
batch_size=1,
186188
)
187-
189+
#print(fms_mo_args)
190+
#ii
188191
# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
189-
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
190-
if qcfg.get("act_scale_path", None):
191-
# user provided a scale file (or a dir)
192-
scale_file_or_dir = Path(qcfg["act_scale_path"])
193-
if scale_file_or_dir.is_dir():
194-
scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt"
195-
elif scale_file_or_dir.is_file():
196-
scale_file = scale_file_or_dir
192+
if not fms_mo_args.inference:
193+
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
194+
if qcfg.get("act_scale_path", None):
195+
# user provided a scale file (or a dir)
196+
scale_file_or_dir = Path(qcfg["act_scale_path"])
197+
if scale_file_or_dir.is_dir():
198+
scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt"
199+
elif scale_file_or_dir.is_file():
200+
scale_file = scale_file_or_dir
197201

198-
if not scale_file.parent.exists():
199-
scale_file.parent.mkdir(exist_ok=False)
202+
if not scale_file.parent.exists():
203+
scale_file.parent.mkdir(exist_ok=False)
200204

201-
if scale_file.exists():
202-
act_scales = torch.load(scale_file, map_location=getattr(model, "device", dev))
203-
else:
204-
logger.info("Generate activation scales")
205-
if qcfg["large_model"]:
206-
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
205+
if scale_file.exists():
206+
act_scales = torch.load(scale_file, map_location=getattr(model, "device", dev))
207207
else:
208-
act_scales = get_act_scales(model, dq_dataloader, qcfg)
209-
torch.save(act_scales, scale_file)
208+
logger.info("Generate activation scales")
209+
if qcfg["large_model"]:
210+
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
211+
else:
212+
act_scales = get_act_scales(model, dq_dataloader, qcfg)
213+
torch.save(act_scales, scale_file)
214+
else:
215+
import json
216+
q_file = open('qcfg_llama.json', "r", encoding="utf-8")
217+
a = json.load(q_file)
218+
print(a)
219+
qcfg.update(a)
220+
print(qcfg)
221+
210222
qmodel_prep(
211223
model,
212224
dq_dataloader,
213225
qcfg,
214226
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
215227
use_dynamo=use_dynamo,
216228
dev=dev,
229+
mode=fms_mo_args.inference,
217230
save_fname="dq",
218231
)
219232
logger.info(f"Quantized model {model}")
220-
logger.info("Starting to apply smooth scale")
221-
dq_llm(model, act_scales, qcfg)
222-
logger.info("Finished applying smooth scale")
223-
logger.info("==" * 20)
224-
if qcfg["qmodel_calibration_new"] > 0:
225-
logger.info("Starting to calibrate activation clip_val")
226-
if qcfg["large_model"]:
227-
calibration_llm_1GPU(qcfg, model, dq_dataloader)
228-
else:
229-
model.to("cuda:0")
230-
pbar = tqdm(
231-
dq_dataloader,
232-
desc=" calibration after applying smoothq scale and before inference",
233-
total=qcfg["qmodel_calibration_new"],
234-
)
235-
for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])):
236-
data_mb = prepare_input(model.device, data_mb)
237-
with patch_torch_bmm(qcfg):
238-
model(**data_mb)
239233

240-
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
241-
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
242-
tokenizer.save_pretrained(opt_args.output_dir)
234+
if not fms_mo_args.inference:
235+
logger.info("Starting to apply smooth scale")
236+
dq_llm(model, act_scales, qcfg)
237+
logger.info("Finished applying smooth scale")
238+
logger.info("==" * 20)
239+
if qcfg["qmodel_calibration_new"] > 0:
240+
logger.info("Starting to calibrate activation clip_val")
241+
if qcfg["large_model"]:
242+
calibration_llm_1GPU(qcfg, model, dq_dataloader)
243+
else:
244+
model.to("cuda:0")
245+
pbar = tqdm(
246+
dq_dataloader,
247+
desc=" calibration after applying smoothq scale and before inference",
248+
total=qcfg["qmodel_calibration_new"],
249+
)
250+
for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])):
251+
data_mb = prepare_input(model.device, data_mb)
252+
with patch_torch_bmm(qcfg):
253+
model(**data_mb)
254+
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
255+
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
256+
tokenizer.save_pretrained(opt_args.output_dir)
257+
else:
258+
pass
259+
from accelerate import load_checkpoint_and_dispatch
260+
model = load_checkpoint_and_dispatch( model, checkpoint=opt_args.output_dir, device_map=None, no_split_module_classes=['Block'])
261+
262+
243263

244264
if fms_mo_args.eval_ppl:
245265
path_test = Path(data_args.test_data_path)
@@ -253,7 +273,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
253273

254274
logger.info(f"Model for evaluation: {model}")
255275
if qcfg["large_model"]:
256-
eval_llm_1GPU(qcfg, model, test_dataset)
276+
eval_llm_1GPU(qcfg, model.to('cpu'), test_dataset)
257277
else:
258278
model.to(torch.device("cuda:0"))
259279
n_samples = int(test_dataset.input_ids.shape[1] / block_size)

fms_mo/prep.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,30 @@ def has_quantized_module(model):
535535
"""Check if model is already quantized - do not want to quantize twice if so"""
536536
return any(isinstance(m, quantized_modules) for m in model.modules())
537537

538+
def swap_qbmm(model, qcfg):
539+
from fms_mo.modules import QBmm
540+
541+
qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][
542+
"which2patch_contextmanager"
543+
]
544+
isbmm = qcfg["which2patch_contextmanager"] == "torch.bmm"
545+
for mod_name, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items():
546+
mod_bmm_happened = model.get_submodule(mod_name)
547+
for whichQBmm, ln in enumerate(line_nums, start=1):
548+
nbits = qcfg[f"nbits_bmm{whichQBmm}"]
549+
newQBmm = QBmm(
550+
num_bits_m1=max(nbits, 8) if whichQBmm == 2 else nbits,
551+
num_bits_m2=nbits,
552+
qm1_mode=qcfg[f"bmm{whichQBmm}_qm1_mode"],
553+
qm2_mode=qcfg[f"bmm{whichQBmm}_qm2_mode"],
554+
m1_unidirectional=(whichQBmm == 2),
555+
m1_bounded=(whichQBmm == 2), # see Note 5
556+
m2_unidirectional=False,
557+
m2_bounded=False,
558+
replaceBmm=isbmm,
559+
qcfg=qcfg,
560+
)
561+
setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm)
538562

539563
def qmodel_prep(
540564
model,
@@ -548,6 +572,7 @@ def qmodel_prep(
548572
dev=None,
549573
use_dynamo=False,
550574
verbose=False,
575+
mode=False,
551576
**kwargs,
552577
):
553578
"""Prepare a given PyTorch model for quantization process through three parts:
@@ -622,7 +647,15 @@ def qmodel_prep(
622647
Returns:
623648
nn.Module: quantized model ready for further PTQ/QAT
624649
"""
650+
if mode:
651+
652+
if qcfg.get("QBmm"):
653+
pass
654+
swap_qbmm(model,qcfg)
625655

656+
model = q_any_net_5(model, qcfg, verbose = False)
657+
return model
658+
626659
sys.setrecursionlimit(4000)
627660

628661
currDev = next(model.parameters()).device if dev is None else dev
@@ -869,7 +902,7 @@ def qmodel_prep(
869902
model, device_ids=DPorDDPdevices
870903
)
871904

872-
qconfig_save(qcfg, fname="qcfg.json")
905+
qconfig_save(qcfg, fname="qcfg2.json")
873906
qcfg["tb_writer"] = tb_writer
874907

875908
logger.info(f"--- Quantized model --- \n{model}\n")

fms_mo/training_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ class FMSMOArguments(TypeChecker):
173173
default=2048, metadata={"help": "input sequence length after tokenization"}
174174
)
175175
eval_ppl: bool = field(default=False)
176+
inference: bool = field(default=False)
177+
176178

177179

178180
@dataclass

fms_mo/utils/qconfig_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def qconfig_save(
623623
def qconfig_load(fname: str = "qcfg.json") -> dict:
624624
"""Read config in json format, work together with qconfig_save"""
625625
config = get_recipe(fname)
626-
626+
627627
if config:
628628
# Check that loaded file is a dict
629629
if not isinstance(config, dict):

0 commit comments

Comments
 (0)