@@ -49,13 +49,21 @@ def __init__(self,
4949 model = model .base_model
5050
5151 if isinstance (config , SwiftConfig ):
52- self .adapters [DEFAULT_ADAPTER ] = self ._prepare_model (
53- model , config , DEFAULT_ADAPTER )
52+ if DEFAULT_ADAPTER not in self .adapters :
53+ self .adapters [DEFAULT_ADAPTER ] = self ._prepare_model (
54+ model , config , DEFAULT_ADAPTER )
55+ else :
56+ logger .warn (
57+ f'Adater { DEFAULT_ADAPTER } has been patched, skip.' )
5458 elif isinstance (config , dict ):
5559 assert (all (isinstance (c , SwiftConfig ) for c in config .values ()))
5660 for adapter_name , _config in config .items ():
57- self .adapters [adapter_name ] = self ._prepare_model (
58- model , _config , adapter_name )
61+ if adapter_name not in self .adapters :
62+ self .adapters [adapter_name ] = self ._prepare_model (
63+ model , _config , adapter_name )
64+ else :
65+ logger .warn (
66+ f'Adater { adapter_name } has been patched, skip.' )
5967 self .model = model
6068
6169 self .extra_state_keys = extra_state_keys or []
@@ -195,7 +203,8 @@ def load_state_file(path):
195203 def from_pretrained (cls ,
196204 model : Union [nn .Module , 'SwiftModel' ],
197205 model_id : str = None ,
198- adapter_name : Union [str , List [str ]] = None ,
206+ adapter_name : Union [str , List [str ], Dict [str ,
207+ str ]] = None ,
199208 inference_mode : bool = False ,
200209 revision : str = None ,
201210 ** kwargs ):
@@ -205,7 +214,7 @@ def from_pretrained(cls,
205214 model (`Union[torch.nn.Module, 'SwiftModel']`): The model to be tuned,
206215 if the model is already a `SwiftModel` it will be un-wrapped and re-wrapped..
207216 model_id (`str`): The model_id or a local model dir of tuners to use to tune the model.
208- adapter_name (`Union[str, List[str]]`): The adapter_names saved in the model repo to load.
217+ adapter_name (`Union[str, List[str], Dict[str, str] ]`): The adapter_names saved in the model repo to load.
209218 Default `None`, means load all tuners saved in the model_id
210219 inference_mode (`bool`): Use in the inference mode or not.
211220 revision (`str`): The model revision to use.
@@ -236,7 +245,8 @@ def from_pretrained(cls,
236245 os .path .isfile (os .path .join (model_dir , sub_dir , CONFIG_NAME ))
237246 ]
238247 for _name in adapter_name if isinstance (adapter_name ,
239- list ) else [adapter_name ]:
248+ list ) else [adapter_name ] \
249+ if isinstance (adapter_name , str ) else adapter_name .keys ():
240250 sub_folder = os .path .join (model_dir , _name )
241251 config_file = os .path .join (sub_folder , CONFIG_NAME )
242252
@@ -250,26 +260,31 @@ def from_pretrained(cls,
250260 if SWIFT_TYPE_KEY not in json_object :
251261 raise ValueError ('Mixed using with peft is not allowed now.' )
252262 else :
253- adapters [_name ] = SwiftConfig .from_pretrained (sub_folder )
263+ key = _name if not isinstance (adapter_name ,
264+ dict ) else adapter_name [_name ]
265+ adapters [key ] = SwiftConfig .from_pretrained (sub_folder )
254266
255267 self = SwiftModel (model , adapters , extra_state_keys , inference_mode ,
256268 ** kwargs )
257269 for _name in adapter_name if isinstance (adapter_name ,
258- list ) else [adapter_name ]:
270+ list ) else [adapter_name ] \
271+ if isinstance (adapter_name , str ) else adapter_name .keys ():
259272 sub_folder = os .path .join (model_dir , _name )
260273 state_dict = cls .load_state_file (sub_folder )
274+ _adapter = _name if not isinstance (adapter_name ,
275+ dict ) else adapter_name [_name ]
261276 if state_dict is not None :
262277 model_is_qlora = len ([
263278 k for k in self .state_dict ().keys ()
264- if k .endswith ('.lora_A.default .weight' )
265- or k .endswith ('.lora_B.default .weight' )
279+ if k .endswith (f '.lora_A.{ _adapter } .weight' )
280+ or k .endswith (f '.lora_B.{ _adapter } .weight' )
266281 ])
267282 if not model_is_qlora :
268283 # model is lora, state_dict: qlora->lora
269284 state_dict = {
270- k [:- len ('.default .weight' ) if k .
271- endswith ('.lora_A.default .weight' ) or k .
272- endswith ('.lora_B.default .weight' ) else None ]: v
285+ k [:- len (f'. { _name } .weight' ) if k .
286+ endswith (f '.lora_A.{ _name } .weight' ) or k .
287+ endswith (f '.lora_B.{ _name } .weight' ) else None ]: v
273288 for k , v in state_dict .items ()
274289 }
275290 if any (['loramodule' in key for key in state_dict ]):
@@ -288,7 +303,13 @@ def from_pretrained(cls,
288303 f'lora_B.{ _name } .weight' ): value
289304 for key , value in state_dict .items ()
290305 }
291- self .load_state_dict (state_dict , adapter_name = _name )
306+ if isinstance (adapter_name , dict ):
307+ # TODO this logic is fragile! replace `_name` may cause other parts replaced
308+ state_dict = {
309+ key .replace (_name , adapter_name [_name ]): value
310+ for key , value in state_dict .items ()
311+ }
312+ self .load_state_dict (state_dict , adapter_name = _adapter )
292313 state_dict = cls .load_state_file (model_dir )
293314 if state_dict is not None :
294315 self .load_state_dict (state_dict )
@@ -569,7 +590,8 @@ def unmerge(model: Union[PeftModel, SwiftModel], **kwargs):
569590 @staticmethod
570591 def from_pretrained (model : Union [nn .Module , SwiftModel ],
571592 model_id : str = None ,
572- adapter_name : Union [str , List [str ]] = None ,
593+ adapter_name : Union [str , List [str ], Dict [str ,
594+ str ]] = None ,
573595 revision : str = None ,
574596 ** kwargs ):
575597 """Prepare a model by a model_id in the ModelScope hub or a local dir.
@@ -593,7 +615,8 @@ def from_pretrained(model: Union[nn.Module, SwiftModel],
593615 is_peft_model = SWIFT_TYPE_KEY not in _json
594616
595617 _name = adapter_name if isinstance (
596- adapter_name , str ) or adapter_name is None else adapter_name [0 ]
618+ adapter_name , str ) or adapter_name is None else adapter_name [0 ] \
619+ if isinstance (adapter_name , list ) else list (adapter_name .keys ())[0 ]
597620 _name = _name or ''
598621 if os .path .exists (os .path .join (model_id , _name , CONFIG_NAME )):
599622 with open (os .path .join (model_id , _name , CONFIG_NAME ), 'r' ) as f :
0 commit comments