Skip to content

Commit 308e5e2

Browse files
Merge pull request #3 from BrandonGroth/mx_impl_brandon
fix: Setting mx_specs outside qconfig_init
2 parents dd17104 + 5694a5a commit 308e5e2

File tree

6 files changed

+55
-9
lines changed

6 files changed

+55
-9
lines changed

.spellcheck-en-custom.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,7 @@ venv
117117
vllm
118118
xs
119119
zp
120-
120+
microxcaling
121+
MX
122+
MXINT
123+
MXFP

examples/MX/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# `microscaling` Examples Using a Toy Model and Direct Quantization (DQ)
22
Here, we provide two simple examples of using MX format in `fms-mo`.
3-
"MX format", such as `MXFP8`, is a different format compared to typical IEEE formats, e.g. PyTorch FP8s (`e4m3` or `e5m2`, see our other [FP8 example](../FP8_QUANT/README.md).) Mainly all the `mx` format are group-based where each member of the group is using the specified format, e.g. FP8 for MXFP8 while each group has a shared (usualy 8-bit) "scale". Group size could be as small as 32 or 16, depending on hardware design.
3+
"MX format", such as `MXFP8`, is a different format compared to typical IEEE formats, e.g. PyTorch FP8s (`e4m3` or `e5m2`, see our other [FP8 example](../FP8_QUANT/README.md).) Mainly all the `mx` format are group-based where each member of the group is using the specified format, e.g. FP8 for MXFP8 while each group has a shared (usually 8-bit) "scale". Group size could be as small as 32 or 16, depending on hardware design.
44
> [!NOTE]
55
It is important to keep in mind that `mx` is not natively supported by Hopper GPUs yet (some will be supported by Blackwell), which means the quantization configurations and corresponding behavior are simulated, i.e. no real "speed up" should be expected.
66

@@ -23,7 +23,7 @@ Expected output includes:
2323

2424
```
2525

26-
The second example is the same as in the [DQ](../DQ_SQ/README.md) folder, except using [microscaling](https://arxiv.org/abs/2310.10537) format. We demonstrate the effect of MXINT8, MXFP8, MXFP6, MXFP4 for weights, activations, and/or KV-cache.
26+
The second example is the same as in the [DQ](../DQ_SQ/README.md) folder, except using [microxcaling](https://arxiv.org/abs/2310.10537) format. We demonstrate the effect of MXINT8, MXFP8, MXFP6, MXFP4 for weights, activations, and/or KV-cache.
2727

2828
**1. Prepare Data** for calibration process by converting into its tokenized form. An example of tokenization using `LLAMA-3-8B`'s tokenizer is below.
2929

fms_mo/prep.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from fms_mo.calib import qmodel_calib
2929
from fms_mo.modules import QBmm_modules, QConv2d_modules, QLinear_modules, QLSTM_modules
3030
from fms_mo.quant.quantizers import Qbypass
31-
from fms_mo.utils.qconfig_utils import check_config, qconfig_save
31+
from fms_mo.utils.import_utils import available_packages
32+
from fms_mo.utils.qconfig_utils import check_config, qconfig_save, set_mx_specs
3233
from fms_mo.utils.utils import prepare_inputs
3334

3435
# import numpy as np # only used in experimental func
@@ -197,6 +198,17 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
197198
qa_mode = qcfg.get("qa_mode", "pact+")
198199
qw_mode = qcfg.get("qw_mode", "sawb+")
199200

201+
# Check if MX has been set outside of qconfig_init without mx_specs being created
202+
if (
203+
available_packages["mx"]
204+
and "mx_specs" not in qcfg
205+
and (
206+
(qcfg["qa_mode"].startswith("mx_") and qcfg["qw_mode"].startswith("mx_"))
207+
or any(key.startswith("mx_") for key in qcfg.keys())
208+
)
209+
):
210+
set_mx_specs(qcfg, use_mx=True)
211+
200212
# check if on "black list" (need to be exact match), can be skipped or quantized those
201213
# to slightly higher "default" precision, or use qspecial_layers to have fine control
202214
if curr_full_name in qcfg["qskip_layer_name"]:

fms_mo/utils/qconfig_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def config_defaults():
5757
("bmm1_qm1_mode", "pact"),
5858
("bmm1_qm2_mode", "pact"),
5959
("bmm2_qm1_mode", "pact"),
60-
("bmm1_qm2_mode", "pact"),
60+
("bmm2_qm2_mode", "pact"),
6161
# mode_calib vars
6262
("qa_mode_calib", "percentile"),
6363
("qw_mode_calib", "percentile"),
@@ -1193,10 +1193,14 @@ def check_config(config, model_dtype=None):
11931193
# 1. can use .func pointer to find the original class
11941194
# 2. QBmm is optional, could be partial(QBmmMX,) or QBmm
11951195
if mapping is not None:
1196-
if not mapping[nn.Linear].func is QLinearMX:
1196+
if mapping[nn.Linear].func is not QLinearMX:
11971197
raise ValueError("MX mapping for nn.Linear is not QLinearMX")
11981198

11991199
qbmm_map = mapping["matmul_or_bmm"]
1200-
if not (qbmm_map is QBmm or getattr(qbmm_map, "func", None) is QBmmMX):
1201-
raise ValueError("MX mapping for matmul_or_bmm is not QBmmMX")
1200+
if bmm_mode_consistency > 0:
1201+
if getattr(qbmm_map, "func", None) is not QBmmMX:
1202+
raise ValueError("MX mapping for matmul_or_bmm is not QBmmMX")
1203+
else:
1204+
if qbmm_map is not QBmm:
1205+
raise ValueError("Mapping for matmul_or_bmm is not QBmm")
12021206
# End mx_specs checks

patches/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ To make a git diff patch file, first make your desired changes to the repository
1515
```
1616
git diff > <package>.patch
1717
```
18-
Packages may include files that differ by whitespaces even if you didn't change them.
18+
Packages may include files that differ by white spaces even if you didn't change them.
1919
To address this, add `--ignore-all-spaces` to the `git diff` command.
2020

2121
To test the patch file, copy the `<package>.patch` file to `fms-model-optimizer/patches`.

tests/models/test_mx.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,30 @@ def test_residualMLP(
134134
assert module.mx_specs["a_elem_format"] == mx_format
135135

136136
assert found_qmodule_mx
137+
138+
139+
@pytest.mark.skipif(
140+
not available_packages["mx"],
141+
reason="Skipping mx_specs error test; No package found",
142+
)
143+
def test_mx_specs_after_qconfig_init(
144+
model_residualMLP: torch.nn.Module,
145+
input_residualMLP: torch.FloatTensor,
146+
config_fp32: dict,
147+
):
148+
"""
149+
Test if a default config w/ MX qmodes trigger setting mx_specs inside qmodel_prep
150+
151+
Args:
152+
model_residualMLP (torch.nn.Module): Single fp32 model.
153+
input_residualMLP (torch.FloatTensor): Random 16x128 tensor.
154+
config_fp32 (dict): Config w/ fp32 settings.
155+
"""
156+
config_fp32["qa_mode"] = "mx_fp8_e5m2"
157+
config_fp32["qw_mode"] = "mx_fp8_e5m2"
158+
159+
assert "mx_specs" not in config_fp32
160+
161+
qmodel_prep(model_residualMLP, input_residualMLP, config_fp32, use_dynamo=True)
162+
163+
assert "mx_specs" in config_fp32

0 commit comments

Comments
 (0)