Skip to content

Commit 362d521

Browse files
fix dq issue with llama3-70b on single gpu
Signed-off-by: cliu-us <[email protected]>
1 parent 8e3f16a commit 362d521

File tree

4 files changed

+148
-37
lines changed

4 files changed

+148
-37
lines changed

fms_mo/dq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from fms_mo import qconfig_init, qmodel_prep
3939
from fms_mo.fx.utils import model_size_Wb
4040
from fms_mo.quant.ptq import (
41-
calibration_llm_1GPU,
41+
calibration_llm_1GPU_v2,
4242
dq_llm,
4343
get_act_scales,
4444
get_act_scales_1gpu,
@@ -224,9 +224,9 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
224224
if qcfg["qmodel_calibration_new"] > 0:
225225
logger.info("Starting to calibrate activation clip_val")
226226
if qcfg["large_model"]:
227-
calibration_llm_1GPU(qcfg, model, dq_dataloader)
227+
calibration_llm_1GPU_v2(qcfg, model, dq_dataloader)
228228
else:
229-
model.to("cuda:0")
229+
model.to("cuda")
230230
pbar = tqdm(
231231
dq_dataloader,
232232
desc=" calibration after applying smoothq scale and before inference",

fms_mo/prep.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,10 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
177177
is mappable, create a Qmodule and return, otherwise, return the original module. In the future,
178178
Qmodules need to have a .from_torch() or .from_nn() classmethod, and then this function will be
179179
greatly simplified.
180-
NOTE: This func will check qskip_layer_name before creating the Qmodule
180+
NOTE:
181+
1. This func will check qskip_layer_name before creating the Qmodule
182+
2. Qmodule will be created on "meta device" as a placeholder, which will skip params init and
183+
mem alloc, as weights and bias will be reassigned to module.weight/.bias right after
181184
182185
Args:
183186
module (nn.Module): the module which Qmodule will be based on
@@ -216,7 +219,7 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
216219
if hasattr(module, "__constants__"):
217220
base_params = {k: getattr(module, k) for k in module.__constants__}
218221
base_params["bias"] = module.bias is not None
219-
base_params["device"] = next(module.parameters()).device # usually cuda
222+
base_params["device"] = "meta"
220223

221224
module_output = module
222225

@@ -499,8 +502,17 @@ def q_any_net_5(model: nn.Module, qcfg: dict, verbose: bool = False):
499502
"""
500503
# Third Party
501504
from torch.ao.quantization.utils import _parent_name
505+
from tqdm import tqdm
506+
507+
total_modules = len(list(model.named_modules()))
508+
pbar = tqdm(
509+
model.named_modules(),
510+
total=total_modules,
511+
desc="Mapping modules to target Qmodules.",
512+
)
513+
for name, module in pbar:
514+
pbar.set_description(f"processing {name}")
502515

503-
for name, module in model.named_modules():
504516
parent_module_name, curr_mod_name = _parent_name(name)
505517
new_module = make_quant_module(module, name, qcfg)
506518
parent_module = model.get_submodule(parent_module_name)
@@ -525,6 +537,7 @@ def q_any_net_5(model: nn.Module, qcfg: dict, verbose: bool = False):
525537
if verbose:
526538
logger.info(f"Swap ({name}) from {type(module)} to {type(new_module)}")
527539

540+
pbar.close()
528541
return model
529542

530543

fms_mo/quant/ptq.py

Lines changed: 110 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -419,15 +419,25 @@ class PTQHookRecInOutLMv2(nn.Module):
419419
leave the special handling, e.g. reshape/cat/shuffling...etc, for later
420420
"""
421421

422-
def __init__(self, qcfg, name=None, cls2rec=(nn.Conv2d,), recInOnly=False):
422+
def __init__(
423+
self,
424+
qcfg,
425+
name=None,
426+
cls2rec=(nn.Conv2d, nn.Linear),
427+
recInOnly=False,
428+
stop_after_rec=False,
429+
cache_dev="cuda",
430+
):
423431
super().__init__()
424432
self.name = name
425433
self.qcfg = qcfg
426434
self.cls2rec = cls2rec
427435
self.rec_input_only = recInOnly
428436
self.num_valid_input = -1
437+
self.stop_after_rec = stop_after_rec
438+
self.cache_dev = cache_dev
429439

430-
def __call__(self, mod, inputs, output):
440+
def __call__(self, mod, inputs, *args, **_kwargs):
431441
# make sure this module/block's ptqmode is not 'q_out'
432442
submods = [m for m in mod.modules() if isinstance(m, self.cls2rec)]
433443
if any(sm.ptqmode == "q_out" for sm in submods):
@@ -448,7 +458,7 @@ def __call__(self, mod, inputs, output):
448458
# check available GPU memory, cache on GPU if possible:
449459
GPUmem_available, _GPUmem_total = torch.cuda.mem_get_info()
450460
# 1 block for SQUAD/BERT 500 batches*12/batch = ~10G
451-
if GPUmem_available / 1e9 > 20:
461+
if self.cache_dev == "cuda" and GPUmem_available / 1e9 > 20:
452462
cache_device = "cuda"
453463
else:
454464
cache_device = "cpu"
@@ -461,13 +471,15 @@ def __call__(self, mod, inputs, output):
461471
)
462472

463473
# output could be a tuple of a single tensor or simply a tensor ?
464-
assert isinstance(output, (torch.Tensor, tuple))
465-
if not self.rec_input_only:
474+
if not self.rec_input_only and "output" in args:
475+
output = args["output"]
476+
assert isinstance(output, (torch.Tensor, tuple))
466477
self.qcfg["cached_output"].append(
467478
output[0].detach().to(cache_device)
468479
if isinstance(output, tuple)
469480
else output.detach().to(cache_device)
470481
)
482+
assert not self.stop_after_rec
471483

472484

473485
# this hook is meant for ptq_loss_func == 'fisher_diag' and to temp hold the "Q_out" of the module
@@ -2021,7 +2033,7 @@ def get_blocks(model, model_type=None):
20212033
"llama": (
20222034
"model.layers",
20232035
"model.embed_tokens",
2024-
None,
2036+
"model.rotary_emb",
20252037
None,
20262038
"model.norm",
20272039
"lm_head",
@@ -2111,20 +2123,16 @@ def cache_block0_inputs(
21112123
model, dloader, qcfg, blocks, emb=None, emb_pos=None, emb_ln=None, dev="cpu"
21122124
):
21132125
"""
2114-
To cache the input to the first transformer block.
2126+
To cache the input to the first transformer block. Basically a "forward_pre_hook"
2127+
NOTE, change caching from tensor to list to allow varying input length, slightly
2128+
increase memeory due to mask and alibi.
21152129
"""
21162130
emb = emb.to(dev)
21172131
if emb_pos is not None:
21182132
emb_pos.to(dev)
21192133
if emb_ln is not None:
21202134
emb_ln = emb_ln.to(dev)
21212135
blocks[0] = blocks[0].to(dev)
2122-
# NOTE, change caching from tensor to list to allow varying input length, slightly
2123-
# increase memeory due to mask and alibi.
2124-
qcfg["cached_block0_input"] = []
2125-
qcfg["cache_id"] = 0
2126-
qcfg["cached_mask"] = []
2127-
qcfg["cached_alibi"] = []
21282136
# move block0 to GPU and excuting fwd() until finish block0
21292137
if "fms" in qcfg["model_type"]:
21302138
qcfg["kw_to_cache"] = {
@@ -2142,9 +2150,16 @@ def cache_block0_inputs(
21422150
}
21432151
blocks[0] = RunModule(blocks[0], qcfg)
21442152

2153+
# clear up old cache, if exists.
2154+
qcfg["cached_block0_input"] = []
2155+
qcfg["cache_id"] = 0
2156+
for kw in qcfg["kw_to_cache"].values():
2157+
if kw in qcfg:
2158+
qcfg[kw] = []
2159+
21452160
if isinstance(dloader, torch.utils.data.DataLoader):
21462161
pbar = tqdm(
2147-
dloader, desc="Phase 0: PTQ caching block0 input", total=qcfg["ptq_nbatch"]
2162+
dloader, desc="Phase 0: Caching block0 inputs", total=qcfg["ptq_nbatch"]
21482163
)
21492164
for data_mb, _ in zip(pbar, range(qcfg["ptq_nbatch"])):
21502165
try:
@@ -2310,9 +2325,8 @@ def freeze_layers(m, layer_list):
23102325

23112326
@torch.no_grad()
23122327
def calibration_llm_1GPU(qcfg, model, dloader):
2313-
"""
2314-
calibration for large models that can not fit the whole model on 1 GPU.
2315-
"""
2328+
"""Calibration for large models that can not fit on 1 GPU."""
2329+
23162330
model.train()
23172331
dev = "cuda"
23182332
qcfg["batch_size"] = 1
@@ -2365,6 +2379,83 @@ def calibration_llm_1GPU(qcfg, model, dloader):
23652379
logger.info("All blocks are calibrated")
23662380

23672381

2382+
@torch.no_grad()
2383+
def calibration_llm_1GPU_v2(qcfg, model, dloader):
2384+
"""
2385+
Improved version of Calibration for large language models that can not fit on 1 GPU with new
2386+
(built-in) calibration mechanism.
2387+
NOTE:
2388+
1. Calibration only, NO update to weights!
2389+
2. Rely on a alternative "pre fwd hook" to cache all possible inputs.
2390+
3. As calibration usually cache a small number of data only, no need to move each batch back and
2391+
forth between GPU and CPU.
2392+
"""
2393+
2394+
model.train()
2395+
dev = "cuda"
2396+
qcfg["batch_size"] = 1
2397+
qcfg["dtype"] = next(iter(model.parameters())).dtype
2398+
qcfg["n_samples"] = min(qcfg["ptq_nbatch"], qcfg["qmodel_calibration_new"])
2399+
2400+
assert "model_type" in qcfg, "Unknown model type. please check before proceed."
2401+
assert isinstance(
2402+
dloader, torch.utils.data.DataLoader
2403+
), "Please provide a valid dataloader."
2404+
# --- Phase 0 cache the inputs of the block0---
2405+
model.config.use_cache = False
2406+
blocks, emb, emb_pos, emb_ln, _, _ = get_blocks(model, qcfg["model_type"])
2407+
2408+
cache_block0_inputs(
2409+
model,
2410+
dloader,
2411+
qcfg,
2412+
blocks,
2413+
emb=emb,
2414+
emb_pos=emb_pos,
2415+
emb_ln=emb_ln,
2416+
dev="cpu",
2417+
)
2418+
logger.info("Done, caching inputs to block0 for calibration")
2419+
2420+
# --- Phase 1 --- compute blocks and last linear layer
2421+
pbar = tqdm(
2422+
blocks, desc="Phase 1: Calibration for each block", position=0, leave=True
2423+
)
2424+
qcfg["cached_input"] = [
2425+
inp.clone().detach().to(dev) for inp in qcfg["cached_block0_input"]
2426+
]
2427+
kw_to_use = {
2428+
kw_org: kw_new
2429+
for kw_org, kw_new in qcfg["kw_to_cache"].items()
2430+
if len(qcfg[kw_new]) == len(qcfg["cached_input"])
2431+
}
2432+
for _num_block, m in enumerate(pbar):
2433+
m.to(dev)
2434+
for i in tqdm(
2435+
range(qcfg["n_samples"]), desc="number of samples", position=1, leave=False
2436+
):
2437+
if qcfg["cached_alibi"]:
2438+
cached_inp_prev_lay = qcfg["cached_input"][i].unsqueeze(0).to(dev)
2439+
data_mb = {
2440+
"attention_mask": qcfg["cached_mask"][i].unsqueeze(0).to(dev),
2441+
"alibi": qcfg["cached_alibi"][i].unsqueeze(0).to(dev),
2442+
}
2443+
else:
2444+
cached_inp_prev_lay = qcfg["cached_input"][i]
2445+
data_mb = {
2446+
kw_org: move_to(qcfg[kw_new][i], dev)
2447+
for kw_org, kw_new in kw_to_use.items()
2448+
}
2449+
2450+
with patch_torch_bmm(qcfg):
2451+
qcfg["cached_input"][i] = m(cached_inp_prev_lay, **data_mb)[0]
2452+
2453+
m.cpu()
2454+
torch.cuda.empty_cache()
2455+
2456+
logger.info("All blocks are calibrated")
2457+
2458+
23682459
@torch.no_grad()
23692460
def activation_stats(name, tensor, act_scales):
23702461
# TODO if 'QBmm' in name: reshape the tensor.
@@ -2498,8 +2589,8 @@ def get_act_scales_1gpu(model, dloader, qcfg):
24982589

24992590
assert "model_type" in qcfg, "Unknown model type. please check before proceed."
25002591
assert (
2501-
qcfg["loader_len"] == qcfg["ptq_nbatch"]
2502-
), "set batch_size=1 and PTQ samples== Nbatches"
2592+
qcfg["loader_len"] >= qcfg["ptq_nbatch"]
2593+
), "Please make sure dataloader has enough data needed for PTQ (ie. check qcfg['ptq_nbatch'])."
25032594
# --- Phase 0 cache the inputs of the block0---
25042595
blocks, emb, emb_pos, emb_ln, _, _ = get_blocks(model, qcfg["model_type"])
25052596
cache_block0_inputs(

fms_mo/utils/eval_utils.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
# Local
2828
from fms_mo.quant.ptq import cache_block0_inputs, get_blocks
29-
from fms_mo.utils.utils import patch_torch_bmm
29+
from fms_mo.utils.utils import move_to, patch_torch_bmm
3030

3131
logger = logging.getLogger(__name__)
3232

@@ -35,11 +35,13 @@
3535
def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): # pylint: disable=unused-argument
3636
"""
3737
Evaluate causal LLM with 1GPU, return perplexity
38-
Note: currently taking test_dataset as dict (instead of dataloader)
39-
Used for models that cannot fit into a 1 GPU.
38+
Note:
39+
1. currently taking test_dataset as dict (instead of dataloader)
40+
2. Used for models that cannot fit into a 1 GPU. Will need to move modules back and forth.
41+
3. Keep hid_state on device to reduce uncessary data transfer.
4042
"""
4143
model.eval()
42-
dev = "cuda:0" # cuda:0 is used for PTQ
44+
dev = "cuda"
4345
qcfg["batch_size"] = 1 # for dataloading, always use batch_size of 1
4446
qcfg["dtype"] = next(iter(model.parameters())).dtype
4547
seq_len = qcfg["seq_len"]
@@ -63,7 +65,14 @@ def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): #
6365
# Phase 1: compute blocks and last linear layer
6466
pbar = tqdm(blocks, desc="evaluation: compute blocks")
6567

66-
qcfg["cached_input"] = [inp.clone().detach() for inp in qcfg["cached_block0_input"]]
68+
qcfg["cached_input"] = [
69+
inp.clone().detach().to(dev) for inp in qcfg["cached_block0_input"]
70+
]
71+
kw_to_use = {
72+
kw_org: kw_new
73+
for kw_org, kw_new in qcfg["kw_to_cache"].items()
74+
if len(qcfg[kw_new]) == len(qcfg["cached_input"])
75+
}
6776
for block_id, m in enumerate(pbar): # pylint: disable=unused-variable
6877
m.to(dev)
6978
for i in range(qcfg["n_samples"]):
@@ -74,16 +83,14 @@ def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): #
7483
"alibi": qcfg["cached_alibi"][i].unsqueeze(0).to(dev),
7584
}
7685
else:
77-
cached_inp_prev_lay = qcfg["cached_input"][i].to(dev)
86+
cached_inp_prev_lay = qcfg["cached_input"][i]
7887
data_mb = {
79-
"attention_mask": qcfg["cached_mask"][i].to(dev)
80-
if len(qcfg["cached_mask"]) > 0
81-
else None,
82-
"position_ids": qcfg["position_ids"][i].to(dev),
88+
kw_org: move_to(qcfg[kw_new][i], dev)
89+
for kw_org, kw_new in kw_to_use.items()
8390
}
8491

85-
with torch.no_grad(), patch_torch_bmm(qcfg):
86-
qcfg["cached_input"][i] = m(cached_inp_prev_lay, **data_mb)[0].cpu()
92+
with patch_torch_bmm(qcfg):
93+
qcfg["cached_input"][i] = m(cached_inp_prev_lay, **data_mb)[0]
8794

8895
m.cpu()
8996
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)