@@ -1550,6 +1550,11 @@ def load_lora_weights(
15501550 )
15511551
15521552 if has_param_with_expanded_shape :
1553+ if not hasattr (self ,"_unloading_reset_list" ):
1554+ self ._lora_unloading_reset_list = [adapter_name ]
1555+ else :
1556+ self ._lora_unloading_reset_list .append (adapter_name )
1557+
15531558 logger .info (
15541559 "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
15551560 "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
@@ -1893,7 +1898,25 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
18931898 transformer .load_state_dict (transformer ._transformer_norm_layers , strict = False )
18941899 transformer ._transformer_norm_layers = None
18951900
1896- if reset_to_overwritten_params and getattr (transformer , "_overwritten_params" , None ) is not None :
1901+ if reset_to_overwritten_params and transformer is not None :
1902+ self ._maybe_reset_transformer (transformer )
1903+ self ._lora_unloading_reset_list .clear ()
1904+
1905+ def delete_adapters (self , adapter_names : Union [List [str ], str ]):
1906+ super ().delete_adapters (adapter_names )
1907+
1908+ adapter_names = [adapter_names ] if isinstance (adapter_names , str ) else adapter_names
1909+ for adapter_name in adapter_names :
1910+ if adapter_name in self ._lora_unloading_reset_list :
1911+ self ._lora_unloading_reset_list .remove (adapter_name )
1912+ # If more than 1 LoRA adapters expanded the transformer, we don't need to resest the transformer.
1913+ if len (self ._lora_unloading_reset_list ) == 0 :
1914+ transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer
1915+ self ._maybe_reset_transformer (transformer )
1916+
1917+ @classmethod
1918+ def _maybe_reset_transformer (cls , transformer : torch .nn .Module ):
1919+ if getattr (transformer , "_overwritten_params" , None ) is not None :
18971920 overwritten_params = transformer ._overwritten_params
18981921 module_names = set ()
18991922
@@ -1903,7 +1926,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19031926
19041927 for name , module in transformer .named_modules ():
19051928 if isinstance (module , torch .nn .Linear ) and name in module_names :
1906- module_weight = module .weight .data
19071929 module_bias = module .bias .data if module .bias is not None else None
19081930 bias = module_bias is not None
19091931
@@ -1917,7 +1939,6 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19171939 in_features ,
19181940 out_features ,
19191941 bias = bias ,
1920- dtype = module_weight .dtype ,
19211942 )
19221943
19231944 tmp_state_dict = {"weight" : current_param_weight }
0 commit comments