@@ -1819,24 +1819,27 @@ def prune_state_dict_(state_dict):
18191819            if  "transformer."  in  k  and  any (norm_key  in  k  for  norm_key  in  supported_norm_keys )
18201820        }
18211821
1822+         transformer  =  getattr (self , self .transformer_name ) if  not  hasattr (self , "transformer" ) else  self .transformer 
1823+         self ._maybe_expand_transformer_param_shape_ (
1824+             transformer , transformer_lora_state_dict , transformer_norm_state_dict 
1825+         )
1826+         print (transformer )
1827+ 
18221828        if  len (transformer_lora_state_dict ) >  0 :
18231829            self .load_lora_into_transformer (
18241830                transformer_lora_state_dict ,
18251831                network_alphas = network_alphas ,
1826-                 transformer = getattr (self , self .transformer_name )
1827-                 if  not  hasattr (self , "transformer" )
1828-                 else  self .transformer ,
1832+                 transformer = transformer ,
18291833                adapter_name = adapter_name ,
18301834                _pipeline = self ,
18311835                low_cpu_mem_usage = low_cpu_mem_usage ,
18321836            )
18331837
18341838        if  len (transformer_norm_state_dict ) >  0 :
1835-             self .load_norm_into_transformer (
1839+             self ._transformer_norm_layers   =   self . load_norm_into_transformer (
18361840                transformer_norm_state_dict ,
1837-                 transformer = getattr (self , self .transformer_name )
1838-                 if  not  hasattr (self , "transformer" )
1839-                 else  self .transformer ,
1841+                 transformer = transformer ,
1842+                 discard_original_layers = False ,
18401843            )
18411844
18421845        text_encoder_state_dict  =  {k : v  for  k , v  in  state_dict .items () if  "text_encoder."  in  k }
@@ -1899,10 +1902,41 @@ def load_lora_into_transformer(
18991902    def  load_norm_into_transformer (
19001903        cls ,
19011904        state_dict ,
1902-         transformer : torch .nn .Module ,
1903-     ):
1904-         print (state_dict .keys ())
1905-         transformer .load_state_dict (state_dict , strict = True )
1905+         transformer ,
1906+         prefix = None ,
1907+         discard_original_layers = False ,
1908+     ) ->  Dict [str , torch .Tensor ]:
1909+         # Remove prefix if present 
1910+         prefix  =  prefix  or  cls .transformer_name 
1911+         for  key  in  list (state_dict .keys ()):
1912+             if  key .split ("." )[0 ] ==  prefix :
1913+                 state_dict [key .replace (f"{ prefix }  , "" )] =  state_dict .pop (key )
1914+ 
1915+         # Find invalid keys 
1916+         transformer_state_dict  =  transformer .state_dict ()
1917+         transformer_keys  =  set (transformer_state_dict .keys ())
1918+         state_dict_keys  =  set (state_dict .keys ())
1919+         extra_keys  =  list (state_dict_keys  -  transformer_keys )
1920+         logger .warning (
1921+             f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n { extra_keys }  
1922+         )
1923+ 
1924+         for  key  in  extra_keys :
1925+             state_dict .pop (key )
1926+ 
1927+         # Save the layers that are going to be overwritten so that unload_lora_weights can work as expected 
1928+         overwritten_layers  =  {}
1929+         if  not  discard_original_layers :
1930+             for  key  in  state_dict .keys ():
1931+                 overwritten_layers [key ] =  transformer_state_dict [key ]
1932+ 
1933+         # We can't load with strict=True because the current state_dict does not contain all the transformer keys 
1934+         logger .info (
1935+             "Normalization layers in LoRA state dict can only be loaded if fused directly in the transformer. Calls to `.fuse_lora()` will only affect the LoRA layers and not the normalization layers." 
1936+         )
1937+         transformer .load_state_dict (state_dict , strict = False )
1938+ 
1939+         return  overwritten_layers 
19061940
19071941    @classmethod  
19081942    # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder 
@@ -2139,6 +2173,11 @@ def fuse_lora(
21392173        pipeline.fuse_lora(lora_scale=0.7) 
21402174        ``` 
21412175        """ 
2176+         if  len (self ._transformer_norm_layers .keys ()) >  0 :
2177+             logger .info (
2178+                 "Normalization layers cannot be loaded without fusing. Calls to `.fuse_lora()` will only affect the actual LoRA layers." 
2179+             )
2180+ 
21422181        super ().fuse_lora (
21432182            components = components , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names 
21442183        )
@@ -2157,8 +2196,83 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
21572196        Args: 
21582197            components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. 
21592198        """ 
2199+         transformer  =  getattr (self , self .transformer_name ) if  not  hasattr (self , "transformer" ) else  self .transformer 
2200+         transformer .load_state_dict (self ._transformer_norm_layers )
2201+ 
21602202        super ().unfuse_lora (components = components )
21612203
2204+     @classmethod  
2205+     def  _maybe_expand_transformer_param_shape_ (
2206+         cls ,
2207+         transformer : torch .nn .Module ,
2208+         lora_state_dict = None ,
2209+         norm_state_dict = None ,
2210+         prefix = None ,
2211+     ):
2212+         state_dict  =  {}
2213+         if  lora_state_dict  is  not None :
2214+             state_dict .update (lora_state_dict )
2215+         if  norm_state_dict  is  not None :
2216+             state_dict .update (norm_state_dict )
2217+ 
2218+         # Remove prefix if present 
2219+         prefix  =  prefix  or  cls .transformer_name 
2220+         for  key  in  list (state_dict .keys ()):
2221+             if  key .split ("." )[0 ] ==  prefix :
2222+                 state_dict [key .replace (f"{ prefix }  , "" )] =  state_dict .pop (key )
2223+ 
2224+         def  get_submodule (module , name ):
2225+             for  part  in  name .split ("." ):
2226+                 if  len (name ) ==  0 :
2227+                     break 
2228+                 if  not  hasattr (module , part ):
2229+                     raise  AttributeError (f"Submodule '{ part } { module }  )
2230+                 module  =  getattr (module , part )
2231+             return  module 
2232+ 
2233+         # Expand transformer parameter shapes if they don't match lora 
2234+         for  name , module  in  transformer .named_modules ():
2235+             if  isinstance (module , torch .nn .Linear ):
2236+                 module_weight  =  module .weight .data 
2237+                 module_bias  =  module .bias .data  if  hasattr (module , "bias" ) else  None 
2238+                 bias  =  module_bias  is  not None 
2239+                 name_split  =  name .split ("." )
2240+ 
2241+                 lora_A_name  =  f"{ name }  
2242+                 lora_B_name  =  f"{ name }  
2243+                 lora_A_weight_name  =  f"{ lora_A_name }  
2244+                 lora_B_weight_name  =  f"{ lora_B_name }  
2245+ 
2246+                 if  lora_A_weight_name  not  in state_dict .keys ():
2247+                     continue 
2248+ 
2249+                 in_features  =  state_dict [lora_A_weight_name ].shape [1 ]
2250+                 out_features  =  state_dict [lora_B_weight_name ].shape [0 ]
2251+ 
2252+                 if  tuple (module_weight .shape ) ==  (out_features , in_features ):
2253+                     continue 
2254+ 
2255+                 parent_module_name  =  "." .join (name_split [:- 1 ])
2256+                 current_module_name  =  name_split [- 1 ]
2257+                 parent_module  =  get_submodule (transformer , parent_module_name )
2258+ 
2259+                 expanded_module  =  torch .nn .Linear (
2260+                     in_features , out_features , bias = bias , device = module_weight .device , dtype = module_weight .dtype 
2261+                 )
2262+ 
2263+                 new_weight  =  module_weight .new_zeros (expanded_module .weight .data .shape )
2264+                 slices  =  tuple (slice (0 , dim ) for  dim  in  module_weight .shape )
2265+                 new_weight [slices ] =  module_weight 
2266+                 expanded_module .weight .data .copy_ (new_weight )
2267+ 
2268+                 if  bias :
2269+                     new_bias  =  module_bias .new_zeros (expanded_module .bias .data .shape )
2270+                     slices  =  tuple (slice (0 , dim ) for  dim  in  module_bias .shape )
2271+                     new_bias [slices ] =  module_bias 
2272+                     expanded_module .bias .data .copy_ (new_bias )
2273+ 
2274+                 setattr (parent_module , current_module_name , expanded_module )
2275+ 
21622276
21632277# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially 
21642278# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. 
0 commit comments