File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -1189,11 +1189,14 @@ def check_config(config, model_dtype=None):
11891189
11901190 mapping = config .get ("mapping" , None )
11911191
1192- # partial was used to init this mapping --> use .func pointer
1192+ # partial was used to wrap QLinearMX, will be an instance of partial
1193+ # 1. can use .func pointer to find the original class
1194+ # 2. QBmm is optional, could be partial(QBmmMX,) or QBmm
11931195 if mapping is not None :
11941196 if not mapping [nn .Linear ].func is QLinearMX :
11951197 raise ValueError ("MX mapping for nn.Linear is not QLinearMX" )
11961198
1197- if mapping ["matmul_or_bmm" ].func is QBmmMX :
1199+ qbmm_map = mapping ["matmul_or_bmm" ]
1200+ if not (qbmm_map is QBmm or getattr (qbmm_map , "func" , None ) is QBmmMX ):
11981201 raise ValueError ("MX mapping for matmul_or_bmm is not QBmmMX" )
11991202 # End mx_specs checks
You can’t perform that action at this time.
0 commit comments