You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Train and prepare INT8 checkpoint for the AIU using Direct Quantization
2
+
This example builds on the [Direct Quantization (DQ) example](../DQ_SQ/README.md). We assume the user is already familiar with the DQ quantization process and would like to generate an INT8-quantized checkpoint that is made compliant with the requirements of the AIU/Spire accelerator.
3
+
4
+
Once created, this checkpoint can be run on the AIU by using an inference script from [aiu-fms-testing-utils](https://github.com/foundation-model-stack/aiu-fms-testing-utils).
5
+
6
+
For more information on the AIU/Spyre accelerator, see the following blogs:
7
+
-[Introducing the IBM Spyre AI Accelerator chip](https://research.ibm.com/blog/spyre-for-z)
8
+
-[IBM Power modernizes infrastructure and accelerates innovation with AI in the year ahead](https://newsroom.ibm.com/blog-ibm-power-modernizes-infrastructure-and-accelerates-innovation-with-ai-in-the-year-ahead)
9
+
10
+
## Requirements
11
+
-[FMS Model Optimizer requirements](../../README.md#requirements)
12
+
13
+
## QuickStart
14
+
15
+
**1. Prepare Data** as per DQ quantization process ([link](../DQ_SQ/README.md)). In this example, we assume the user wants to quantized RoBERTa-base model and has thus prepared the DQ data for it, stored under the folder `data_train` and `data_test`, by adapting the DQ example accordingly.
16
+
17
+
**2. Apply DQ with conversion** by providing the desired quantization parameters, as well as the flags `--save_ckpt_for_aiu` and `--recompute_narrow_weights`.
18
+
19
+
```bash
20
+
python -m fms_mo.run_quant \
21
+
--model_name_or_path "roberta-base" \
22
+
--training_data_path data_train \
23
+
--test_data_path data_test \
24
+
--torch_dtype "float16" \
25
+
--quant_method dq \
26
+
--nbits_w 8 \
27
+
--nbits_a 8 \
28
+
--nbits_kvcache 32 \
29
+
--qa_mode "pertokenmax"\
30
+
--qw_mode "maxperCh" \
31
+
--qmodel_calibration_new 1 \
32
+
--output_dir "dq_test" \
33
+
--save_ckpt_for_aiu \
34
+
--recompute_narrow_weights
35
+
```
36
+
> [!TIP]
37
+
> - In this example, we are not evaluating the perplexity of the quantized model, but, if so desired, the user can add the `--eval_ppl` flag.
38
+
> - We set a single calibration example because the quantizers in use do not need calibration: weights remain static during DQ, so a single example will initialize the quantizer correctly, and the activation quantizer `pertokenmax` will dynamically recompute the quantization range at inference time, when running on the AIU.
39
+
40
+
**3. Reload checkpoint for testing** and validate its content (optional).
assertall(v.dtype == torch.int8 for k,v in sd.items() ifany(n in k for n in roberta_qlayers) and k.endswith(".weight"))
53
+
# assert all other parameters are fp16
54
+
assertall(v.dtype == torch.float16 for k,v in sd.items() ifall(n notin k for n in roberta_qlayers) ornot k.endswith(".weight"))
55
+
```
56
+
57
+
> [!TIP]
58
+
> - We have trained the model with symmetric quantizer for activations (`qa_mode`). If an asymmetric quantizer is used, then the checkpoint will also carry a `zero_shift` parameters which is torch.float32, so this validation step should be modified accordingly.
59
+
60
+
Because we have used the `narrow_weight_recomputation` option along with a `maxperCh` (max per-channel) quantizer for weights, the INT weight matrices distributions have been widened. Most values of standard deviation (per channel) should surpass the empirical threshold of 20.
61
+
62
+
```python
63
+
[f"{v.to(torch.float32).std(dim=-1).mean():.4f}"for k,v in sd.items() if k.endswith(".weight") andany(n in k for n in roberta_qlayers)]
64
+
```
65
+
66
+
> [!TIP]
67
+
> - We cast the torch.int8 weights to torch.float32 to be able to apply the torch.std function.
68
+
> - For per-channel weights, the recomputation is applied per-channel. Here we print a mean across channels for help of visualization.
69
+
> - It is not a guarantee that the recomputed weights will exceed the empirical threshold after recomputation, but it is the case for several common models of BERT, RoBERTa, Llama, and Granite families.
0 commit comments