@@ -57,7 +57,7 @@ def config_defaults():
5757 ("bmm1_qm1_mode" , "pact" ),
5858 ("bmm1_qm2_mode" , "pact" ),
5959 ("bmm2_qm1_mode" , "pact" ),
60- ("bmm1_qm2_mode " , "pact" ),
60+ ("bmm2_qm2_mode " , "pact" ),
6161 # mode_calib vars
6262 ("qa_mode_calib" , "percentile" ),
6363 ("qw_mode_calib" , "percentile" ),
@@ -1193,10 +1193,14 @@ def check_config(config, model_dtype=None):
11931193 # 1. can use .func pointer to find the original class
11941194 # 2. QBmm is optional, could be partial(QBmmMX,) or QBmm
11951195 if mapping is not None :
1196- if not mapping [nn .Linear ].func is QLinearMX :
1196+ if mapping [nn .Linear ].func is not QLinearMX :
11971197 raise ValueError ("MX mapping for nn.Linear is not QLinearMX" )
11981198
11991199 qbmm_map = mapping ["matmul_or_bmm" ]
1200- if not (qbmm_map is QBmm or getattr (qbmm_map , "func" , None ) is QBmmMX ):
1201- raise ValueError ("MX mapping for matmul_or_bmm is not QBmmMX" )
1200+ if bmm_mode_consistency > 0 :
1201+ if getattr (qbmm_map , "func" , None ) is not QBmmMX :
1202+ raise ValueError ("MX mapping for matmul_or_bmm is not QBmmMX" )
1203+ else :
1204+ if qbmm_map is not QBmm :
1205+ raise ValueError ("Mapping for matmul_or_bmm is not QBmm" )
12021206 # End mx_specs checks
0 commit comments