@@ -70,11 +70,14 @@ def _forward(self, *args, **kwargs):
7070 setattr (module , f'_part_{ adapter_name } ' , new_module )
7171 new_module .requires_grad_ (True )
7272
73- def state_dict_callback (state_dict , adapter_name ):
73+ def state_dict_callback (state_dict , adapter_name , ** kwargs ):
7474 new_state_dict = {}
7575 for key , value in state_dict .items ():
7676 if f'_part_{ adapter_name } .' in key :
77- new_key = key .replace (f'_part_{ adapter_name } .' , '' ).replace ('base_layer.' , '' )
77+ if kwargs .get ('replace_key' , True ):
78+ new_key = key .replace (f'_part_{ adapter_name } .' , '' ).replace ('base_layer.' , '' )
79+ else :
80+ new_key = key
7881 new_state_dict [new_key ] = value
7982
8083 return new_state_dict
@@ -90,11 +93,14 @@ def load_state_dict_callback(model: nn.Module, adapter_name: str, state_dict: Di
9093 for param_name in state_dict :
9194 if param_name .startswith (name ):
9295 end = param_name [len (name ):]
93- if hasattr (module , 'base_layer' ):
94- new_state_dict [name + f'.base_layer._part_{ adapter_name } '
95- + end ] = state_dict [param_name ]
96+ if '_part_' not in param_name :
97+ if hasattr (module , 'base_layer' ):
98+ new_state_dict [name + f'.base_layer._part_{ adapter_name } '
99+ + end ] = state_dict [param_name ]
100+ else :
101+ new_state_dict [name + f'._part_{ adapter_name } ' + end ] = state_dict [param_name ]
96102 else :
97- new_state_dict [name + f'._part_ { adapter_name } ' + end ] = state_dict [param_name ]
103+ new_state_dict [param_name ] = state_dict [param_name ]
98104 return new_state_dict
99105
100106 return SwiftOutput (
0 commit comments