Skip to content

Commit 2be858a

Browse files
Merge branch 'main' into gptq_model
Signed-off-by: chichun-charlie-liu <[email protected]>
2 parents 51140cf + 05bb442 commit 2be858a

39 files changed

+3295
-515
lines changed

.github/pull_request_template.md

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,26 @@
44

55
<!-- Please summarize the changes -->
66

7-
### Related issue number
7+
### Related issues or PRs
88

9-
<!-- For example: "Closes #1234" -->
9+
<!-- For example: "Closes #1234" or "Fixes bug introduced in #5678 -->
1010

11-
### How to verify the PR
11+
### (Optional) List any documentation or testing added
1212

13-
<!-- Please provide instruction or screenshots on how to verify the PR.-->
13+
<!-- Describe which features were documented/tested -->
1414

15-
### Was the PR tested
15+
### (Optional) How to verify the contribution
1616

17-
<!-- Describe how PR was tested -->
18-
- [ ] I have added >=1 unit test(s) for every new method I have added.
19-
- [ ] I have ensured all unit tests pass
17+
<!-- Provide instructions on how to verify your contribution if unit tests do not provide coverage -->
18+
19+
### Checklist for passing CI/CD:
20+
21+
<!-- Mark completed tasks with "- [x]" -->
22+
- [ ] All commits are signed showing "Signed-off-by: Name \<[email protected]\>" with `git commit -signoff` or equivalent
23+
- [ ] PR title and commit messages adhere to [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/)
24+
- [ ] Contribution is formatted with `tox -e fix`
25+
- [ ] Contribution passes linting with `tox -e lint`
26+
- [ ] Contribution passes spellcheck with `tox -e spellcheck`
27+
- [ ] Contribution passes all unit tests with `tox -e unit`
28+
29+
Note: CI/CD performs unit tests on multiple versions of Python from a fresh install. There may be differences with your local environment and the test environment.

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,9 @@ fms_mo.log
4545
data*_train/
4646
data*_test/
4747
act_scales/
48-
examples/
48+
examples/**/*.json
49+
examples/**/*.safetensors
50+
examples/**/*.log
51+
examples/**/*.sh
52+
examples/**/*.pt
53+
examples/**/*.arrow

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ persistent=yes
9494

9595
# Minimum Python version to use for version dependent checks. Will default to
9696
# the version used to run pylint.
97-
py-version=3.9
97+
py-version=3.10
9898

9999
# Discover python modules and packages in the file system subtree.
100100
recursive=no

.spellcheck-en-custom.txt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
activations
22
acc
33
ADR
4+
aiu
5+
AIU
6+
Spyre
7+
spyre
48
Args
59
autoregressive
610
backpropagation
@@ -22,8 +26,10 @@ dequantization
2226
dq
2327
DQ
2428
dev
29+
dtype
2530
eval
2631
fms
32+
fmsmo
2733
fp
2834
FP
2935
FP8Arguments
@@ -91,8 +97,11 @@ quantizes
9197
Quantizing
9298
QW
9399
rceil
100+
recomputation
94101
repo
95102
representable
103+
roberta
104+
RoBERTa
96105
runtime
97106
Runtime
98107
SAWB
@@ -112,9 +121,19 @@ Tokenizer
112121
toml
113122
triton
114123
Unquantized
124+
utils
115125
vals
116126
venv
117127
vllm
118128
xs
119129
zp
120-
130+
microxcaling
131+
Microscaling
132+
microscaling
133+
MX
134+
mx
135+
MXINT
136+
mxint
137+
MXFP
138+
mxfp
139+
OCP

examples/AIU_CONVERSION/README.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Train and prepare INT8 checkpoint for the AIU using Direct Quantization
2+
This example builds on the [Direct Quantization (DQ) example](../DQ_SQ/README.md). We assume the user is already familiar with the DQ quantization process and would like to generate an INT8-quantized checkpoint that is made compliant with the requirements of the AIU/Spire accelerator.
3+
4+
Once created, this checkpoint can be run on the AIU by using an inference script from [aiu-fms-testing-utils](https://github.com/foundation-model-stack/aiu-fms-testing-utils).
5+
6+
For more information on the AIU/Spyre accelerator, see the following blogs:
7+
- [Introducing the IBM Spyre AI Accelerator chip](https://research.ibm.com/blog/spyre-for-z)
8+
- [IBM Power modernizes infrastructure and accelerates innovation with AI in the year ahead](https://newsroom.ibm.com/blog-ibm-power-modernizes-infrastructure-and-accelerates-innovation-with-ai-in-the-year-ahead)
9+
10+
## Requirements
11+
- [FMS Model Optimizer requirements](../../README.md#requirements)
12+
13+
## QuickStart
14+
15+
**1. Prepare Data** as per DQ quantization process ([link](../DQ_SQ/README.md)). In this example, we assume the user wants to quantized RoBERTa-base model and has thus prepared the DQ data for it, stored under the folder `data_train` and `data_test`, by adapting the DQ example accordingly.
16+
17+
**2. Apply DQ with conversion** by providing the desired quantization parameters, as well as the flags `--save_ckpt_for_aiu` and `--recompute_narrow_weights`.
18+
19+
```bash
20+
python -m fms_mo.run_quant \
21+
--model_name_or_path "roberta-base" \
22+
--training_data_path data_train \
23+
--test_data_path data_test \
24+
--torch_dtype "float16" \
25+
--quant_method dq \
26+
--nbits_w 8 \
27+
--nbits_a 8 \
28+
--nbits_kvcache 32 \
29+
--qa_mode "pertokenmax"\
30+
--qw_mode "maxperCh" \
31+
--qmodel_calibration_new 1 \
32+
--output_dir "dq_test" \
33+
--save_ckpt_for_aiu \
34+
--recompute_narrow_weights
35+
```
36+
> [!TIP]
37+
> - In this example, we are not evaluating the perplexity of the quantized model, but, if so desired, the user can add the `--eval_ppl` flag.
38+
> - We set a single calibration example because the quantizers in use do not need calibration: weights remain static during DQ, so a single example will initialize the quantizer correctly, and the activation quantizer `pertokenmax` will dynamically recompute the quantization range at inference time, when running on the AIU.
39+
40+
**3. Reload checkpoint for testing** and validate its content (optional).
41+
42+
```python
43+
sd = torch.load("dq_test/qmodel_for_aiu.pt", weights_only=True)
44+
```
45+
46+
Check that all quantized layers have been converted to `torch.int8`, while the rest are `torch.float16`.
47+
48+
```python
49+
# select quantized layers by name
50+
roberta_qlayers = ["attention.self.query", "attention.self.key", "attention.self.value", "attention.output.dense", "intermediate.dense", "output.dense"]
51+
# assert all quantized weights are int8
52+
assert all(v.dtype == torch.int8 for k,v in sd.items() if any(n in k for n in roberta_qlayers) and k.endswith(".weight"))
53+
# assert all other parameters are fp16
54+
assert all(v.dtype == torch.float16 for k,v in sd.items() if all(n not in k for n in roberta_qlayers) or not k.endswith(".weight"))
55+
```
56+
57+
> [!TIP]
58+
> - We have trained the model with symmetric quantizer for activations (`qa_mode`). If an asymmetric quantizer is used, then the checkpoint will also carry a `zero_shift` parameters which is torch.float32, so this validation step should be modified accordingly.
59+
60+
Because we have used the `narrow_weight_recomputation` option along with a `maxperCh` (max per-channel) quantizer for weights, the INT weight matrices distributions have been widened. Most values of standard deviation (per channel) should surpass the empirical threshold of 20.
61+
62+
```python
63+
[f"{v.to(torch.float32).std(dim=-1).mean():.4f}" for k,v in sd.items() if k.endswith(".weight") and any(n in k for n in roberta_qlayers)]
64+
```
65+
66+
> [!TIP]
67+
> - We cast the torch.int8 weights to torch.float32 to be able to apply the torch.std function.
68+
> - For per-channel weights, the recomputation is applied per-channel. Here we print a mean across channels for help of visualization.
69+
> - It is not a guarantee that the recomputed weights will exceed the empirical threshold after recomputation, but it is the case for several common models of BERT, RoBERTa, Llama, and Granite families.

examples/MX/README.md

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# `microscaling` Examples Using a Toy Model and Direct Quantization (DQ)
2+
Microscaling, or "MX", format, such as `MXFP8`, is a different numeric format compared to commonly used FP8 formats. For example, PyTorch provides two FP8 formats, which are 1 sign bit, 4 exponent bits, and 3 mantissa bits (denoted as `e4m3`) or 1 sign bit, 5 exponent bits, and 2 mantissa bits (`e5m2`), see our other [FP8 example](../FP8_QUANT/README.md) for more details. On the other hand, all the `mx` formats are group-based data structure where each member of the group is using the specified format, e.g. FP8 for MXFP8, while each group has a shared (usually 8-bit) "scale". Group size could be as small as 32 or 16, depending on hardware design. One may consider each MXFP8 number actually requires 8.25 bits (when group size is 32) instead of 8 bits. More details about microscaling can be found in [this OCP document](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
3+
4+
Here, we provide two simple examples of using MX format in `fms-mo`.
5+
6+
> [!NOTE]
7+
It is important to keep in mind that `mx` is not natively supported by Hopper GPUs yet (some will be supported by Blackwell), which means the quantization configurations and corresponding behavior are simulated. Hence, no real "speed up" should be expected.
8+
9+
10+
## Requirements
11+
- [FMS Model Optimizer requirements](../../README.md#requirements)
12+
- Microsoft `microxcaling` python package, download [here](https://github.com/microsoft/microxcaling.git).
13+
> [!TIP]
14+
> `FMS-Model-Optimizer` and `microxcaling` have clashing dependency requirements for `PyTorch` packages. We have created a patching solution to resolve this, run the following in command line:
15+
``` bash
16+
python3 ../install_patches.py
17+
```
18+
This patching file will either download the repo for you, or look for an already installed version in `$HOME` or the current working directory, then install the patch.
19+
For more information, see `patches/README.md`.
20+
21+
## QuickStart
22+
23+
### Example 1
24+
First example is based on a toy model with only a few Linear layers, in which only one Linear layer will be quantized with MX version of `int8`, `int4`, `fp8`, and `fp4`. The example can simply be run as follow
25+
26+
```bash
27+
>>> python simple_mx_example.py
28+
```
29+
30+
Comparison between different formats, including the first 3 elements from output tensors and the norm compared to FP32 reference, is shown below.
31+
32+
| dtype | output[0, 0] | output[0, 1] | output[0, 2] | \|\|ref - out_dtype\|\|<sub>2</sub> |
33+
|:-----------|---------------:|---------------:|---------------:|------------------------:|
34+
| fp32 | -1.0491 | 0.5312 | -1.6387 | 0.0000 |
35+
| fmsmo_int8 | -1.0577 | 0.5346 | -1.6508 | 0.4937 |
36+
| fmsmo_int4 | -0.5885 | 0.5831 | -1.7976 | 8.2927 |
37+
| mxint8 | -0.6444 | 0.6828 | -1.8626 | 8.3305 |
38+
| mxint4 | -0.9089 | 0.6141 | -1.7630 | 8.0692 |
39+
| mxfp8_e4m3 | -0.8031 | 0.7262 | -1.9581 | 7.8554 |
40+
| mxfp8_e5m2 | -0.8471 | 0.7319 | -1.7458 | 8.1838 |
41+
| mxfp4_e2m1 | -0.7506 | 0.6123 | -1.9311 | 7.9936 |
42+
43+
44+
### Example 2
45+
The second example is the same as the [DQ example](../DQ_SQ/README.md), except using [microxcaling](https://arxiv.org/abs/2310.10537) format. We only demonstrate `mxfp8` and `mxfp4` here, but MXINT8, MXFP8, MXFP6, MXFP4 are also available for weights, activations, and/or KV-cache.
46+
47+
**1. Prepare Data** for calibration process by converting into its tokenized form. An example of tokenization using `LLAMA-3-8B`'s tokenizer is below.
48+
49+
```python
50+
from transformers import AutoTokenizer
51+
from fms_mo.utils.calib_data import get_tokenized_data
52+
53+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True)
54+
num_samples = 128
55+
seq_len = 2048
56+
get_tokenized_data("wiki", num_samples, seq_len, tokenizer, path_to_save='data')
57+
```
58+
> [!NOTE]
59+
> - Users should provide a tokenized data file based on their need. This is just one example to demonstrate what data format `fms_mo` is expecting.
60+
> - Tokenized data will be saved in `<path_to_save>_train` and `<path_to_save>_test`
61+
> - If you have trouble downloading Llama family of models from Hugging Face ([LLama models require access](https://www.llama.com/docs/getting-the-models/hugging-face/)), you can use `ibm-granite/granite-8b-code` instead
62+
63+
**2. Apply DQ** by providing specific hyper-parameters such as `quant_method`, weight quantizers (`qw_mode`) and activation quantizers (`qa_mode`) etc. An example using `Meta-Llama-3-8B` and the tokenized training and test data is provided below.
64+
```bash
65+
python -m fms_mo.run_quant \
66+
--model_name_or_path "meta-llama/Meta-Llama-3-8B" \
67+
--training_data_path data_train \
68+
--test_data_path data_test \
69+
--torch_dtype "float16" \
70+
--quant_method dq \
71+
--nbits_w 8 \
72+
--nbits_a 8 \
73+
--nbits_kvcache 32 \
74+
--qa_mode "mx_fp8_e4m3"\
75+
--qw_mode "mx_fp8_e4m3" \
76+
--output_dir "dq_test" \
77+
--eval_ppl
78+
```
79+
> [!NOTE]
80+
> To use MX format, simply assign `qa_mode` and `qw_mode` argument with a `mx_<dtype supported by mx package>`, e.g. `mx_fp8_e4m3` as in the above example. Corresponding `QLinearMX` wrappers will be used in place of `QLinear` as in other examples.
81+
82+
**3. Compare the Perplexity score** For user convenience, the code will print out perplexity (controlled by `eval_ppl` flag) at the end of the run, so no additional steps needed (if the logging level is set to `INFO` in terminal). You can check output in the logging file. `./fms_mo.log`.
83+
84+
85+
## Example Test Results
86+
The perplexity of the INT8 and FP8 quantized models on the `wikitext` dataset is shown below:
87+
88+
| Model |Type |QA |QW |DQ |SQ |Perplexity|
89+
|:---------:|:---:|:------------:|:------------:|:--:|:--:|:--------:|
90+
|`Llama3-8b`|INT8 |maxpertoken |maxperCh |yes |yes |6.22 |
91+
| |FP8 |fp8_e4m3_scale|fp8_e4m3_scale|yes |yes |6.19 |
92+
| |**MX**|mx_fp8_e4m3 |mx_fp8_e4m3 |yes |**no** |6.23 |
93+
| |**MX**|mx_fp4_e2m1 |mx_fp4_e2m1 |yes |**no** |8.22 |
94+
95+
96+
> [!NOTE]
97+
> SmoothQuant is disabled when `mx` is being used. See `dq.py` for more details.
98+

0 commit comments

Comments
 (0)