Skip to content

Commit 6e4431c

Browse files
committed
feat: Added qmodel_prep mx_specs hook for default config with MX settings
Signed-off-by: Brandon Groth <[email protected]>
1 parent dd17104 commit 6e4431c

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

fms_mo/prep.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from fms_mo.calib import qmodel_calib
2929
from fms_mo.modules import QBmm_modules, QConv2d_modules, QLinear_modules, QLSTM_modules
3030
from fms_mo.quant.quantizers import Qbypass
31-
from fms_mo.utils.qconfig_utils import check_config, qconfig_save
31+
from fms_mo.utils.import_utils import available_packages
32+
from fms_mo.utils.qconfig_utils import check_config, qconfig_save, set_mx_specs
3233
from fms_mo.utils.utils import prepare_inputs
3334

3435
# import numpy as np # only used in experimental func
@@ -197,6 +198,17 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
197198
qa_mode = qcfg.get("qa_mode", "pact+")
198199
qw_mode = qcfg.get("qw_mode", "sawb+")
199200

201+
# Check if MX has been set outside of qconfig_init without mx_specs being created
202+
if (
203+
available_packages["mx"]
204+
and "mx_specs" not in qcfg
205+
and (
206+
(qcfg["qa_mode"].startswith("mx_") and qcfg["qw_mode"].startswith("mx_"))
207+
or any(key.startswith("mx_") for key in qcfg.keys())
208+
)
209+
):
210+
set_mx_specs(qcfg, use_mx=True)
211+
200212
# check if on "black list" (need to be exact match), can be skipped or quantized those
201213
# to slightly higher "default" precision, or use qspecial_layers to have fine control
202214
if curr_full_name in qcfg["qskip_layer_name"]:

0 commit comments

Comments
 (0)