Skip to content

Commit ab060ca

Browse files
update mx readme.md, fix simple_mx_exam.py together with a few minor bugs
Signed-off-by: cliu-us <[email protected]>
1 parent e5eb0ed commit ab060ca

File tree

4 files changed

+66
-146
lines changed

4 files changed

+66
-146
lines changed

examples/MX/README.md

Lines changed: 25 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# `microscaling` Examples Using a Toy Model and Direct Quantization (DQ)
2+
Microscaling, or "MX", format, such as `MXFP8`, is a different numeric format compared to commonly used FP8 formats. For example, PyTorch provides two FP8 formats, which are 1 sign bit, 4 exponent bits, and 3 mantissa bits (denoted as `e4m3`) or 1 sign bit, 5 exponent bits, and 2 mantissa bits (`e5m2`), see our other [FP8 example](../FP8_QUANT/README.md) for more details. On the other hand, all the `mx` formats are group-based data structure where each member of the group is using the specified format, e.g. FP8 for MXFP8, while each group has a shared (usually 8-bit) "scale". Group size could be as small as 32 or 16, depending on hardware design. One may consider each MXFP8 number actually requires 8.25 bits (when group size is 32) instead of 8 bits.
3+
24
Here, we provide two simple examples of using MX format in `fms-mo`.
3-
"MX format", such as `MXFP8`, is a different format compared to typical IEEE formats, e.g. PyTorch FP8s (`e4m3` or `e5m2`, see our other [FP8 example](../FP8_QUANT/README.md).) Mainly all the `mx` format are group-based where each member of the group is using the specified format, e.g. FP8 for MXFP8 while each group has a shared (usually 8-bit) "scale". Group size could be as small as 32 or 16, depending on hardware design.
5+
46
> [!NOTE]
5-
It is important to keep in mind that `mx` is not natively supported by Hopper GPUs yet (some will be supported by Blackwell), which means the quantization configurations and corresponding behavior are simulated, i.e. no real "speed up" should be expected.
7+
It is important to keep in mind that `mx` is not natively supported by Hopper GPUs yet (some will be supported by Blackwell), which means the quantization configurations and corresponding behavior are simulated. Hence, no real "speed up" should be expected.
68

79

810
## Requirements
@@ -18,16 +20,28 @@ For more information, see `patches/README.md`.
1820

1921
## QuickStart
2022

23+
### Example 1
2124
First example is based on a toy model with only a few Linear layers, in which only one Linear layer will be quantized with MX version of `int8`, `int4`, `fp8`, and `fp4`. The example can simply be run as follow
2225

2326
```bash
2427
>>> python simple_mx_example.py
2528
```
26-
Expected output includes:
27-
```bash
2829

29-
```
30+
Expected output:
31+
32+
| dtype | output[0, 0] | output[0, 1] | output[0, 2] | \|\|ref - out_dtype\|\|<sub>2</sub> |
33+
|:-----------|---------------:|---------------:|---------------:|------------------------:|
34+
| fp32 | -1.0491 | 0.5312 | -1.6387 | 0.0000 |
35+
| fmsmo_int8 | -1.0577 | 0.5346 | -1.6508 | 0.4937 |
36+
| fmsmo_int4 | -0.5885 | 0.5831 | -1.7976 | 8.2927 |
37+
| mxint8 | -0.6444 | 0.6828 | -1.8626 | 8.3305 |
38+
| mxint4 | -0.9089 | 0.6141 | -1.7630 | 8.0692 |
39+
| mxfp8_e4m3 | -0.8031 | 0.7262 | -1.9581 | 7.8554 |
40+
| mxfp8_e5m2 | -0.8471 | 0.7319 | -1.7458 | 8.1838 |
41+
| mxfp4_e2m1 | -0.7506 | 0.6123 | -1.9311 | 7.9936 |
42+
3043

44+
### Example 2
3145
The second example is the same as in the [DQ](../DQ_SQ/README.md) folder, except using [microxcaling](https://arxiv.org/abs/2310.10537) format. We demonstrate the effect of MXINT8, MXFP8, MXFP6, MXFP4 for weights, activations, and/or KV-cache.
3246

3347
**1. Prepare Data** for calibration process by converting into its tokenized form. An example of tokenization using `LLAMA-3-8B`'s tokenizer is below.
@@ -62,134 +76,23 @@ python -m fms_mo.run_quant \
6276
--output_dir "dq_test" \
6377
--eval_ppl
6478
```
65-
> [!TIP]
79+
> [!NOTE]
6680
> To use MX format, simply assign `qa_mode` and `qw_mode` argument with a `mx_<dtype supported by mx package>`, e.g. `mx_fp8_e4m3` as in the above example. Corresponding `QLinearMX` wrappers will be used in place of `QLinear` as in other examples.
6781
6882
**3. Compare the Perplexity score** For user convenience, the code will print out perplexity (controlled by `eval_ppl` flag) at the end of the run, so no additional steps needed (if the logging level is set to `INFO` in terminal). You can check output in the logging file. `./fms_mo.log`.
6983

70-
# *TO BE UPDATED BELOW THIS LINE*
71-
7284

7385
## Example Test Results
7486
The perplexity of the INT8 and FP8 quantized models on the `wikitext` dataset is shown below:
7587

7688
| Model |Type |QA |QW |DQ |SQ |Perplexity|
7789
|:---------:|:---:|:------------:|:------------:|:--:|:--:|:--------:|
78-
|`Llama3-8b`|INT8 |maxpertoken |maxperCh |yes |yes |6.21 |
90+
|`Llama3-8b`|INT8 |maxpertoken |maxperCh |yes |yes |6.22 |
7991
| |FP8 |fp8_e4m3_scale|fp8_e4m3_scale|yes |yes |6.19 |
92+
| |**MX**|mx_fp8_e4m3 |mx_fp8_e4m3 |yes |**no** |6.23 |
93+
| |**MX**|mx_fp4_e2m1 |mx_fp4_e2m1 |yes |**no** |8.22 |
8094

81-
## Code Walk-through
82-
83-
**1. KV caching**
84-
85-
In large language models (LLMs), key/value pairs are frequently cached during token generation, a process known as KV caching, to prevent redundant computations due to the autoregressive nature of token generation. However, the size of the KV cache increases with both batch size and context length, which can slow down model inference due to the need to access a large amount of data in memory. Quantizing the KV cache effectively reduces this memory bandwidth limitation, improving inference speed. To study the quantization behavior of KV cache, we can simply set the `nbits_kvcache` argument to 8-bit, then the KV cache will be quantized together with weights and activations. In addition, the `bmm1_qm1_mode`, `bmm1_qm2_mode`, and `bmm2_qm2_mode` [arguments](../../fms_mo/training_args.py) must be set to the same quantizer mode as `qa_mode`. **NOTE**: `bmm2_qm1_mode` should be kept as `minmax`.
86-
87-
The effect of setting the `nbits_kvcache` to 8 and its relevant code sections are:
88-
89-
- Enables eager attention for the quantization of attention operations, including KV cache.
90-
```python
91-
# For attention or kv-cache quantization, need to use eager attention
92-
attn_bits = [fms_mo_args.nbits_bmm1, fms_mo_args.nbits_bmm2, fms_mo_args.nbits_kvcache]
93-
if any(attn_bits) != 32:
94-
attn_implementation = "eager"
95-
else:
96-
attn_implementation = None
97-
```
98-
- Enables Dynamo for quantized model preparation. We use PyTorch's Dynamo tracer to identify the bmm and KV cache inside the attention block.
99-
```python
100-
if any(x != 32 for x in attn_bits):
101-
logger.info("Quantize attention bmms or kvcache, use dynamo for prep")
102-
use_layer_name_pattern_matching = False
103-
qcfg["qlayer_name_pattern"] = []
104-
assert (
105-
qcfg["qlayer_name_pattern"] == []
106-
), "ensure nothing in qlayer_name_pattern when use dynamo"
107-
use_dynamo = True
108-
else:
109-
logger.info("Do not quantize attention bmms")
110-
use_layer_name_pattern_matching = True
111-
use_dynamo = False
112-
```
113-
114-
**2. Define quantization config** including quantizers and hyperparameters. Here we simply use the default [dq recipe](../../fms_mo/recipies/dq.json).
11595

116-
```python
117-
qcfg = qconfig_init(recipe="dq",args=fms_mo_args)
118-
```
119-
120-
**3. Obtain activation scales for SmoothQuant (SQ)**
121-
122-
``` python
123-
# For loading or creating smoothquant scale.
124-
act_scale_directory = "./act_scales"
125-
if not os.path.exists(act_scale_directory):
126-
os.makedirs(act_scale_directory)
127-
128-
if qcfg["act_scale_path"] is not None:
129-
act_scales = torch.load(qcfg["act_scale_path"], map_location="cpu")
130-
else:
131-
logger.info("Generate activation scales")
132-
if qcfg["large_model"]:
133-
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
134-
else:
135-
act_scales = get_act_scales(model, dq_dataloader, qcfg)
136-
scale_file = f"{act_scale_directory}/{qcfg['model'].replace('/', '-')}" + ".pt"
137-
torch.save(act_scales, scale_file)
138-
```
139-
140-
**4. Prepare the quantized model and attach activation scales** to quantized modules
141-
142-
```python
143-
qmodel_prep(
144-
model,
145-
dq_dataloader,
146-
qcfg,
147-
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
148-
use_dynamo=use_dynamo,
149-
dev=dev,
150-
save_fname='test'
151-
)
152-
153-
dq_llm(model, act_scales, qcfg)
154-
```
155-
156-
**5. Perform direct quantization** by calibrating quantizers (clip_vals)
157-
158-
``` python
159-
if qcfg["qmodel_calibration_new"] > 0:
160-
logger.info("Starting to calibrate activation clip_val")
161-
if qcfg["large_model"]:
162-
calibration_llm_1GPU(qcfg, model, calibration_dataset)
163-
else:
164-
model.to("cuda:0")
165-
pbar = tqdm(
166-
dq_dataloader,
167-
desc=" calibration after applying smoothq scale and before inference",
168-
total=qcfg["qmodel_calibration_new"],
169-
)
170-
for data_mb, _ in zip(pbar, range(qcfg["qmodel_calibration_new"])):
171-
data_mb = prepare_input(model.device, data_mb)
172-
with patch_torch_bmm(qcfg):
173-
model(**data_mb)
174-
175-
logger.info(f"Saving quantized model and tokenizer to {output_dir}")
176-
model.save_pretrained(output_dir, use_safetensors=True)
177-
tokenizer.save_pretrained(output_dir)
178-
```
96+
> [!NOTE]
97+
> SmoothQuant is disabled when `mx` is being used. See `dq.py` for more details.
17998
180-
**6. Check perplexity** (simple method to evaluate the model quality)
181-
182-
``` python
183-
if fms_mo_args.eval_ppl:
184-
logger.info(f"Model for evaluation: {model}")
185-
if qcfg["large_model"]:
186-
eval_llm_1GPU(qcfg, model, test_dataset)
187-
else:
188-
model.to(torch.device("cuda:0"))
189-
n_samples = int(test_dataset.input_ids.shape[1] / block_size)
190-
evaluator = Evaluator(test_dataset, "cuda", n_samples=n_samples)
191-
ppl = evaluator.evaluate(model, block_size=block_size)
192-
logger.info(f"Model perplexity: {ppl}")
193-
logger.info("-" * 50)
194-
logger.info("Finished evaluation")
195-
```

examples/MX/simple_mx_example.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,38 +54,49 @@ def forward(self, inputs):
5454
HIDDEN_DIM = 128
5555
x = np.random.randn(16, HIDDEN_DIM)
5656
x = torch.tensor(x, dtype=torch.float32, device="cuda")
57-
results = {"dtype": [], "output[0, :5]": [], "||ref - out_dtype||_2": []}
57+
results = {
58+
"dtype": [],
59+
"output[0, 0]": [],
60+
"output[0, 1]": [],
61+
"output[0, 2]": [],
62+
"||ref - out_dtype||_2": [],
63+
}
5864

5965
# --- Test 0. Run MLP as is
60-
mlp = ResidualMLP(HIDDEN_DIM)
61-
# mlp.to("cuda")
66+
model = ResidualMLP(HIDDEN_DIM)
6267
with torch.no_grad():
63-
out = mlp(x)
68+
out = model(x)
6469
results["dtype"].append("fp32")
65-
results["output[0, :5]"].append(out[0, :5].tolist())
66-
results["||ref - out_dtype||_2"].append("-")
67-
print(mlp)
70+
results["output[0, 0]"].append(out[0, 0])
71+
results["output[0, 1]"].append(out[0, 1])
72+
results["output[0, 2]"].append(out[0, 2])
73+
results["||ref - out_dtype||_2"].append(0)
74+
print(model)
6875

6976
# --- Test 1. fms-mo qmodel_prep, replace Linear with our QLinear
7077
qcfg = qconfig_init()
7178
qcfg["nbits_a"] = 8
7279
qcfg["nbits_w"] = 8
73-
model = qmodel_prep(mlp, x, qcfg)
80+
qmodel_prep(model, x, qcfg)
7481
with torch.no_grad():
7582
out_dtype = model(x)
76-
results["dtype"].append("fms_int8")
77-
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
83+
results["dtype"].append("fmsmo_int8")
84+
results["output[0, 0]"].append(out_dtype[0, 0])
85+
results["output[0, 1]"].append(out_dtype[0, 1])
86+
results["output[0, 2]"].append(out_dtype[0, 2])
7887
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
79-
# print(model)
88+
print(model)
8089

8190
qcfg["nbits_a"] = 4
8291
qcfg["nbits_w"] = 4
83-
mlp = ResidualMLP(HIDDEN_DIM)
84-
model = qmodel_prep(mlp, x, qcfg)
92+
model = ResidualMLP(HIDDEN_DIM)
93+
qmodel_prep(model, x, qcfg)
8594
with torch.no_grad():
8695
out_dtype = model(x)
87-
results["dtype"].append("fms_int4")
88-
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
96+
results["dtype"].append("fmsmo_int4")
97+
results["output[0, 0]"].append(out_dtype[0, 0])
98+
results["output[0, 1]"].append(out_dtype[0, 1])
99+
results["output[0, 2]"].append(out_dtype[0, 2])
89100
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
90101
print(model)
91102

@@ -96,15 +107,17 @@ def forward(self, inputs):
96107
for dtype_to_test in ["int8", "int4", "fp8_e4m3", "fp8_e5m2", "fp4_e2m1"]:
97108
qcfg["qw_mode"] = f"mx_{dtype_to_test}"
98109
qcfg["qa_mode"] = f"mx_{dtype_to_test}"
99-
mlp = ResidualMLP(HIDDEN_DIM) # fresh model
100-
model = qmodel_prep(mlp, x, qcfg)
110+
model = ResidualMLP(HIDDEN_DIM) # fresh model
111+
qmodel_prep(model, x, qcfg)
101112
with torch.no_grad():
102113
out_dtype = model(x)
103114
results["dtype"].append(f"mx{dtype_to_test}")
104-
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
115+
results["output[0, 0]"].append(out_dtype[0, 0])
116+
results["output[0, 1]"].append(out_dtype[0, 1])
117+
results["output[0, 2]"].append(out_dtype[0, 2])
105118
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
106119
print(model)
107120

108-
print(tabulate(results, headers="keys"))
121+
print(tabulate(results, headers="keys", tablefmt="pipe", floatfmt=".4f"))
109122

110123
print("DONE!")

fms_mo/utils/qconfig_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,10 @@ def set_mx_specs(
342342

343343
# Check args for any mx_specs vars
344344
use_mx_args = args is not None and any(
345-
hasattr(args, key) for key, _ in fms_defaults.items()
345+
hasattr(args, key)
346+
for key, _ in fms_defaults.items()
347+
if key != "block_size"
348+
# some items are not unique to mx, add names here if needed
346349
)
347350

348351
# Lastly, check for BMM consistency to enable QBmmMX
@@ -479,9 +482,9 @@ def get_mx_specs_defaults():
479482
"a_elem_format_bp_os": "fp8_e4m3",
480483
"shared_exp_method": "max",
481484
"scale_bits": 8,
482-
"block_size": 32,
483-
"bfloat": 16,
484-
"fp": 16,
485+
"block_size": 32, # this item is not unique to mx
486+
"bfloat": 16, # bfloat and fp cannot be set at the same time
487+
"fp": 0,
485488
"round": "nearest",
486489
"round_m": "nearest",
487490
"round_weight": "nearest",
@@ -1100,7 +1103,7 @@ def check_config(config, model_dtype=None):
11001103
mx_spec_int_var_str_defaults = [
11011104
("scale_bits", 8),
11021105
("block_size", 32),
1103-
("fp", 16),
1106+
# ("fp", 16), # can only set either fp or bfloat to non-zero
11041107
("bfloat", 16),
11051108
]
11061109
mx_spec_int_var_values = {2, 4, 6, 8, 16, 32}

fms_mo/utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def prepare_input(
217217
if isinstance(data, torch.Tensor):
218218
kwargs = {"device": device}
219219
return data.to(**kwargs)
220+
220221
logger.warning(
221222
"data input to prepare_input must be Dict, "
222223
"Tuple, List or torch.Tensor and currently is",

0 commit comments

Comments
 (0)