@@ -70,6 +70,9 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]):
7070 self .mapping = dict (enumerate (state_dict .keys ()))
7171 self .rev_mapping = {v : k for k , v in enumerate (state_dict .keys ())}
7272
73+ # .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
74+ self .split_keys = [".processor" , ".k_proj" , ".q_proj" , ".v_proj" , ".out_proj" ]
75+
7376 # we add a hook to state_dict() and load_state_dict() so that the
7477 # naming fits with `unet.attn_processors`
7578 def map_to (module , state_dict , * args , ** kwargs ):
@@ -81,10 +84,19 @@ def map_to(module, state_dict, *args, **kwargs):
8184
8285 return new_state_dict
8386
87+ def remap_key (key , state_dict ):
88+ for k in self .split_keys :
89+ if k in key :
90+ return key .split (k )[0 ] + k
91+
92+ raise ValueError (
93+ f"There seems to be a problem with the state_dict: { set (state_dict .keys ())} . { key } has to have one of { self .split_keys } ."
94+ )
95+
8496 def map_from (module , state_dict , * args , ** kwargs ):
8597 all_keys = list (state_dict .keys ())
8698 for key in all_keys :
87- replace_key = key . split ( ".processor" )[ 0 ] + ".processor"
99+ replace_key = remap_key ( key , state_dict )
88100 new_key = key .replace (replace_key , f"layers.{ module .rev_mapping [replace_key ]} " )
89101 state_dict [new_key ] = state_dict [key ]
90102 del state_dict [key ]
@@ -898,6 +910,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
898910 attn_procs_text_encoder = self ._load_text_encoder_attn_procs (text_encoder_lora_state_dict )
899911 self ._modify_text_encoder (attn_procs_text_encoder )
900912
913+ # save lora attn procs of text encoder so that it can be easily retrieved
914+ self ._text_encoder_lora_attn_procs = attn_procs_text_encoder
915+
901916 # Otherwise, we're dealing with the old format. This means the `state_dict` should only
902917 # contain the module names of the `unet` as its keys WITHOUT any prefix.
903918 elif not all (
@@ -907,6 +922,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
907922 warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
908923 warnings .warn (warn_message )
909924
925+ @property
926+ def text_encoder_lora_attn_procs (self ):
927+ if hasattr (self , "_text_encoder_lora_attn_procs" ):
928+ return self ._text_encoder_lora_attn_procs
929+ return
930+
910931 def _modify_text_encoder (self , attn_processors : Dict [str , LoRAAttnProcessor ]):
911932 r"""
912933 Monkey-patches the forward passes of attention modules of the text encoder.
@@ -1110,7 +1131,7 @@ def _load_text_encoder_attn_procs(
11101131 def save_lora_weights (
11111132 self ,
11121133 save_directory : Union [str , os .PathLike ],
1113- unet_lora_layers : Dict [str , torch .nn .Module ] = None ,
1134+ unet_lora_layers : Dict [str , Union [ torch .nn .Module , torch . Tensor ] ] = None ,
11141135 text_encoder_lora_layers : Dict [str , torch .nn .Module ] = None ,
11151136 is_main_process : bool = True ,
11161137 weight_name : str = None ,
@@ -1123,13 +1144,14 @@ def save_lora_weights(
11231144 Arguments:
11241145 save_directory (`str` or `os.PathLike`):
11251146 Directory to which to save. Will be created if it doesn't exist.
1126- unet_lora_layers (`Dict[str, torch.nn.Module`] ):
1147+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]` ):
11271148 State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the
1128- serialization process easier and cleaner.
1129- text_encoder_lora_layers (`Dict[str, torch.nn.Module`]):
1149+ serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch
1150+ weights.
1151+ text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
11301152 State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from
11311153 `transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state
1132- dict.
1154+ dict. Values can be both LoRA torch.nn.Modules layers or torch weights.
11331155 is_main_process (`bool`, *optional*, defaults to `True`):
11341156 Whether the process calling this is the main process or not. Useful when in distributed training like
11351157 TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
@@ -1157,15 +1179,22 @@ def save_function(weights, filename):
11571179 # Create a flat dictionary.
11581180 state_dict = {}
11591181 if unet_lora_layers is not None :
1160- unet_lora_state_dict = {
1161- f"{ self .unet_name } .{ module_name } " : param
1162- for module_name , param in unet_lora_layers .state_dict ().items ()
1163- }
1182+ weights = (
1183+ unet_lora_layers .state_dict () if isinstance (unet_lora_layers , torch .nn .Module ) else unet_lora_layers
1184+ )
1185+
1186+ unet_lora_state_dict = {f"{ self .unet_name } .{ module_name } " : param for module_name , param in weights .items ()}
11641187 state_dict .update (unet_lora_state_dict )
1188+
11651189 if text_encoder_lora_layers is not None :
1190+ weights = (
1191+ text_encoder_lora_layers .state_dict ()
1192+ if isinstance (text_encoder_lora_layers , torch .nn .Module )
1193+ else text_encoder_lora_layers
1194+ )
1195+
11661196 text_encoder_lora_state_dict = {
1167- f"{ self .text_encoder_name } .{ module_name } " : param
1168- for module_name , param in text_encoder_lora_layers .state_dict ().items ()
1197+ f"{ self .text_encoder_name } .{ module_name } " : param for module_name , param in weights .items ()
11691198 }
11701199 state_dict .update (text_encoder_lora_state_dict )
11711200
0 commit comments