@@ -2892,6 +2892,20 @@ def _get_trainable_parameters(self):
28922892 return self .lycoris_wrapped_network .parameters ()
28932893 return [param for param in self .model .get_trained_component (unwrap_model = False ).parameters () if param .requires_grad ]
28942894
2895+ def _get_slider_tuner_layers (self ):
2896+ """Return cached list of (id, module) for BaseTunerLayer modules that have scaling dicts."""
2897+ if hasattr (self , "_slider_tuner_layers_cache" ):
2898+ return self ._slider_tuner_layers_cache
2899+
2900+ from peft .tuners .tuners_utils import BaseTunerLayer
2901+
2902+ result = []
2903+ for name , module in self .model .get_trained_component ().named_modules ():
2904+ if isinstance (module , BaseTunerLayer ) and hasattr (module , "scaling" ):
2905+ result .append ((id (module ), module ))
2906+ self ._slider_tuner_layers_cache = result
2907+ return result
2908+
28952909 def _ensure_parameter_dtype (self , parameters , target_dtype : torch .dtype , optimizer_name : str | None = None ):
28962910 converted = 0
28972911 for param_or_group in parameters :
@@ -5431,6 +5445,28 @@ def train(self):
54315445 else :
54325446 self .model .get_trained_component ().enable_lora ()
54335447
5448+ # slider
5449+ raw_strength = prepared_batch .get ("slider_strength" , 1.0 )
5450+ try :
5451+ strength = float (raw_strength )
5452+ except (TypeError , ValueError ):
5453+ strength = 1.0
5454+
5455+ slider_original_scaling = None
5456+ if self .config .model_type == "lora" and strength != 1.0 :
5457+ with torch .no_grad ():
5458+ if self .config .lora_type .lower () == "lycoris" :
5459+ self .accelerator ._lycoris_wrapped_network .set_multiplier (strength )
5460+ else :
5461+ tuner_layers = self ._get_slider_tuner_layers ()
5462+ slider_original_scaling = {}
5463+ for layer_id , module in tuner_layers :
5464+ saved = {}
5465+ for key , val in module .scaling .items ():
5466+ saved [key ] = val
5467+ module .scaling [key ] = val * strength
5468+ slider_original_scaling [layer_id ] = (module , saved )
5469+
54345470 training_logger .debug ("Predicting." )
54355471 model_pred = self .model_predict (
54365472 prepared_batch = prepared_batch ,
@@ -5601,6 +5637,14 @@ def train(self):
56015637 ):
56025638 self .distiller .discriminator_step (prepared_batch = prepared_batch )
56035639 self .distiller .post_training_step (self .model , step )
5640+ if self .config .model_type == "lora" and strength != 1 :
5641+ with torch .no_grad ():
5642+ if self .config .lora_type .lower () == "lycoris" :
5643+ self .accelerator ._lycoris_wrapped_network .set_multiplier (1.0 )
5644+ elif slider_original_scaling is not None :
5645+ for module , saved in slider_original_scaling .values ():
5646+ for key , val in saved .items ():
5647+ module .scaling [key ] = val
56045648
56055649 # Checks if the accelerator has performed an optimization step behind the scenes
56065650 wandb_logs = {}
0 commit comments