Skip to content

Commit 136c12f

Browse files
Merge pull request #70 from chichun-charlie-liu/dq_fix
fix: multiple bug fixes:
2 parents 3c07b46 + 3f7cc67 commit 136c12f

File tree

8 files changed

+103
-89
lines changed

8 files changed

+103
-89
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ error.log
4242

4343
# Files generated from running examples
4444
fms_mo.log
45-
data_train/
46-
data_test/
45+
data*_train/
46+
data*_test/
4747
act_scales/

.spellcheck-en-custom.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
activations
2+
acc
23
ADR
34
Args
45
AutoGPTQ
@@ -67,6 +68,7 @@ NLP
6768
Nouterloop
6869
Nvidia
6970
Nvidia's
71+
openai
7072
orchestrator
7173
param
7274
pre
@@ -99,6 +101,8 @@ SmoothQuant
99101
socio
100102
sparsification
101103
SQuAD
104+
stderr
105+
Stderr
102106
straightforward
103107
tokenization
104108
tokenized

examples/FP8_QUANT/README.md

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,18 @@ This end-to-end example utilizes the common set of interfaces provided by `fms_m
7373

7474
## Example Test Results
7575
- BF16 (not quantized) LLAMA3-8B model.
76-
``` bash
77-
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
78-
|--------------|------:|------|-----:|----------|---|-----:|---|-----:|
79-
|lambada_openai| 1|none | 5|acc ||0.7120|± |0.0287|
80-
| | |none | 5|perplexity||3.8683|± |0.3716|
81-
```
76+
77+
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
78+
|--------------|------:|------|-----:|----------|---|-----:|---|-----:|
79+
|lambada_openai| 1|none | 5|acc ||0.7120|± |0.0287|
80+
| | |none | 5|perplexity||3.8683|± |0.3716|
8281

8382
- FP8 quantized LLAMA3-8B model.
84-
``` bash
85-
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
86-
|--------------|------:|------|-----:|----------|---|-----:|---|-----:|
87-
|lambada_openai| 1|none | 5|acc ||0.7160|± |0.0286|
88-
| | |none | 5|perplexity||3.8915|± |0.3727|
89-
```
83+
84+
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
85+
|--------------|------:|------|-----:|----------|---|-----:|---|-----:|
86+
|lambada_openai| 1|none | 5|acc ||0.7160|± |0.0286|
87+
| | |none | 5|perplexity||3.8915|± |0.3727|
9088

9189
## Code Walk-through
9290

fms_mo/dq.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
# Standard
2222
from pathlib import Path
2323
import logging
24-
import os
2524

2625
# Third Party
2726
from datasets import load_from_disk
@@ -114,7 +113,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
114113
revision="main",
115114
use_auth_token=True if model_args.use_auth_token else None,
116115
torch_dtype=torch_dtype,
117-
low_cpu_mem_usage=False,
116+
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
117+
device_map="auto" if model_args.low_cpu_mem_usage else None,
118118
)
119119

120120
embedding_size = model.get_input_embeddings().weight.shape[0]
@@ -125,7 +125,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
125125
logger.info(f"Model is at {model.device} after intialization")
126126
logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}")
127127
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
128-
# for models that cannot fit in 1 GPU, keep it in CPU and use block-wise calibration.
128+
# for models that cannot fit in 1 GPU, keep it on CPU and use block-wise calibration.
129+
# or leverage HF's device_map="auto"
129130
total_gpu_memory = 1e-5
130131
if torch.cuda.is_available():
131132
total_gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
@@ -143,7 +144,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
143144
qcfg["large_model"] = any(
144145
name in model_args.model_name_or_path for name in known_large_models
145146
) or (gpu_mem_util_per > 0.7)
146-
dev = "cpu" if qcfg["large_model"] else "cuda:0"
147+
dev = "cpu" if qcfg["large_model"] else "cuda"
148+
model.to(dev)
147149

148150
if hasattr(model.config, "model_type"):
149151
qcfg["model_type"] = model.config.model_type
@@ -180,23 +182,27 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
180182
batch_size=1,
181183
)
182184

183-
# For loading or creating smoothquant scale.
184-
act_scale_directory = "./act_scales"
185-
if not os.path.exists(act_scale_directory):
186-
os.makedirs(act_scale_directory)
185+
# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
186+
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
187+
if qcfg.get("act_scale_path", None):
188+
# user provided a scale file (or a dir)
189+
scale_file_or_dir = Path(qcfg["act_scale_path"])
190+
if scale_file_or_dir.is_dir():
191+
scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt"
192+
elif scale_file_or_dir.is_file():
193+
scale_file = scale_file_or_dir
187194

188-
if qcfg["act_scale_path"] is not None:
189-
act_scales = torch.load(qcfg["act_scale_path"], map_location="cpu")
195+
if not scale_file.parent.exists():
196+
scale_file.parent.mkdir(exist_ok=False)
197+
198+
if scale_file.exists():
199+
act_scales = torch.load(scale_file, map_location=getattr(model, "device", dev))
190200
else:
191201
logger.info("Generate activation scales")
192202
if qcfg["large_model"]:
193203
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
194204
else:
195-
if gpu_mem_util_per < 0.7:
196-
model.to(dev)
197-
198205
act_scales = get_act_scales(model, dq_dataloader, qcfg)
199-
scale_file = f"{act_scale_directory}/{qcfg['model'].replace('/', '-')}" + ".pt"
200206
torch.save(act_scales, scale_file)
201207

202208
qmodel_prep(

fms_mo/modules/linear.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def __init__(
9393
Defaults to 32.
9494
qw_mode (str, optional): Quantization mode for weight. Defaults to None.
9595
**kwargs (dict): Additional keyword arguments.
96+
97+
Note:
98+
scales could be of higher precision than x or W, need to make sure qinput.dtype after
99+
Qa(x/scale) are consistent with x. Same for W
96100
"""
97101

98102
super().__init__(
@@ -275,15 +279,17 @@ def forward(self, x):
275279
# pylint: disable=not-callable
276280
return F.linear(x, self.W_fp, self.bias)
277281
else:
278-
qinput = self.quantize_feature(x / scale)
282+
qinput = self.quantize_feature(x / scale).to(x.dtype)
279283
# Default self.update_type == 'hard' pruning.
280284
if self.mask is not None:
281285
pweight = HardPrune.apply(
282286
self.weight, self.mask.to(self.weight.device), self.p_inplace
283287
)
284288
qweight = self.quantize_weight(pweight)
285289
else:
286-
qweight = self.quantize_weight(self.weight * scale)
290+
qweight = self.quantize_weight(self.weight * scale).to(
291+
self.weight.dtype
292+
)
287293

288294
qbias = self.bias
289295

fms_mo/quant/ptq.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,14 +1943,12 @@ def __init__(self, module, qcfg):
19431943
def forward(self, inp, **kwargs):
19441944
self.qcfg["cached_block0_input"].append(inp.cpu())
19451945
self.qcfg["cache_id"] += 1
1946-
for k, v in kwargs.items():
1947-
if k == "attention_mask":
1948-
if v is not None:
1949-
self.qcfg["cached_mask"].append(v.cpu())
1950-
if k == "alibi":
1951-
self.qcfg["cached_alibi"].append(v.cpu())
1952-
if k == "position_ids":
1953-
self.qcfg["position_ids"].append(v.cpu())
1946+
for kw_org, kw_qcfg in self.qcfg["kw_to_cache"].items():
1947+
if kw_qcfg not in self.qcfg:
1948+
self.qcfg[kw_qcfg] = []
1949+
v = kwargs.get(kw_org, None)
1950+
if v is not None:
1951+
self.qcfg[kw_qcfg].append(move_to(v, "cpu"))
19541952
raise ValueError
19551953

19561954

@@ -1965,14 +1963,15 @@ def __init__(self, module, qcfg):
19651963
self.module = module
19661964

19671965
def forward(self, **kwargs):
1968-
for k, v in kwargs.items():
1969-
if k == "x":
1970-
self.qcfg["cached_block0_input"][self.qcfg["cache_id"]] = v.cpu()
1971-
self.qcfg["cache_id"] += 1
1972-
if k == "mask":
1973-
self.qcfg["cached_mask"] = v.cpu()
1974-
if k == "rel_pos_bias":
1975-
self.qcfg["cached_pos_bias"] = v.cpu()
1966+
self.qcfg["cached_block0_input"][self.qcfg["cache_id"]] = kwargs["x"].cpu()
1967+
self.qcfg["cache_id"] += 1
1968+
for kw_org, kw_qcfg in self.qcfg["kw_to_cache"]:
1969+
if kw_qcfg not in self.qcfg:
1970+
self.qcfg[kw_qcfg] = []
1971+
v = kwargs.get(kw_org, None)
1972+
if v is not None:
1973+
self.qcfg[kw_qcfg].append(v.cpu())
1974+
19761975
raise ValueError
19771976

19781977

@@ -2126,13 +2125,21 @@ def cache_block0_inputs(
21262125
qcfg["cache_id"] = 0
21272126
qcfg["cached_mask"] = []
21282127
qcfg["cached_alibi"] = []
2129-
qcfg[
2130-
"position_ids"
2131-
] = [] # latest transformers requires pos_ids to be fed into fwd()
21322128
# move block0 to GPU and excuting fwd() until finish block0
21332129
if "fms" in qcfg["model_type"]:
2130+
qcfg["kw_to_cache"] = {
2131+
"mask": "cached_mask",
2132+
"rel_pos_bias": "cached_pos_bias",
2133+
}
21342134
blocks[0] = RunFMModule(blocks[0], qcfg)
21352135
else:
2136+
# latest transformers requires pos_ids to be fed into fwd()
2137+
qcfg["kw_to_cache"] = {
2138+
"attention_mask": "cached_mask",
2139+
"alibi": "cached_alibi",
2140+
"position_ids": "position_ids",
2141+
"position_embeddings": "position_embeddings",
2142+
}
21362143
blocks[0] = RunModule(blocks[0], qcfg)
21372144

21382145
if isinstance(dloader, torch.utils.data.DataLoader):
@@ -2464,12 +2471,13 @@ def get_module_act_scales(m, block_idx, qcfg, act_scales):
24642471
alibi=qcfg["cached_alibi"][i].unsqueeze(0).to(dev),
24652472
)[0].cpu()
24662473
else:
2474+
kwargs = {
2475+
kw_org: move_to(qcfg[kw_qcfg][i], dev) if qcfg[kw_qcfg] != [] else None
2476+
for kw_org, kw_qcfg in qcfg["kw_to_cache"].items()
2477+
}
24672478
qcfg["cached_input"][i] = m(
24682479
qcfg["cached_input"][i].to(dev),
2469-
attention_mask=None
2470-
if qcfg["cached_mask"] == []
2471-
else qcfg["cached_mask"][i].to(dev),
2472-
position_ids=qcfg["position_ids"][i].to(dev),
2480+
**kwargs,
24732481
)[0].cpu()
24742482
for h in hooks:
24752483
h.remove()
@@ -2482,7 +2490,7 @@ def get_act_scales_1gpu(model, dloader, qcfg):
24822490
"""
24832491
get activation blocks on 1gpu for very large models that cannot fit in 1gpu
24842492
"""
2485-
dev = "cuda:0"
2493+
dev = "cuda"
24862494
qcfg["batch_size"] = 1
24872495
qcfg["loader_len"] = len(dloader)
24882496
qcfg["dtype"] = next(iter(model.parameters())).dtype

fms_mo/training_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ class ModelArguments(TypeChecker):
5656

5757
model_name_or_path: str = field(default="facebook/opt-125m")
5858
torch_dtype: str = field(default="bfloat16")
59+
low_cpu_mem_usage: bool = field(
60+
default=False,
61+
metadata={
62+
"help": "When set to True, leverage device_map='auto' and let HF to move modules"
63+
"between cpu and cuda automatically during inference."
64+
},
65+
)
5966
use_fast_tokenizer: bool = field(
6067
default=True,
6168
metadata={

fms_mo/utils/dq_utils.py

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -38,50 +38,35 @@ def config_quantize_smooth_layers(qcfg):
3838
"granite-20b-code",
3939
"granite-20b-code",
4040
]
41-
if any(model in qcfg["model"] for model in llama_architecture) or any(
42-
model in qcfg["model_type"] for model in llama_architecture
41+
if (
42+
any(model in qcfg["model"] for model in llama_architecture)
43+
or any(model in qcfg["model_type"] for model in llama_architecture)
44+
and qcfg["qskip_large_mag_layers"]
4345
):
4446
qcfg["qlayer_name_pattern"] = ["model.layers."]
4547
qcfg["scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"]
46-
qcfg["qskip_layer_name"] = []
47-
if "2-7b" in qcfg["model"]:
48-
if qcfg["qskip_large_mag_layers"]:
49-
qcfg["qskip_layer_name"] = [
50-
f"model.layers.{i}.mlp.down_proj" for i in [1, 30]
51-
]
52-
if "2-13b" in qcfg["model"]:
53-
if qcfg["qskip_large_mag_layers"]:
54-
qcfg["qskip_layer_name"] = [
55-
f"model.layers.{i}.mlp.down_proj" for i in [3, 37]
56-
]
57-
if "2-70b" in qcfg["model"]:
58-
if qcfg["qskip_large_mag_layers"]:
59-
qcfg["qskip_layer_name"] = [
60-
f"model.layers.{i}.mlp.down_proj" for i in [2, 8, 79]
61-
]
62-
if "3-8B" in qcfg["model"]:
63-
if qcfg["qskip_large_mag_layers"]:
64-
qcfg["qskip_layer_name"] = [
65-
f"model.layers.{i}.mlp.down_proj" for i in [1, 31]
66-
]
67-
if "3-70B" in qcfg["model"]:
68-
if qcfg["qskip_large_mag_layers"]:
69-
qcfg["qskip_layer_name"] = [
70-
f"model.layers.{i}.mlp.down_proj" for i in [3, 78, 79]
71-
]
72-
if "405B-Instruct" in qcfg["model"]: # llama3.1
73-
if qcfg["qskip_large_mag_layers"]:
74-
qcfg["qskip_layer_name"] = [
75-
f"model.layers.{i}.mlp.down_proj" for i in [5, 124, 125]
48+
large_mag_layers = {
49+
"2-7b": [1, 30],
50+
"2-70b": [2, 8, 79],
51+
"3-8B": [1, 31],
52+
"3-70B": [3, 78, 79],
53+
"405B-Instruct": [5, 124, 125],
54+
}
55+
for llama_family, layers in large_mag_layers.items():
56+
if llama_family in qcfg["model"]:
57+
qcfg["qskip_layer_name"] += [
58+
f"model.layers.{i}.mlp.down_proj" for i in layers
7659
]
60+
break
61+
7762
elif "mixtral" in qcfg["model"]:
7863
qcfg["qlayer_name_pattern"] = (
7964
["model.layers"] if qcfg["nbits_bmm1"] == 32 else []
8065
)
8166
qcfg["scale_layers"] = ["q_proj", "k_proj", "v_proj", "w1", "w3"]
82-
qcfg["qskip_layer_name"] = []
83-
for i in range(32):
84-
qcfg["qskip_layer_name"].append(f"model.layers.{i}.block_sparse_moe.gate")
67+
qcfg["qskip_layer_name"] += [
68+
f"model.layers.{i}.block_sparse_moe.gate" for i in range(32)
69+
]
8570
if qcfg["qskip_large_mag_layers"]:
8671
qcfg["qskip_layer_name"] += [
8772
f"model.layers.{i}.block_sparse_moe.experts.{j}.w2"

0 commit comments

Comments
 (0)