|
28 | 28 | from fms_mo.calib import qmodel_calib |
29 | 29 | from fms_mo.modules import QBmm_modules, QConv2d_modules, QLinear_modules, QLSTM_modules |
30 | 30 | from fms_mo.quant.quantizers import Qbypass |
31 | | -from fms_mo.utils.qconfig_utils import check_config, qconfig_save |
| 31 | +from fms_mo.utils.import_utils import available_packages |
| 32 | +from fms_mo.utils.qconfig_utils import check_config, qconfig_save, set_mx_specs |
32 | 33 | from fms_mo.utils.utils import prepare_inputs |
33 | 34 |
|
34 | 35 | # import numpy as np # only used in experimental func |
@@ -197,6 +198,17 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): |
197 | 198 | qa_mode = qcfg.get("qa_mode", "pact+") |
198 | 199 | qw_mode = qcfg.get("qw_mode", "sawb+") |
199 | 200 |
|
| 201 | + # Check if MX has been set outside of qconfig_init without mx_specs being created |
| 202 | + if ( |
| 203 | + available_packages["mx"] |
| 204 | + and "mx_specs" not in qcfg |
| 205 | + and ( |
| 206 | + (qcfg["qa_mode"].startswith("mx_") and qcfg["qw_mode"].startswith("mx_")) |
| 207 | + or any(key.startswith("mx_") for key in qcfg.keys()) |
| 208 | + ) |
| 209 | + ): |
| 210 | + set_mx_specs(qcfg, use_mx=True) |
| 211 | + |
200 | 212 | # check if on "black list" (need to be exact match), can be skipped or quantized those |
201 | 213 | # to slightly higher "default" precision, or use qspecial_layers to have fine control |
202 | 214 | if curr_full_name in qcfg["qskip_layer_name"]: |
|
0 commit comments