Skip to content

Commit 4a1201e

Browse files
Merge branch 'main' into int-triton-kernel-adj
Signed-off-by: chichun-charlie-liu <[email protected]>
2 parents 553c7a6 + 418f682 commit 4a1201e

File tree

7 files changed

+126
-10
lines changed

7 files changed

+126
-10
lines changed

.spellcheck-en-custom.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
activations
22
acc
33
ADR
4+
aiu
5+
AIU
6+
Spyre
7+
spyre
48
Args
59
AutoGPTQ
610
autoregressive
@@ -91,8 +95,11 @@ quantizes
9195
Quantizing
9296
QW
9397
rceil
98+
recomputation
9499
repo
95100
representable
101+
roberta
102+
RoBERTa
96103
runtime
97104
Runtime
98105
SAWB
@@ -112,9 +119,9 @@ Tokenizer
112119
toml
113120
triton
114121
Unquantized
122+
utils
115123
vals
116124
venv
117125
vllm
118126
xs
119127
zp
120-

examples/AIU_CONVERSION/README.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Train and prepare INT8 checkpoint for the AIU using Direct Quantization
2+
This example builds on the [Direct Quantization (DQ) example](../DQ_SQ/README.md). We assume the user is already familiar with the DQ quantization process and would like to generate an INT8-quantized checkpoint that is made compliant with the requirements of the AIU/Spire accelerator.
3+
4+
Once created, this checkpoint can be run on the AIU by using an inference script from [aiu-fms-testing-utils](https://github.com/foundation-model-stack/aiu-fms-testing-utils).
5+
6+
For more information on the AIU/Spyre accelerator, see the following blogs:
7+
- [Introducing the IBM Spyre AI Accelerator chip](https://research.ibm.com/blog/spyre-for-z)
8+
- [IBM Power modernizes infrastructure and accelerates innovation with AI in the year ahead](https://newsroom.ibm.com/blog-ibm-power-modernizes-infrastructure-and-accelerates-innovation-with-ai-in-the-year-ahead)
9+
10+
## Requirements
11+
- [FMS Model Optimizer requirements](../../README.md#requirements)
12+
13+
## QuickStart
14+
15+
**1. Prepare Data** as per DQ quantization process ([link](../DQ_SQ/README.md)). In this example, we assume the user wants to quantized RoBERTa-base model and has thus prepared the DQ data for it, stored under the folder `data_train` and `data_test`, by adapting the DQ example accordingly.
16+
17+
**2. Apply DQ with conversion** by providing the desired quantization parameters, as well as the flags `--save_ckpt_for_aiu` and `--recompute_narrow_weights`.
18+
19+
```bash
20+
python -m fms_mo.run_quant \
21+
--model_name_or_path "roberta-base" \
22+
--training_data_path data_train \
23+
--test_data_path data_test \
24+
--torch_dtype "float16" \
25+
--quant_method dq \
26+
--nbits_w 8 \
27+
--nbits_a 8 \
28+
--nbits_kvcache 32 \
29+
--qa_mode "pertokenmax"\
30+
--qw_mode "maxperCh" \
31+
--qmodel_calibration_new 1 \
32+
--output_dir "dq_test" \
33+
--save_ckpt_for_aiu \
34+
--recompute_narrow_weights
35+
```
36+
> [!TIP]
37+
> - In this example, we are not evaluating the perplexity of the quantized model, but, if so desired, the user can add the `--eval_ppl` flag.
38+
> - We set a single calibration example because the quantizers in use do not need calibration: weights remain static during DQ, so a single example will initialize the quantizer correctly, and the activation quantizer `pertokenmax` will dynamically recompute the quantization range at inference time, when running on the AIU.
39+
40+
**3. Reload checkpoint for testing** and validate its content (optional).
41+
42+
```python
43+
sd = torch.load("dq_test/qmodel_for_aiu.pt", weights_only=True)
44+
```
45+
46+
Check that all quantized layers have been converted to `torch.int8`, while the rest are `torch.float16`.
47+
48+
```python
49+
# select quantized layers by name
50+
roberta_qlayers = ["attention.self.query", "attention.self.key", "attention.self.value", "attention.output.dense", "intermediate.dense", "output.dense"]
51+
# assert all quantized weights are int8
52+
assert all(v.dtype == torch.int8 for k,v in sd.items() if any(n in k for n in roberta_qlayers) and k.endswith(".weight"))
53+
# assert all other parameters are fp16
54+
assert all(v.dtype == torch.float16 for k,v in sd.items() if all(n not in k for n in roberta_qlayers) or not k.endswith(".weight"))
55+
```
56+
57+
> [!TIP]
58+
> - We have trained the model with symmetric quantizer for activations (`qa_mode`). If an asymmetric quantizer is used, then the checkpoint will also carry a `zero_shift` parameters which is torch.float32, so this validation step should be modified accordingly.
59+
60+
Because we have used the `narrow_weight_recomputation` option along with a `maxperCh` (max per-channel) quantizer for weights, the INT weight matrices distributions have been widened. Most values of standard deviation (per channel) should surpass the empirical threshold of 20.
61+
62+
```python
63+
[f"{v.to(torch.float32).std(dim=-1).mean():.4f}" for k,v in sd.items() if k.endswith(".weight") and any(n in k for n in roberta_qlayers)]
64+
```
65+
66+
> [!TIP]
67+
> - We cast the torch.int8 weights to torch.float32 to be able to apply the torch.std function.
68+
> - For per-channel weights, the recomputation is applied per-channel. Here we print a mean across channels for help of visualization.
69+
> - It is not a guarantee that the recomputed weights will exceed the empirical threshold after recomputation, but it is the case for several common models of BERT, RoBERTa, Llama, and Granite families.

fms_mo/dq.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
get_act_scales,
4444
get_act_scales_1gpu,
4545
)
46+
from fms_mo.utils.aiu_utils import save_for_aiu
4647
from fms_mo.utils.dq_utils import config_quantize_smooth_layers
4748
from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU
4849
from fms_mo.utils.utils import patch_torch_bmm, prepare_input
@@ -172,7 +173,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
172173

173174
qcfg["seq_len"] = block_size
174175
qcfg["model"] = model_args.model_name_or_path
175-
qcfg["smoothq"] = fms_mo_args.smoothq_alpha != -1
176+
qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0
176177
qcfg["plotsvg"] = False
177178

178179
calibration_dataset = load_from_disk(data_args.training_data_path)
@@ -199,7 +200,11 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
199200
scale_file.parent.mkdir(exist_ok=False)
200201

201202
if scale_file.exists():
202-
act_scales = torch.load(scale_file, map_location=getattr(model, "device", dev))
203+
act_scales = torch.load(
204+
scale_file,
205+
map_location=getattr(model, "device", dev),
206+
weights_only=True,
207+
)
203208
else:
204209
logger.info("Generate activation scales")
205210
if qcfg["large_model"]:
@@ -217,11 +222,13 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
217222
save_fname="dq",
218223
)
219224
logger.info(f"Quantized model {model}")
225+
logger.info("==" * 20)
226+
220227
if qcfg["smoothq"]:
221228
logger.info("Starting to apply smooth scale")
222229
dq_llm(model, act_scales, qcfg)
223230
logger.info("Finished applying smooth scale")
224-
logger.info("==" * 20)
231+
225232
if qcfg["qmodel_calibration_new"] > 0:
226233
logger.info("Starting to calibrate activation clip_val")
227234
if qcfg["large_model"]:
@@ -238,9 +245,15 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
238245
with patch_torch_bmm(qcfg):
239246
model(**data_mb)
240247

241-
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
242-
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
243-
tokenizer.save_pretrained(opt_args.output_dir)
248+
if opt_args.save_ckpt_for_aiu:
249+
logger.info(
250+
f"Saving model processed for AIU and tokenizer to {opt_args.output_dir}"
251+
)
252+
save_for_aiu(model, qcfg, output_dir=opt_args.output_dir, verbose=True)
253+
elif opt_args.save_ckpt:
254+
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
255+
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
256+
tokenizer.save_pretrained(opt_args.output_dir)
244257

245258
if fms_mo_args.eval_ppl:
246259
path_test = Path(data_args.test_data_path)

fms_mo/run_quant.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,11 @@ def main():
315315

316316
logger = set_log_level(opt_args.log_level, __name__)
317317

318-
logger.debug(f"Input args parsed: \nmodel_args {model_args}, data_args {data_args}, \
319-
opt_args {opt_args}, fms_mo_args {fms_mo_args}, gptq_args {gptq_args}, \
320-
fp8_args {fp8_args}")
318+
logger.debug(
319+
f"Input args parsed: \nmodel_args {model_args}, data_args {data_args}, "
320+
f"opt_args {opt_args}, fms_mo_args {fms_mo_args}, gptq_args {gptq_args}, "
321+
f"fp8_args {fp8_args}"
322+
)
321323
except Exception as e: # pylint: disable=broad-except
322324
logger.error(traceback.format_exc())
323325
write_termination_log(

fms_mo/training_args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,14 @@ class OptArguments(TypeChecker):
138138
default="INFO",
139139
metadata={"help": "The log level to adopt during optimization."},
140140
)
141+
save_ckpt: bool = field(
142+
default=True,
143+
metadata={"help": "Save quantized checkpoint."},
144+
)
145+
save_ckpt_for_aiu: bool = field(
146+
default=False,
147+
metadata={"help": "Prepare and save AIU-compliant checkpoint."},
148+
)
141149

142150

143151
@dataclass
@@ -176,6 +184,10 @@ class FMSMOArguments(TypeChecker):
176184
aiu_sim_triton: bool = field(
177185
default=False, metadata={"help": ("AIU simulation with triton kernel")}
178186
)
187+
recompute_narrow_weights: bool = field(
188+
default=False,
189+
metadata={"help": "Apply recomputation during checkpoint saving for AIU."},
190+
)
179191

180192

181193
@dataclass

fms_mo/utils/dq_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,5 +115,15 @@ def config_quantize_smooth_layers(qcfg: dict):
115115
qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_34b_base12.pt"
116116
if "granite-34b-code-instruct" in qcfg["model"]:
117117
qcfg["smoothq_act_scale_path"] = "./act_scales/graniteCodeHF_34b_base12.pt"
118+
elif "roberta" in qcfg["model"]:
119+
qcfg["act_scale_path"] = "./act_scales"
120+
qcfg["smoothq_scale_layers"] = [
121+
"attention.self.query",
122+
"attention.self.key",
123+
"attention.self.value",
124+
"intermediate.dense",
125+
]
126+
qcfg["qskip_layer_name"] = []
127+
qcfg["qlayer_name_pattern"] = ["roberta.encoder"]
118128
else:
119129
raise ValueError("The model architecture is not supported for DQ.")

fms_mo/utils/qconfig_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def config_defaults() -> dict:
114114
"extend_act_range": False,
115115
"plotsvg": False,
116116
"qskip_large_mag_layers": False,
117+
"recompute_narrow_weights": False,
117118
# Iterable vars
118119
"qlayer_name_pattern": [],
119120
"qskip_layer_name": [],
@@ -306,6 +307,7 @@ def qconfig_init(recipe: str = None, args: Any = None) -> dict:
306307
qcfg["qlayer_name_pattern"] = []
307308
qcfg["qskip_layer_name"] = []
308309
qcfg["qskip_large_mag_layers"] = False
310+
qcfg["recompute_narrow_weights"] = False
309311
qcfg["qspecial_layers"] = {}
310312

311313
# settings about quantizing bmm/matmul
@@ -878,6 +880,7 @@ def check_config(config: dict, model_dtype: torch.dtype = None) -> None:
878880
"ptq_freezecvs",
879881
"ptq_qdrop",
880882
"qskip_large_mag_layers",
883+
"recompute_narrow_weights",
881884
"smoothq",
882885
]
883886
for boolean_var_str in boolean_vars_str:

0 commit comments

Comments
 (0)