Skip to content

Commit 5694a5a

Browse files
committed
fix: Changed check_config check for MX mapping
Signed-off-by: Brandon Groth <[email protected]>
1 parent 7daae3b commit 5694a5a

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

fms_mo/utils/qconfig_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def config_defaults():
5757
("bmm1_qm1_mode", "pact"),
5858
("bmm1_qm2_mode", "pact"),
5959
("bmm2_qm1_mode", "pact"),
60-
("bmm1_qm2_mode", "pact"),
60+
("bmm2_qm2_mode", "pact"),
6161
# mode_calib vars
6262
("qa_mode_calib", "percentile"),
6363
("qw_mode_calib", "percentile"),
@@ -1193,10 +1193,14 @@ def check_config(config, model_dtype=None):
11931193
# 1. can use .func pointer to find the original class
11941194
# 2. QBmm is optional, could be partial(QBmmMX,) or QBmm
11951195
if mapping is not None:
1196-
if not mapping[nn.Linear].func is QLinearMX:
1196+
if mapping[nn.Linear].func is not QLinearMX:
11971197
raise ValueError("MX mapping for nn.Linear is not QLinearMX")
11981198

11991199
qbmm_map = mapping["matmul_or_bmm"]
1200-
if not (qbmm_map is QBmm or getattr(qbmm_map, "func", None) is QBmmMX):
1201-
raise ValueError("MX mapping for matmul_or_bmm is not QBmmMX")
1200+
if bmm_mode_consistency > 0:
1201+
if getattr(qbmm_map, "func", None) is not QBmmMX:
1202+
raise ValueError("MX mapping for matmul_or_bmm is not QBmmMX")
1203+
else:
1204+
if qbmm_map is not QBmm:
1205+
raise ValueError("Mapping for matmul_or_bmm is not QBmm")
12021206
# End mx_specs checks

0 commit comments

Comments
 (0)