@@ -48,10 +48,12 @@ def __init__(self,
4848 extra_state_keys .extend (model .extra_state_keys )
4949 model = model .base_model
5050
51+ new_adapters = []
5152 if isinstance (config , SwiftConfig ):
5253 if DEFAULT_ADAPTER not in self .adapters :
5354 self .adapters [DEFAULT_ADAPTER ] = self ._prepare_model (
5455 model , config , DEFAULT_ADAPTER )
56+ new_adapters .append (DEFAULT_ADAPTER )
5557 else :
5658 logger .warn (
5759 f'Adater { DEFAULT_ADAPTER } has been patched, skip.' )
@@ -61,6 +63,7 @@ def __init__(self,
6163 if adapter_name not in self .adapters :
6264 self .adapters [adapter_name ] = self ._prepare_model (
6365 model , _config , adapter_name )
66+ new_adapters .append (adapter_name )
6467 else :
6568 logger .warn (
6669 f'Adater { adapter_name } has been patched, skip.' )
@@ -76,14 +79,15 @@ def forward(self, *args, **kwargs):
7679 signature (self .base_model .forward ).parameters .values ())
7780 forward .__signature__ = Signature (_parameters )
7881 self .forward = MethodType (forward , self )
79- for adapter_name in self . adapters :
82+ for adapter_name in new_adapters :
8083 self .activate_adapter (adapter_name )
8184
8285 if inference_mode :
8386 self .eval ()
8487 else :
85- for output in self .adapters .values ():
86- output .mark_trainable_callback (model )
88+ for key , output in self .adapters .items ():
89+ if key in new_adapters :
90+ output .mark_trainable_callback (model )
8791 if self .extra_state_keys :
8892 for n , p in model .named_parameters ():
8993 if any (
0 commit comments