@@ -34,15 +34,7 @@ def forward_with_lora(x, layer1=original_forward, layer2=lora):
3434
3535def load_lora (model , path ):
3636 state_dict = torch .load (path , map_location = model .device )
37-
38- # 兼容DDP训练保存的权重(带有module.前缀),去除前缀以匹配当前模型
39- new_state_dict = {}
40- for k , v in state_dict .items ():
41- if k .startswith ('module.' ):
42- new_state_dict [k [7 :]] = v
43- else :
44- new_state_dict [k ] = v
45- state_dict = new_state_dict
37+ state_dict = {(k [7 :] if k .startswith ('module.' ) else k ): v for k , v in state_dict .items ()}
4638
4739 for name , module in model .named_modules ():
4840 if hasattr (module , 'lora' ):
@@ -54,6 +46,7 @@ def save_lora(model, path):
5446 state_dict = {}
5547 for name , module in model .named_modules ():
5648 if hasattr (module , 'lora' ):
57- lora_state = {f'{ name } .lora.{ k } ' : v for k , v in module .lora .state_dict ().items ()}
49+ clean_name = name [7 :] if name .startswith ("module." ) else name
50+ lora_state = {f'{ clean_name } .lora.{ k } ' : v for k , v in module .lora .state_dict ().items ()}
5851 state_dict .update (lora_state )
5952 torch .save (state_dict , path )
0 commit comments