@@ -218,8 +218,11 @@ def _convert_tensor_parallel(self, lora_state_dict):
218
218
lora_name_action_mappings [v ] = name_action_mappings [k ]
219
219
220
220
for name , action in lora_name_action_mappings .items ():
221
- tensor = lora_state_dict .pop (name )
222
- lora_state_dict [name ] = action (tensor )
221
+ if name in lora_state_dict :
222
+ tensor = lora_state_dict .pop (name )
223
+ lora_state_dict [name ] = action (tensor )
224
+ else :
225
+ logger .warning (f"{ name } not found in lora_state_dict!" )
223
226
return lora_state_dict
224
227
225
228
def save_pretrained (self , save_directory : str , merge_tensor_parallel : bool = False , ** kwargs ):
@@ -326,9 +329,10 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
326
329
self .add_lora_split_mapping (module_name + ".lora_B" , is_column = True )
327
330
328
331
# for lora qat
329
- self .add_lora_split_mapping (module_name + ".weight_quanter._scale" , is_column = True )
330
- self .add_lora_split_mapping (module_name + ".activation_quanter._scale" , is_column = False )
331
- self .add_lora_split_mapping (module_name + ".activation_quanter.quanter._scale" , is_column = False )
332
+ if self .lora_config .do_qat :
333
+ self .add_lora_split_mapping (module_name + ".weight_quanter._scale" , is_column = True )
334
+ self .add_lora_split_mapping (module_name + ".activation_quanter._scale" , is_column = False )
335
+ self .add_lora_split_mapping (module_name + ".activation_quanter.quanter._scale" , is_column = False )
332
336
elif isinstance (module , RowParallelLinear ):
333
337
# recover the original output_features
334
338
lora_module = RowParallelLoRALinear (
@@ -345,9 +349,10 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
345
349
self .add_lora_split_mapping (module_name + ".lora_A" , is_column = False )
346
350
347
351
# for lora qat
348
- self .add_lora_split_mapping (module_name + ".weight_quanter._scale" , is_column = False )
349
- self .add_lora_split_mapping (module_name + ".activation_quanter._scale" , is_column = False )
350
- self .add_lora_split_mapping (module_name + ".activation_quanter.quanter._scale" , is_column = False )
352
+ if self .lora_config .do_qat :
353
+ self .add_lora_split_mapping (module_name + ".weight_quanter._scale" , is_column = False )
354
+ self .add_lora_split_mapping (module_name + ".activation_quanter._scale" , is_column = False )
355
+ self .add_lora_split_mapping (module_name + ".activation_quanter.quanter._scale" , is_column = False )
351
356
elif QuantizationLinear is not None and isinstance (module , QuantizationLinear ):
352
357
lora_module = QuantizationLoRALinear (
353
358
in_features = module .in_features ,
0 commit comments