@@ -193,6 +193,8 @@ def load_lora_adapter(
193193 from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
194194 from peft .tuners .tuners_utils import BaseTunerLayer
195195
196+ from .lora_base import LORA_ADAPTER_METADATA_KEY
197+
196198 cache_dir = kwargs .pop ("cache_dir" , None )
197199 force_download = kwargs .pop ("force_download" , False )
198200 proxies = kwargs .pop ("proxies" , None )
@@ -236,11 +238,11 @@ def load_lora_adapter(
236238 raise ValueError ("`network_alphas` cannot be None when `prefix` is None." )
237239
238240 if prefix is not None :
239- metadata = state_dict .pop ("lora_adapter_metadata" , None )
241+ metadata = state_dict .pop (LORA_ADAPTER_METADATA_KEY , None )
240242 state_dict = {k [len (f"{ prefix } ." ) :]: v for k , v in state_dict .items () if k .startswith (f"{ prefix } ." )}
241243
242244 if metadata is not None :
243- state_dict ["lora_adapter_metadata" ] = metadata
245+ state_dict [LORA_ADAPTER_METADATA_KEY ] = metadata
244246
245247 if len (state_dict ) > 0 :
246248 if adapter_name in getattr (self , "peft_config" , {}) and not hotswap :
@@ -464,23 +466,19 @@ def save_lora_adapter(
464466 safe_serialization (`bool`, *optional*, defaults to `True`):
465467 Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
466468 weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
467- lora_adapter_metadata: TODO
468469 """
469470 from peft .utils import get_peft_model_state_dict
470471
471- from .lora_base import LORA_WEIGHT_NAME , LORA_WEIGHT_NAME_SAFE
472-
473- if lora_adapter_metadata is not None and not safe_serialization :
474- raise ValueError ("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`." )
475- if not isinstance (lora_adapter_metadata , dict ):
476- raise ValueError ("`lora_adapter_metadata` must be of type `dict`." )
472+ from .lora_base import LORA_ADAPTER_METADATA_KEY , LORA_WEIGHT_NAME , LORA_WEIGHT_NAME_SAFE
477473
478474 if adapter_name is None :
479475 adapter_name = get_adapter_name (self )
480476
481477 if adapter_name not in getattr (self , "peft_config" , {}):
482478 raise ValueError (f"Adapter name { adapter_name } not found in the model." )
483479
480+ lora_adapter_metadata = self .peft_config [adapter_name ]
481+
484482 lora_layers_to_save = get_peft_model_state_dict (
485483 self .to (dtype = torch .float32 if upcast_before_saving else None ), adapter_name = adapter_name
486484 )
@@ -497,7 +495,7 @@ def save_function(weights, filename):
497495 for key , value in lora_adapter_metadata .items ():
498496 if isinstance (value , set ):
499497 lora_adapter_metadata [key ] = list (value )
500- metadata ["lora_adapter_metadata" ] = json .dumps (lora_adapter_metadata , indent = 2 , sort_keys = True )
498+ metadata [LORA_ADAPTER_METADATA_KEY ] = json .dumps (lora_adapter_metadata , indent = 2 , sort_keys = True )
501499
502500 return safetensors .torch .save_file (weights , filename , metadata = metadata )
503501
@@ -512,7 +510,6 @@ def save_function(weights, filename):
512510 else :
513511 weight_name = LORA_WEIGHT_NAME
514512
515- # TODO: we could consider saving the `peft_config` as well.
516513 save_path = Path (save_directory , weight_name ).as_posix ()
517514 save_function (lora_layers_to_save , save_path )
518515 logger .info (f"Model weights saved in { save_path } " )
0 commit comments