Skip to content

Commit dd17104

Browse files
fix bug in config check as QBmmMX is optional
Signed-off-by: cliu-us <[email protected]>
1 parent 3446413 commit dd17104

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

fms_mo/utils/qconfig_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,11 +1189,14 @@ def check_config(config, model_dtype=None):
11891189

11901190
mapping = config.get("mapping", None)
11911191

1192-
# partial was used to init this mapping --> use .func pointer
1192+
# partial was used to wrap QLinearMX, will be an instance of partial
1193+
# 1. can use .func pointer to find the original class
1194+
# 2. QBmm is optional, could be partial(QBmmMX,) or QBmm
11931195
if mapping is not None:
11941196
if not mapping[nn.Linear].func is QLinearMX:
11951197
raise ValueError("MX mapping for nn.Linear is not QLinearMX")
11961198

1197-
if mapping["matmul_or_bmm"].func is QBmmMX:
1199+
qbmm_map = mapping["matmul_or_bmm"]
1200+
if not (qbmm_map is QBmm or getattr(qbmm_map, "func", None) is QBmmMX):
11981201
raise ValueError("MX mapping for matmul_or_bmm is not QBmmMX")
11991202
# End mx_specs checks

0 commit comments

Comments
 (0)