@@ -1819,24 +1819,27 @@ def prune_state_dict_(state_dict):
18191819 if "transformer." in k and any (norm_key in k for norm_key in supported_norm_keys )
18201820 }
18211821
1822+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer
1823+ self ._maybe_expand_transformer_param_shape_ (
1824+ transformer , transformer_lora_state_dict , transformer_norm_state_dict
1825+ )
1826+ print (transformer )
1827+
18221828 if len (transformer_lora_state_dict ) > 0 :
18231829 self .load_lora_into_transformer (
18241830 transformer_lora_state_dict ,
18251831 network_alphas = network_alphas ,
1826- transformer = getattr (self , self .transformer_name )
1827- if not hasattr (self , "transformer" )
1828- else self .transformer ,
1832+ transformer = transformer ,
18291833 adapter_name = adapter_name ,
18301834 _pipeline = self ,
18311835 low_cpu_mem_usage = low_cpu_mem_usage ,
18321836 )
18331837
18341838 if len (transformer_norm_state_dict ) > 0 :
1835- self .load_norm_into_transformer (
1839+ self ._transformer_norm_layers = self . load_norm_into_transformer (
18361840 transformer_norm_state_dict ,
1837- transformer = getattr (self , self .transformer_name )
1838- if not hasattr (self , "transformer" )
1839- else self .transformer ,
1841+ transformer = transformer ,
1842+ discard_original_layers = False ,
18401843 )
18411844
18421845 text_encoder_state_dict = {k : v for k , v in state_dict .items () if "text_encoder." in k }
@@ -1899,10 +1902,41 @@ def load_lora_into_transformer(
18991902 def load_norm_into_transformer (
19001903 cls ,
19011904 state_dict ,
1902- transformer : torch .nn .Module ,
1903- ):
1904- print (state_dict .keys ())
1905- transformer .load_state_dict (state_dict , strict = True )
1905+ transformer ,
1906+ prefix = None ,
1907+ discard_original_layers = False ,
1908+ ) -> Dict [str , torch .Tensor ]:
1909+ # Remove prefix if present
1910+ prefix = prefix or cls .transformer_name
1911+ for key in list (state_dict .keys ()):
1912+ if key .split ("." )[0 ] == prefix :
1913+ state_dict [key .replace (f"{ prefix } ." , "" )] = state_dict .pop (key )
1914+
1915+ # Find invalid keys
1916+ transformer_state_dict = transformer .state_dict ()
1917+ transformer_keys = set (transformer_state_dict .keys ())
1918+ state_dict_keys = set (state_dict .keys ())
1919+ extra_keys = list (state_dict_keys - transformer_keys )
1920+ logger .warning (
1921+ f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n { extra_keys } ."
1922+ )
1923+
1924+ for key in extra_keys :
1925+ state_dict .pop (key )
1926+
1927+ # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
1928+ overwritten_layers = {}
1929+ if not discard_original_layers :
1930+ for key in state_dict .keys ():
1931+ overwritten_layers [key ] = transformer_state_dict [key ]
1932+
1933+ # We can't load with strict=True because the current state_dict does not contain all the transformer keys
1934+ logger .info (
1935+ "Normalization layers in LoRA state dict can only be loaded if fused directly in the transformer. Calls to `.fuse_lora()` will only affect the LoRA layers and not the normalization layers."
1936+ )
1937+ transformer .load_state_dict (state_dict , strict = False )
1938+
1939+ return overwritten_layers
19061940
19071941 @classmethod
19081942 # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
@@ -2139,6 +2173,11 @@ def fuse_lora(
21392173 pipeline.fuse_lora(lora_scale=0.7)
21402174 ```
21412175 """
2176+ if len (self ._transformer_norm_layers .keys ()) > 0 :
2177+ logger .info (
2178+ "Normalization layers cannot be loaded without fusing. Calls to `.fuse_lora()` will only affect the actual LoRA layers."
2179+ )
2180+
21422181 super ().fuse_lora (
21432182 components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
21442183 )
@@ -2157,8 +2196,83 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
21572196 Args:
21582197 components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
21592198 """
2199+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer
2200+ transformer .load_state_dict (self ._transformer_norm_layers )
2201+
21602202 super ().unfuse_lora (components = components )
21612203
2204+ @classmethod
2205+ def _maybe_expand_transformer_param_shape_ (
2206+ cls ,
2207+ transformer : torch .nn .Module ,
2208+ lora_state_dict = None ,
2209+ norm_state_dict = None ,
2210+ prefix = None ,
2211+ ):
2212+ state_dict = {}
2213+ if lora_state_dict is not None :
2214+ state_dict .update (lora_state_dict )
2215+ if norm_state_dict is not None :
2216+ state_dict .update (norm_state_dict )
2217+
2218+ # Remove prefix if present
2219+ prefix = prefix or cls .transformer_name
2220+ for key in list (state_dict .keys ()):
2221+ if key .split ("." )[0 ] == prefix :
2222+ state_dict [key .replace (f"{ prefix } ." , "" )] = state_dict .pop (key )
2223+
2224+ def get_submodule (module , name ):
2225+ for part in name .split ("." ):
2226+ if len (name ) == 0 :
2227+ break
2228+ if not hasattr (module , part ):
2229+ raise AttributeError (f"Submodule '{ part } ' not found in '{ module } '." )
2230+ module = getattr (module , part )
2231+ return module
2232+
2233+ # Expand transformer parameter shapes if they don't match lora
2234+ for name , module in transformer .named_modules ():
2235+ if isinstance (module , torch .nn .Linear ):
2236+ module_weight = module .weight .data
2237+ module_bias = module .bias .data if hasattr (module , "bias" ) else None
2238+ bias = module_bias is not None
2239+ name_split = name .split ("." )
2240+
2241+ lora_A_name = f"{ name } .lora_A"
2242+ lora_B_name = f"{ name } .lora_B"
2243+ lora_A_weight_name = f"{ lora_A_name } .weight"
2244+ lora_B_weight_name = f"{ lora_B_name } .weight"
2245+
2246+ if lora_A_weight_name not in state_dict .keys ():
2247+ continue
2248+
2249+ in_features = state_dict [lora_A_weight_name ].shape [1 ]
2250+ out_features = state_dict [lora_B_weight_name ].shape [0 ]
2251+
2252+ if tuple (module_weight .shape ) == (out_features , in_features ):
2253+ continue
2254+
2255+ parent_module_name = "." .join (name_split [:- 1 ])
2256+ current_module_name = name_split [- 1 ]
2257+ parent_module = get_submodule (transformer , parent_module_name )
2258+
2259+ expanded_module = torch .nn .Linear (
2260+ in_features , out_features , bias = bias , device = module_weight .device , dtype = module_weight .dtype
2261+ )
2262+
2263+ new_weight = module_weight .new_zeros (expanded_module .weight .data .shape )
2264+ slices = tuple (slice (0 , dim ) for dim in module_weight .shape )
2265+ new_weight [slices ] = module_weight
2266+ expanded_module .weight .data .copy_ (new_weight )
2267+
2268+ if bias :
2269+ new_bias = module_bias .new_zeros (expanded_module .bias .data .shape )
2270+ slices = tuple (slice (0 , dim ) for dim in module_bias .shape )
2271+ new_bias [slices ] = module_bias
2272+ expanded_module .bias .data .copy_ (new_bias )
2273+
2274+ setattr (parent_module , current_module_name , expanded_module )
2275+
21622276
21632277# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
21642278# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
0 commit comments