|
| 1 | +# `microscaling` Examples Using a Toy Model and Direct Quantization (DQ) |
| 2 | +Microscaling, or "MX", format, such as `MXFP8`, is a different numeric format compared to commonly used FP8 formats. For example, PyTorch provides two FP8 formats, which are 1 sign bit, 4 exponent bits, and 3 mantissa bits (denoted as `e4m3`) or 1 sign bit, 5 exponent bits, and 2 mantissa bits (`e5m2`), see our other [FP8 example](../FP8_QUANT/README.md) for more details. On the other hand, all the `mx` formats are group-based data structure where each member of the group is using the specified format, e.g. FP8 for MXFP8, while each group has a shared (usually 8-bit) "scale". Group size could be as small as 32 or 16, depending on hardware design. One may consider each MXFP8 number actually requires 8.25 bits (when group size is 32) instead of 8 bits. More details about microscaling can be found in [this OCP document](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). |
| 3 | + |
| 4 | +Here, we provide two simple examples of using MX format in `fms-mo`. |
| 5 | + |
| 6 | +> [!NOTE] |
| 7 | +It is important to keep in mind that `mx` is not natively supported by Hopper GPUs yet (some will be supported by Blackwell), which means the quantization configurations and corresponding behavior are simulated. Hence, no real "speed up" should be expected. |
| 8 | + |
| 9 | + |
| 10 | +## Requirements |
| 11 | +- [FMS Model Optimizer requirements](../../README.md#requirements) |
| 12 | +- Microsoft `microxcaling` python package, download [here](https://github.com/microsoft/microxcaling.git). |
| 13 | +> [!TIP] |
| 14 | +> `FMS-Model-Optimizer` and `microxcaling` have clashing dependency requirements for `PyTorch` packages. We have created a patching solution to resolve this, run the following in command line: |
| 15 | +``` bash |
| 16 | +python3 ../install_patches.py |
| 17 | +``` |
| 18 | +This patching file will either download the repo for you, or look for an already installed version in `$HOME` or the current working directory, then install the patch. |
| 19 | +For more information, see `patches/README.md`. |
| 20 | + |
| 21 | +## QuickStart |
| 22 | + |
| 23 | +### Example 1 |
| 24 | +First example is based on a toy model with only a few Linear layers, in which only one Linear layer will be quantized with MX version of `int8`, `int4`, `fp8`, and `fp4`. The example can simply be run as follow |
| 25 | + |
| 26 | +```bash |
| 27 | +>>> python simple_mx_example.py |
| 28 | +``` |
| 29 | + |
| 30 | +Comparison between different formats, including the first 3 elements from output tensors and the norm compared to FP32 reference, is shown below. |
| 31 | + |
| 32 | +| dtype | output[0, 0] | output[0, 1] | output[0, 2] | \|\|ref - out_dtype\|\|<sub>2</sub> | |
| 33 | +|:-----------|---------------:|---------------:|---------------:|------------------------:| |
| 34 | +| fp32 | -1.0491 | 0.5312 | -1.6387 | 0.0000 | |
| 35 | +| fmsmo_int8 | -1.0577 | 0.5346 | -1.6508 | 0.4937 | |
| 36 | +| fmsmo_int4 | -0.5885 | 0.5831 | -1.7976 | 8.2927 | |
| 37 | +| mxint8 | -0.6444 | 0.6828 | -1.8626 | 8.3305 | |
| 38 | +| mxint4 | -0.9089 | 0.6141 | -1.7630 | 8.0692 | |
| 39 | +| mxfp8_e4m3 | -0.8031 | 0.7262 | -1.9581 | 7.8554 | |
| 40 | +| mxfp8_e5m2 | -0.8471 | 0.7319 | -1.7458 | 8.1838 | |
| 41 | +| mxfp4_e2m1 | -0.7506 | 0.6123 | -1.9311 | 7.9936 | |
| 42 | + |
| 43 | + |
| 44 | +### Example 2 |
| 45 | +The second example is the same as the [DQ example](../DQ_SQ/README.md), except using [microxcaling](https://arxiv.org/abs/2310.10537) format. We only demonstrate `mxfp8` and `mxfp4` here, but MXINT8, MXFP8, MXFP6, MXFP4 are also available for weights, activations, and/or KV-cache. |
| 46 | + |
| 47 | +**1. Prepare Data** for calibration process by converting into its tokenized form. An example of tokenization using `LLAMA-3-8B`'s tokenizer is below. |
| 48 | + |
| 49 | +```python |
| 50 | +from transformers import AutoTokenizer |
| 51 | +from fms_mo.utils.calib_data import get_tokenized_data |
| 52 | + |
| 53 | +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", use_fast=True) |
| 54 | +num_samples = 128 |
| 55 | +seq_len = 2048 |
| 56 | +get_tokenized_data("wiki", num_samples, seq_len, tokenizer, path_to_save='data') |
| 57 | +``` |
| 58 | +> [!NOTE] |
| 59 | +> - Users should provide a tokenized data file based on their need. This is just one example to demonstrate what data format `fms_mo` is expecting. |
| 60 | +> - Tokenized data will be saved in `<path_to_save>_train` and `<path_to_save>_test` |
| 61 | +> - If you have trouble downloading Llama family of models from Hugging Face ([LLama models require access](https://www.llama.com/docs/getting-the-models/hugging-face/)), you can use `ibm-granite/granite-8b-code` instead |
| 62 | +
|
| 63 | +**2. Apply DQ** by providing specific hyper-parameters such as `quant_method`, weight quantizers (`qw_mode`) and activation quantizers (`qa_mode`) etc. An example using `Meta-Llama-3-8B` and the tokenized training and test data is provided below. |
| 64 | +```bash |
| 65 | +python -m fms_mo.run_quant \ |
| 66 | + --model_name_or_path "meta-llama/Meta-Llama-3-8B" \ |
| 67 | + --training_data_path data_train \ |
| 68 | + --test_data_path data_test \ |
| 69 | + --torch_dtype "float16" \ |
| 70 | + --quant_method dq \ |
| 71 | + --nbits_w 8 \ |
| 72 | + --nbits_a 8 \ |
| 73 | + --nbits_kvcache 32 \ |
| 74 | + --qa_mode "mx_fp8_e4m3"\ |
| 75 | + --qw_mode "mx_fp8_e4m3" \ |
| 76 | + --output_dir "dq_test" \ |
| 77 | + --eval_ppl |
| 78 | +``` |
| 79 | +> [!NOTE] |
| 80 | +> To use MX format, simply assign `qa_mode` and `qw_mode` argument with a `mx_<dtype supported by mx package>`, e.g. `mx_fp8_e4m3` as in the above example. Corresponding `QLinearMX` wrappers will be used in place of `QLinear` as in other examples. |
| 81 | +
|
| 82 | +**3. Compare the Perplexity score** For user convenience, the code will print out perplexity (controlled by `eval_ppl` flag) at the end of the run, so no additional steps needed (if the logging level is set to `INFO` in terminal). You can check output in the logging file. `./fms_mo.log`. |
| 83 | + |
| 84 | + |
| 85 | +## Example Test Results |
| 86 | +The perplexity of the INT8 and FP8 quantized models on the `wikitext` dataset is shown below: |
| 87 | + |
| 88 | +| Model |Type |QA |QW |DQ |SQ |Perplexity| |
| 89 | +|:---------:|:---:|:------------:|:------------:|:--:|:--:|:--------:| |
| 90 | +|`Llama3-8b`|INT8 |maxpertoken |maxperCh |yes |yes |6.22 | |
| 91 | +| |FP8 |fp8_e4m3_scale|fp8_e4m3_scale|yes |yes |6.19 | |
| 92 | +| |**MX**|mx_fp8_e4m3 |mx_fp8_e4m3 |yes |**no** |6.23 | |
| 93 | +| |**MX**|mx_fp4_e2m1 |mx_fp4_e2m1 |yes |**no** |8.22 | |
| 94 | + |
| 95 | + |
| 96 | +> [!NOTE] |
| 97 | +> SmoothQuant is disabled when `mx` is being used. See `dq.py` for more details. |
| 98 | +
|
0 commit comments