Skip to content

Commit 05bb442

Browse files
Merge pull request #110 from chichun-charlie-liu/mx_impl
feat: mx integration
2 parents 16ed9a2 + 4e0de3f commit 05bb442

23 files changed

+1658
-172
lines changed

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,9 @@ fms_mo.log
4545
data*_train/
4646
data*_test/
4747
act_scales/
48-
examples/
48+
examples/**/*.json
49+
examples/**/*.safetensors
50+
examples/**/*.log
51+
examples/**/*.sh
52+
examples/**/*.pt
53+
examples/**/*.arrow

.spellcheck-en-custom.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ dequantization
2727
dq
2828
DQ
2929
dev
30+
dtype
3031
eval
3132
fms
33+
fmsmo
3234
fp
3335
FP
3436
FP8Arguments
@@ -125,3 +127,13 @@ venv
125127
vllm
126128
xs
127129
zp
130+
microxcaling
131+
Microscaling
132+
microscaling
133+
MX
134+
mx
135+
MXINT
136+
mxint
137+
MXFP
138+
mxfp
139+
OCP

examples/MX/README.md

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# `microscaling` Examples Using a Toy Model and Direct Quantization (DQ)
2+
Microscaling, or "MX", format, such as `MXFP8`, is a different numeric format compared to commonly used FP8 formats. For example, PyTorch provides two FP8 formats, which are 1 sign bit, 4 exponent bits, and 3 mantissa bits (denoted as `e4m3`) or 1 sign bit, 5 exponent bits, and 2 mantissa bits (`e5m2`), see our other [FP8 example](../FP8_QUANT/README.md) for more details. On the other hand, all the `mx` formats are group-based data structure where each member of the group is using the specified format, e.g. FP8 for MXFP8, while each group has a shared (usually 8-bit) "scale". Group size could be as small as 32 or 16, depending on hardware design. One may consider each MXFP8 number actually requires 8.25 bits (when group size is 32) instead of 8 bits. More details about microscaling can be found in [this OCP document](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
3+
4+
Here, we provide two simple examples of using MX format in `fms-mo`.
5+
6+
> [!NOTE]
7+
It is important to keep in mind that `mx` is not natively supported by Hopper GPUs yet (some will be supported by Blackwell), which means the quantization configurations and corresponding behavior are simulated. Hence, no real "speed up" should be expected.
8+
9+
10+
## Requirements
11+
- [FMS Model Optimizer requirements](../../README.md#requirements)
12+
- Microsoft `microxcaling` python package, download [here](https://github.com/microsoft/microxcaling.git).
13+
> [!TIP]
14+
> `FMS-Model-Optimizer` and `microxcaling` have clashing dependency requirements for `PyTorch` packages. We have created a patching solution to resolve this, run the following in command line:
15+
``` bash
16+
python3 ../install_patches.py
17+
```
18+
This patching file will either download the repo for you, or look for an already installed version in `$HOME` or the current working directory, then install the patch.
19+
For more information, see `patches/README.md`.
20+
21+
## QuickStart
22+
23+
### Example 1
24+
First example is based on a toy model with only a few Linear layers, in which only one Linear layer will be quantized with MX version of `int8`, `int4`, `fp8`, and `fp4`. The example can simply be run as follow
25+
26+
```bash
27+
>>> python simple_mx_example.py
28+
```
29+
30+
Comparison between different formats, including the first 3 elements from output tensors and the norm compared to FP32 reference, is shown below.
31+
32+
| dtype | output[0, 0] | output[0, 1] | output[0, 2] | \|\|ref - out_dtype\|\|<sub>2</sub> |
33+
|:-----------|---------------:|---------------:|---------------:|------------------------:|
34+
| fp32 | -1.0491 | 0.5312 | -1.6387 | 0.0000 |
35+
| fmsmo_int8 | -1.0577 | 0.5346 | -1.6508 | 0.4937 |
36+
| fmsmo_int4 | -0.5885 | 0.5831 | -1.7976 | 8.2927 |
37+
| mxint8 | -0.6444 | 0.6828 | -1.8626 | 8.3305 |
38+
| mxint4 | -0.9089 | 0.6141 | -1.7630 | 8.0692 |
39+
| mxfp8_e4m3 | -0.8031 | 0.7262 | -1.9581 | 7.8554 |
40+
| mxfp8_e5m2 | -0.8471 | 0.7319 | -1.7458 | 8.1838 |
41+
| mxfp4_e2m1 | -0.7506 | 0.6123 | -1.9311 | 7.9936 |
42+
43+
44+
### Example 2
45+
The second example is the same as the [DQ example](../DQ_SQ/README.md), except using [microxcaling](https://arxiv.org/abs/2310.10537) format. We only demonstrate `mxfp8` and `mxfp4` here, but MXINT8, MXFP8, MXFP6, MXFP4 are also available for weights, activations, and/or KV-cache.
46+
47+
**1. Prepare Data** for calibration process by converting into its tokenized form. An example of tokenization using `LLAMA-3-8B`'s tokenizer is below.
48+
49+
```python
50+
from transformers import AutoTokenizer
51+
from fms_mo.utils.calib_data import get_tokenized_data
52+
53+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True)
54+
num_samples = 128
55+
seq_len = 2048
56+
get_tokenized_data("wiki", num_samples, seq_len, tokenizer, path_to_save='data')
57+
```
58+
> [!NOTE]
59+
> - Users should provide a tokenized data file based on their need. This is just one example to demonstrate what data format `fms_mo` is expecting.
60+
> - Tokenized data will be saved in `<path_to_save>_train` and `<path_to_save>_test`
61+
> - If you have trouble downloading Llama family of models from Hugging Face ([LLama models require access](https://www.llama.com/docs/getting-the-models/hugging-face/)), you can use `ibm-granite/granite-8b-code` instead
62+
63+
**2. Apply DQ** by providing specific hyper-parameters such as `quant_method`, weight quantizers (`qw_mode`) and activation quantizers (`qa_mode`) etc. An example using `Meta-Llama-3-8B` and the tokenized training and test data is provided below.
64+
```bash
65+
python -m fms_mo.run_quant \
66+
--model_name_or_path "meta-llama/Meta-Llama-3-8B" \
67+
--training_data_path data_train \
68+
--test_data_path data_test \
69+
--torch_dtype "float16" \
70+
--quant_method dq \
71+
--nbits_w 8 \
72+
--nbits_a 8 \
73+
--nbits_kvcache 32 \
74+
--qa_mode "mx_fp8_e4m3"\
75+
--qw_mode "mx_fp8_e4m3" \
76+
--output_dir "dq_test" \
77+
--eval_ppl
78+
```
79+
> [!NOTE]
80+
> To use MX format, simply assign `qa_mode` and `qw_mode` argument with a `mx_<dtype supported by mx package>`, e.g. `mx_fp8_e4m3` as in the above example. Corresponding `QLinearMX` wrappers will be used in place of `QLinear` as in other examples.
81+
82+
**3. Compare the Perplexity score** For user convenience, the code will print out perplexity (controlled by `eval_ppl` flag) at the end of the run, so no additional steps needed (if the logging level is set to `INFO` in terminal). You can check output in the logging file. `./fms_mo.log`.
83+
84+
85+
## Example Test Results
86+
The perplexity of the INT8 and FP8 quantized models on the `wikitext` dataset is shown below:
87+
88+
| Model |Type |QA |QW |DQ |SQ |Perplexity|
89+
|:---------:|:---:|:------------:|:------------:|:--:|:--:|:--------:|
90+
|`Llama3-8b`|INT8 |maxpertoken |maxperCh |yes |yes |6.22 |
91+
| |FP8 |fp8_e4m3_scale|fp8_e4m3_scale|yes |yes |6.19 |
92+
| |**MX**|mx_fp8_e4m3 |mx_fp8_e4m3 |yes |**no** |6.23 |
93+
| |**MX**|mx_fp4_e2m1 |mx_fp4_e2m1 |yes |**no** |8.22 |
94+
95+
96+
> [!NOTE]
97+
> SmoothQuant is disabled when `mx` is being used. See `dq.py` for more details.
98+

examples/MX/simple_mx_example.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Simple example using a toy model to demo how to trigger mx in fms-mo."""
15+
16+
# Third Party
17+
import numpy as np
18+
import torch
19+
import torch.nn.functional as F
20+
21+
22+
class ResidualMLP(torch.nn.Module):
23+
def __init__(self, hidden_size, device="cuda"):
24+
super(ResidualMLP, self).__init__()
25+
26+
self.layernorm = torch.nn.LayerNorm(hidden_size, device=device)
27+
self.dense_4h = torch.nn.Linear(hidden_size, 4 * hidden_size, device=device)
28+
self.dense_h = torch.nn.Linear(4 * hidden_size, hidden_size, device=device)
29+
self.dummy = torch.nn.Linear(hidden_size, hidden_size, device=device)
30+
# add a dummy layer because by default we skip 1st/last, if there are only 2 layers, all will be skipped
31+
32+
def forward(self, inputs):
33+
norm_outputs = self.layernorm(inputs)
34+
35+
# MLP
36+
proj_outputs = self.dense_4h(norm_outputs)
37+
proj_outputs = F.gelu(proj_outputs)
38+
mlp_outputs = self.dense_h(proj_outputs)
39+
mlp_outputs = self.dummy(mlp_outputs)
40+
41+
# Residual Connection
42+
outputs = inputs + mlp_outputs
43+
44+
return outputs
45+
46+
47+
if __name__ == "__main__":
48+
# Third Party
49+
from tabulate import tabulate
50+
51+
# Local
52+
from fms_mo import qconfig_init, qmodel_prep
53+
54+
HIDDEN_DIM = 128
55+
x = np.random.randn(16, HIDDEN_DIM)
56+
x = torch.tensor(x, dtype=torch.float32, device="cuda")
57+
results = {
58+
"dtype": [],
59+
"output[0, 0]": [],
60+
"output[0, 1]": [],
61+
"output[0, 2]": [],
62+
"||ref - out_dtype||_2": [],
63+
}
64+
65+
# --- Test 0. Run MLP as is
66+
model = ResidualMLP(HIDDEN_DIM)
67+
with torch.no_grad():
68+
out = model(x)
69+
results["dtype"].append("fp32")
70+
results["output[0, 0]"].append(out[0, 0])
71+
results["output[0, 1]"].append(out[0, 1])
72+
results["output[0, 2]"].append(out[0, 2])
73+
results["||ref - out_dtype||_2"].append(0)
74+
print(model)
75+
76+
# --- Test 1. fms-mo qmodel_prep, replace Linear with our QLinear
77+
qcfg = qconfig_init()
78+
qcfg["nbits_a"] = 8
79+
qcfg["nbits_w"] = 8
80+
qmodel_prep(model, x, qcfg)
81+
with torch.no_grad():
82+
out_dtype = model(x)
83+
results["dtype"].append("fmsmo_int8")
84+
results["output[0, 0]"].append(out_dtype[0, 0])
85+
results["output[0, 1]"].append(out_dtype[0, 1])
86+
results["output[0, 2]"].append(out_dtype[0, 2])
87+
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
88+
print(model)
89+
90+
qcfg["nbits_a"] = 4
91+
qcfg["nbits_w"] = 4
92+
model = ResidualMLP(HIDDEN_DIM)
93+
qmodel_prep(model, x, qcfg)
94+
with torch.no_grad():
95+
out_dtype = model(x)
96+
results["dtype"].append("fmsmo_int4")
97+
results["output[0, 0]"].append(out_dtype[0, 0])
98+
results["output[0, 1]"].append(out_dtype[0, 1])
99+
results["output[0, 2]"].append(out_dtype[0, 2])
100+
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
101+
print(model)
102+
103+
# --- Test 2. now change mapping to MX
104+
# NOTE simply use qa_mode or qw_mode to trigger the use of mx, e.g. use "mx_" prefixed mode,
105+
# qcfg["mapping"] and other qcfg["mx_specs"] content will be updated automatically
106+
107+
for dtype_to_test in ["int8", "int4", "fp8_e4m3", "fp8_e5m2", "fp4_e2m1"]:
108+
qcfg["qw_mode"] = f"mx_{dtype_to_test}"
109+
qcfg["qa_mode"] = f"mx_{dtype_to_test}"
110+
model = ResidualMLP(HIDDEN_DIM) # fresh model
111+
qmodel_prep(model, x, qcfg)
112+
with torch.no_grad():
113+
out_dtype = model(x)
114+
results["dtype"].append(f"mx{dtype_to_test}")
115+
results["output[0, 0]"].append(out_dtype[0, 0])
116+
results["output[0, 1]"].append(out_dtype[0, 1])
117+
results["output[0, 2]"].append(out_dtype[0, 2])
118+
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
119+
print(model)
120+
121+
print(tabulate(results, headers="keys", tablefmt="pipe", floatfmt=".4f"))
122+
123+
print("DONE!")

fms_mo/dq.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -159,21 +159,21 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
159159
config_quantize_smooth_layers(qcfg)
160160

161161
if any(x != 32 for x in attn_bits):
162-
logger.info("Quantize attention bmms or kvcache, use dynamo for prep")
162+
logger.info("Quantize attention bmms or kvcache, will use dynamo for prep")
163163
use_layer_name_pattern_matching = False
164164
qcfg["qlayer_name_pattern"] = []
165165
assert (
166166
qcfg["qlayer_name_pattern"] == []
167167
), "ensure nothing in qlayer_name_pattern when use dynamo"
168168
use_dynamo = True
169169
else:
170-
logger.info("Do not quantize attention bmms")
170+
logger.info("Attention bmms will not be quantized.")
171171
use_layer_name_pattern_matching = True
172172
use_dynamo = False
173173

174174
qcfg["seq_len"] = block_size
175175
qcfg["model"] = model_args.model_name_or_path
176-
qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0
176+
qcfg["smoothq"] = qcfg.get("smoothq_alpha", -1) >= 0 and "mx_specs" not in qcfg
177177
qcfg["plotsvg"] = False
178178

179179
calibration_dataset = load_from_disk(data_args.training_data_path)
@@ -187,31 +187,32 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
187187
)
188188

189189
# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
190-
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
191-
if qcfg.get("act_scale_path", None):
192-
# user provided a scale file (or a dir)
193-
scale_file_or_dir = Path(qcfg["act_scale_path"])
194-
if scale_file_or_dir.is_dir():
195-
scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt"
196-
elif scale_file_or_dir.is_file():
197-
scale_file = scale_file_or_dir
198-
199-
if not scale_file.parent.exists():
200-
scale_file.parent.mkdir(exist_ok=False)
201-
202-
if scale_file.exists():
203-
act_scales = torch.load(
204-
scale_file,
205-
map_location=getattr(model, "device", dev),
206-
weights_only=True,
207-
)
208-
else:
209-
logger.info("Generate activation scales")
210-
if qcfg["large_model"]:
211-
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
190+
if qcfg["smoothq"]:
191+
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
192+
if qcfg.get("act_scale_path", None):
193+
# user provided a scale file (or a dir)
194+
scale_file_or_dir = Path(qcfg["act_scale_path"])
195+
if scale_file_or_dir.is_dir():
196+
scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt"
197+
elif scale_file_or_dir.is_file():
198+
scale_file = scale_file_or_dir
199+
200+
if not scale_file.parent.exists():
201+
scale_file.parent.mkdir(exist_ok=False)
202+
203+
if scale_file.exists():
204+
act_scales = torch.load(
205+
scale_file, map_location=getattr(model, "device", dev)
206+
)
207+
212208
else:
213-
act_scales = get_act_scales(model, dq_dataloader, qcfg)
214-
torch.save(act_scales, scale_file)
209+
logger.info("Generate activation scales")
210+
if qcfg["large_model"]:
211+
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
212+
else:
213+
act_scales = get_act_scales(model, dq_dataloader, qcfg)
214+
torch.save(act_scales, scale_file)
215+
215216
qmodel_prep(
216217
model,
217218
dq_dataloader,

fms_mo/fx/dynamo_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,7 @@ def call_seq_hook(mod, *_args, **_kwargs):
12181218
# b) qbmm creation and attaching to model
12191219
if qcfg.get("QBmm"): # see Note 4
12201220
# Local
1221-
from fms_mo.modules import QBmm
1221+
QBmm = qcfg["mapping"]["matmul_or_bmm"]
12221222

12231223
qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][
12241224
"which2patch_contextmanager"

fms_mo/fx/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,6 @@ def lower_qmodel_to_ext_kernels(
172172
QLinearExllamaV2,
173173
)
174174

175-
qclass_accepted = []
176-
for map_dict in qcfg["mapping"].values():
177-
qclass_accepted.append(map_dict["to"])
178-
qclass_accepted.append(map_dict.get("otherwise", None))
179-
180175
mod2swap = {
181176
n: m
182177
for n, m in mod.named_modules()
@@ -498,7 +493,6 @@ def model_size_Wb(mod, unit="MB", print_to_file=True, show_details=False):
498493
)
499494
),
500495
)
501-
502496
if show_details:
503497
logger_or_print(df_summary_weights.to_markdown())
504498

0 commit comments

Comments
 (0)