1414
1515import  copy 
1616import  inspect 
17+ import  json 
1718import  os 
1819from  pathlib  import  Path 
1920from  typing  import  Callable , Dict , List , Optional , Union 
4546    set_adapter_layers ,
4647    set_weights_and_activate_adapters ,
4748)
49+ from  ..utils .state_dict_utils  import  _load_sft_state_dict_metadata 
4850
4951
5052if  is_transformers_available ():
6264
6365LORA_WEIGHT_NAME  =  "pytorch_lora_weights.bin" 
6466LORA_WEIGHT_NAME_SAFE  =  "pytorch_lora_weights.safetensors" 
67+ LORA_ADAPTER_METADATA_KEY  =  "lora_adapter_metadata" 
6568
6669
6770def  fuse_text_encoder_lora (text_encoder , lora_scale = 1.0 , safe_fusing = False , adapter_names = None ):
@@ -206,6 +209,7 @@ def _fetch_state_dict(
206209    subfolder ,
207210    user_agent ,
208211    allow_pickle ,
212+     metadata = None ,
209213):
210214    model_file  =  None 
211215    if  not  isinstance (pretrained_model_name_or_path_or_dict , dict ):
@@ -236,11 +240,14 @@ def _fetch_state_dict(
236240                    user_agent = user_agent ,
237241                )
238242                state_dict  =  safetensors .torch .load_file (model_file , device = "cpu" )
243+                 metadata  =  _load_sft_state_dict_metadata (model_file )
244+ 
239245            except  (IOError , safetensors .SafetensorError ) as  e :
240246                if  not  allow_pickle :
241247                    raise  e 
242248                # try loading non-safetensors weights 
243249                model_file  =  None 
250+                 metadata  =  None 
244251                pass 
245252
246253        if  model_file  is  None :
@@ -261,10 +268,11 @@ def _fetch_state_dict(
261268                user_agent = user_agent ,
262269            )
263270            state_dict  =  load_state_dict (model_file )
271+             metadata  =  None 
264272    else :
265273        state_dict  =  pretrained_model_name_or_path_or_dict 
266274
267-     return  state_dict 
275+     return  state_dict ,  metadata 
268276
269277
270278def  _best_guess_weight_name (
@@ -306,6 +314,11 @@ def _best_guess_weight_name(
306314    return  weight_name 
307315
308316
317+ def  _pack_dict_with_prefix (state_dict , prefix ):
318+     sd_with_prefix  =  {f"{ prefix }  .{ key }  " : value  for  key , value  in  state_dict .items ()}
319+     return  sd_with_prefix 
320+ 
321+ 
309322def  _load_lora_into_text_encoder (
310323    state_dict ,
311324    network_alphas ,
@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
317330    _pipeline = None ,
318331    low_cpu_mem_usage = False ,
319332    hotswap : bool  =  False ,
333+     metadata = None ,
320334):
321335    if  not  USE_PEFT_BACKEND :
322336        raise  ValueError ("PEFT backend is required for this method." )
323337
338+     if  network_alphas  and  metadata :
339+         raise  ValueError ("`network_alphas` and `metadata` cannot be specified both at the same time." )
340+ 
324341    peft_kwargs  =  {}
325342    if  low_cpu_mem_usage :
326343        if  not  is_peft_version (">=" , "0.13.1" ):
@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
349366    # Load the layers corresponding to text encoder and make necessary adjustments. 
350367    if  prefix  is  not   None :
351368        state_dict  =  {k .removeprefix (f"{ prefix }  ." ): v  for  k , v  in  state_dict .items () if  k .startswith (f"{ prefix }  ." )}
369+         if  metadata  is  not   None :
370+             metadata  =  {k .removeprefix (f"{ prefix }  ." ): v  for  k , v  in  metadata .items () if  k .startswith (f"{ prefix }  ." )}
352371
353372    if  len (state_dict ) >  0 :
354373        logger .info (f"Loading { prefix }  ." )
@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
376395            alpha_keys  =  [k  for  k  in  network_alphas .keys () if  k .startswith (prefix ) and  k .split ("." )[0 ] ==  prefix ]
377396            network_alphas  =  {k .removeprefix (f"{ prefix }  ." ): v  for  k , v  in  network_alphas .items () if  k  in  alpha_keys }
378397
379-         lora_config_kwargs  =  get_peft_kwargs (rank , network_alphas , state_dict , is_unet = False )
398+         if  metadata  is  not   None :
399+             lora_config_kwargs  =  metadata 
400+         else :
401+             lora_config_kwargs  =  get_peft_kwargs (rank , network_alphas , state_dict , is_unet = False )
380402
381403        if  "use_dora"  in  lora_config_kwargs :
382404            if  lora_config_kwargs ["use_dora" ]:
@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
398420                if  is_peft_version ("<=" , "0.13.2" ):
399421                    lora_config_kwargs .pop ("lora_bias" )
400422
401-         lora_config  =  LoraConfig (** lora_config_kwargs )
423+             try :
424+                 lora_config  =  LoraConfig (** lora_config_kwargs )
425+             except  TypeError  as  e :
426+                 raise  TypeError ("`LoraConfig` class could not be instantiated." ) from  e 
402427
403428        # adapter_name 
404429        if  adapter_name  is  None :
@@ -889,8 +914,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
889914    @staticmethod  
890915    def  pack_weights (layers , prefix ):
891916        layers_weights  =  layers .state_dict () if  isinstance (layers , torch .nn .Module ) else  layers 
892-         layers_state_dict  =  {f"{ prefix }  .{ module_name }  " : param  for  module_name , param  in  layers_weights .items ()}
893-         return  layers_state_dict 
917+         return  _pack_dict_with_prefix (layers_weights , prefix )
894918
895919    @staticmethod  
896920    def  write_lora_layers (
@@ -900,16 +924,32 @@ def write_lora_layers(
900924        weight_name : str ,
901925        save_function : Callable ,
902926        safe_serialization : bool ,
927+         lora_adapter_metadata : Optional [dict ] =  None ,
903928    ):
904929        if  os .path .isfile (save_directory ):
905930            logger .error (f"Provided path ({ save_directory }  ) should be a directory, not a file" )
906931            return 
907932
933+         if  lora_adapter_metadata  and  not  safe_serialization :
934+             raise  ValueError ("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`." )
935+         if  lora_adapter_metadata  and  not  isinstance (lora_adapter_metadata , dict ):
936+             raise  TypeError ("`lora_adapter_metadata` must be of type `dict`." )
937+ 
908938        if  save_function  is  None :
909939            if  safe_serialization :
910940
911941                def  save_function (weights , filename ):
912-                     return  safetensors .torch .save_file (weights , filename , metadata = {"format" : "pt" })
942+                     # Inject framework format. 
943+                     metadata  =  {"format" : "pt" }
944+                     if  lora_adapter_metadata :
945+                         for  key , value  in  lora_adapter_metadata .items ():
946+                             if  isinstance (value , set ):
947+                                 lora_adapter_metadata [key ] =  list (value )
948+                         metadata [LORA_ADAPTER_METADATA_KEY ] =  json .dumps (
949+                             lora_adapter_metadata , indent = 2 , sort_keys = True 
950+                         )
951+ 
952+                     return  safetensors .torch .save_file (weights , filename , metadata = metadata )
913953
914954            else :
915955                save_function  =  torch .save 
0 commit comments