Skip to content

Commit fc34e5f

Browse files
authored
Merge pull request bghira#2592 from bghira/sliders
Sliders
2 parents 46b77f9 + 7dad823 commit fc34e5f

File tree

4 files changed

+58
-0
lines changed

4 files changed

+58
-0
lines changed

simpletuner/helpers/data_backend/factory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3198,11 +3198,19 @@ def _create_dataset_and_sampler(
31983198
elif caption_strategy == "instanceprompt":
31993199
use_captions = False
32003200

3201+
slider_strength_raw = backend.get("slider_strength", 1.0)
3202+
try:
3203+
slider_strength = float(slider_strength_raw) if slider_strength_raw is not None else 1.0
3204+
except (TypeError, ValueError):
3205+
logging.warning("Invalid slider_strength %r in backend; defaulting to 1.0", slider_strength_raw)
3206+
slider_strength = 1.0
3207+
32013208
init_backend["train_dataset"] = MultiAspectDataset(
32023209
id=init_backend["id"],
32033210
datasets=[init_backend["metadata_backend"]],
32043211
is_regularisation_data=is_regularisation_data,
32053212
is_i2v_data=is_i2v_data,
3213+
slider_strength=slider_strength,
32063214
)
32073215

32083216
if "deepfloyd" in self.args.model_type:

simpletuner/helpers/multiaspect/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ def __init__(
3131
print_names: bool = False,
3232
is_regularisation_data: bool = False,
3333
is_i2v_data: bool = False,
34+
slider_strength: float = 1.0,
3435
):
3536
self.id = id
3637
self.datasets = datasets
3738
self.print_names = print_names
3839
self.is_regularisation_data = is_regularisation_data
3940
self.is_i2v_data = is_i2v_data
41+
self.slider_strength = slider_strength
4042

4143
def __len__(self):
4244
# Sum the length of all data backends:
@@ -54,6 +56,7 @@ def __getitem__(self, image_tuple: list[dict[str, Any] | TrainingSample]):
5456
"conditioning_samples": [],
5557
"is_regularisation_data": self.is_regularisation_data,
5658
"is_i2v_data": self.is_i2v_data,
59+
"slider_strength": self.slider_strength,
5760
}
5861
first_aspect_ratio = None
5962
for sample in image_tuple:

simpletuner/helpers/training/collate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,4 +1219,7 @@ def _conditioning_pixel_value_for_example(example_idx: int):
12191219
"is_audio_only": is_audio_only,
12201220
"s2v_audio_paths": s2v_audio_paths if any(s2v_audio_paths) else None,
12211221
"s2v_audio_backend_ids": s2v_audio_backend_ids if any(s2v_audio_backend_ids) else None,
1222+
1223+
1224+
"slider_strength": batch.get("slider_strength")
12221225
}

simpletuner/helpers/training/trainer.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2892,6 +2892,20 @@ def _get_trainable_parameters(self):
28922892
return self.lycoris_wrapped_network.parameters()
28932893
return [param for param in self.model.get_trained_component(unwrap_model=False).parameters() if param.requires_grad]
28942894

2895+
def _get_slider_tuner_layers(self):
2896+
"""Return cached list of (id, module) for BaseTunerLayer modules that have scaling dicts."""
2897+
if hasattr(self, "_slider_tuner_layers_cache"):
2898+
return self._slider_tuner_layers_cache
2899+
2900+
from peft.tuners.tuners_utils import BaseTunerLayer
2901+
2902+
result = []
2903+
for name, module in self.model.get_trained_component().named_modules():
2904+
if isinstance(module, BaseTunerLayer) and hasattr(module, "scaling"):
2905+
result.append((id(module), module))
2906+
self._slider_tuner_layers_cache = result
2907+
return result
2908+
28952909
def _ensure_parameter_dtype(self, parameters, target_dtype: torch.dtype, optimizer_name: str | None = None):
28962910
converted = 0
28972911
for param_or_group in parameters:
@@ -5431,6 +5445,28 @@ def train(self):
54315445
else:
54325446
self.model.get_trained_component().enable_lora()
54335447

5448+
# slider
5449+
raw_strength = prepared_batch.get("slider_strength", 1.0)
5450+
try:
5451+
strength = float(raw_strength)
5452+
except (TypeError, ValueError):
5453+
strength = 1.0
5454+
5455+
slider_original_scaling = None
5456+
if self.config.model_type == "lora" and strength != 1.0:
5457+
with torch.no_grad():
5458+
if self.config.lora_type.lower() == "lycoris":
5459+
self.accelerator._lycoris_wrapped_network.set_multiplier(strength)
5460+
else:
5461+
tuner_layers = self._get_slider_tuner_layers()
5462+
slider_original_scaling = {}
5463+
for layer_id, module in tuner_layers:
5464+
saved = {}
5465+
for key, val in module.scaling.items():
5466+
saved[key] = val
5467+
module.scaling[key] = val * strength
5468+
slider_original_scaling[layer_id] = (module, saved)
5469+
54345470
training_logger.debug("Predicting.")
54355471
model_pred = self.model_predict(
54365472
prepared_batch=prepared_batch,
@@ -5601,6 +5637,14 @@ def train(self):
56015637
):
56025638
self.distiller.discriminator_step(prepared_batch=prepared_batch)
56035639
self.distiller.post_training_step(self.model, step)
5640+
if self.config.model_type == "lora" and strength != 1:
5641+
with torch.no_grad():
5642+
if self.config.lora_type.lower() == "lycoris":
5643+
self.accelerator._lycoris_wrapped_network.set_multiplier(1.0)
5644+
elif slider_original_scaling is not None:
5645+
for module, saved in slider_original_scaling.values():
5646+
for key, val in saved.items():
5647+
module.scaling[key] = val
56045648

56055649
# Checks if the accelerator has performed an optimization step behind the scenes
56065650
wandb_logs = {}

0 commit comments

Comments
 (0)