From e3d9d10fe1c57dbd7a6b9e7ab3ca74f97e0d3207 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Tue, 27 Jan 2026 11:25:29 +0800 Subject: [PATCH 01/16] Rallback all the Z-Image-Turbo bs modfiy, just modify z-image.py --- .../diffusion_models/z_image/z_image.py | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py index 368ae9e7c..2a79fd105 100644 --- a/extensions_built_in/diffusion_models/z_image/z_image.py +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -333,10 +333,25 @@ def get_noise_prediction( timestep_model_input = (1000 - timestep) / 1000 + text_embeds = text_embeddings.text_embeds + if isinstance(text_embeds, torch.Tensor): + if len(text_embeds.shape) == 3: + # if it is a single batch tensor, unbind it into a list of tensors + text_embeds = list(text_embeds.unbind(dim=0)) + elif isinstance(text_embeds, list): + # check if items are rank 3 (batch, length, dim) + if len(text_embeds[0].shape) == 3: + # flatten the list of batches into a single list of tensors + new_text_embeds = [] + for t in text_embeds: + if t is not None: + new_text_embeds += list(t.unbind(dim=0)) + text_embeds = new_text_embeds + model_out_list = self.transformer( latent_model_input_list, timestep_model_input, - text_embeddings.text_embeds, + text_embeds, )[0] noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) @@ -355,6 +370,24 @@ def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: do_classifier_free_guidance=False, device=self.device_torch, ) + # encode_prompt returns list of rank-2 tensors [seq_len, dim] + # Pad to same length and stack into rank-3 tensor [batch, seq_len, dim] + # for compatibility with concat_prompt_embeds and predict_noise + if isinstance(prompt_embeds, list): + # Find max sequence length, TODO: Or just use 512? + max_seq_len = max(t.shape[0] for t in prompt_embeds) + # Pad each tensor to max length + padded = [] + for t in prompt_embeds: + if t.shape[0] < max_seq_len: + pad = torch.zeros( + (max_seq_len - t.shape[0], t.shape[1]), + dtype=t.dtype, + device=t.device, + ) + t = torch.cat([t, pad], dim=0) + padded.append(t) + prompt_embeds = torch.stack(padded, dim=0) pe = PromptEmbeds([prompt_embeds, None]) return pe From 30a91855755ca02aba3ded2c86f69650626a8151 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Thu, 29 Jan 2026 01:17:02 +0800 Subject: [PATCH 02/16] Fixed generate_single_image for training Z-Image --- .../diffusion_models/z_image/z_image.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py index 2a79fd105..fda894a5a 100644 --- a/extensions_built_in/diffusion_models/z_image/z_image.py +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -306,9 +306,33 @@ def generate_single_image( sc = self.get_bucket_divisibility() gen_config.width = int(gen_config.width // sc * sc) gen_config.height = int(gen_config.height // sc * sc) + + # ZImagePipeline expects prompt_embeds and negative_prompt_embeds to be + # List[torch.FloatTensor] where each element is [seq_len, dim]. + # The pipeline concatenates these lists for CFG (not element-wise add). + cond_embeds = conditional_embeds.text_embeds + uncond_embeds = unconditional_embeds.text_embeds + + # Convert rank-3 tensors back to list of rank-2 tensors + def to_embed_list(embeds): + if embeds is None: + return [] + if isinstance(embeds, list): + return embeds + if len(embeds.shape) == 3: + # [batch, seq_len, dim] -> list of [seq_len, dim] + return list(embeds.unbind(dim=0)) + elif len(embeds.shape) == 2: + # Already [seq_len, dim], wrap in list + return [embeds] + return embeds + + cond_embeds_list = to_embed_list(cond_embeds) + uncond_embeds_list = to_embed_list(uncond_embeds) + img = pipeline( - prompt_embeds=conditional_embeds.text_embeds, - negative_prompt_embeds=unconditional_embeds.text_embeds, + prompt_embeds=cond_embeds_list, + negative_prompt_embeds=uncond_embeds_list, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, From 0c403b597e35d44207565b9a2acc1b1da22bbd99 Mon Sep 17 00:00:00 2001 From: Daan Date: Fri, 13 Feb 2026 17:59:05 +0300 Subject: [PATCH 03/16] feat(training): timestep sampling, schedulers, networks, loss UI - blank_prompt_probability, differential guidance metric, LR graph - Gaussian/content-style timestep sampling, warmup for cosine LR - BaseSDTrainProcess timestep mapping, fixed cycle training - rank_dropout/module_dropout, alpha for Peft LoRA, SNR flow matching - Timestep debug logging, config and UI updates --- .gitignore | 6 +- config/examples/WARMUP_SCHEDULER_GUIDE.md | 299 +++++++++++++++++ .../examples/train_lora_flux_with_warmup.yaml | 79 +++++ extensions_built_in/sd_trainer/SDTrainer.py | 89 +++-- jobs/process/BaseSDTrainProcess.py | 304 +++++++++++++++++- toolkit/config_modules.py | 21 +- toolkit/dataloader_mixins.py | 4 + toolkit/lora_special.py | 15 +- toolkit/network_mixins.py | 8 +- toolkit/samplers/custom_flowmatch_sampler.py | 55 ++++ toolkit/scheduler.py | 124 ++++++- toolkit/train_tools.py | 20 +- ui/src/app/jobs/new/SimpleJob.tsx | 77 +++++ ui/src/components/JobLossGraph.tsx | 54 +++- ui/src/docs.tsx | 72 +++++ ui/src/hooks/useJobLossLog.tsx | 47 ++- ui/src/types.ts | 5 + 17 files changed, 1188 insertions(+), 91 deletions(-) create mode 100644 config/examples/WARMUP_SCHEDULER_GUIDE.md create mode 100644 config/examples/train_lora_flux_with_warmup.yaml diff --git a/.gitignore b/.gitignore index 04c233ac4..2a618d3f2 100644 --- a/.gitignore +++ b/.gitignore @@ -181,4 +181,8 @@ cython_debug/ ._.DS_Store aitk_db.db /notes.md -/data \ No newline at end of file +/data + +# ide +.cursor +.continue diff --git a/config/examples/WARMUP_SCHEDULER_GUIDE.md b/config/examples/WARMUP_SCHEDULER_GUIDE.md new file mode 100644 index 000000000..3f9a195fe --- /dev/null +++ b/config/examples/WARMUP_SCHEDULER_GUIDE.md @@ -0,0 +1,299 @@ +# Learning Rate Scheduler Warmup Guide + +## Overview + +This guide explains how to use the warmup functionality for learning rate schedulers in the AI Toolkit. Warmup gradually increases the learning rate from near-zero to the target learning rate over a specified number of steps, which can help stabilize training in the early stages. + +## Supported Schedulers + +Warmup is supported for the following schedulers: +- `cosine` - Cosine annealing scheduler +- `cosine_with_restarts` - Cosine annealing with warm restarts (SGDR) + +## How It Works + +When you specify `warmup_steps > 0`, the scheduler automatically creates a composite scheduler using PyTorch's `SequentialLR`: + +1. **Warmup Phase** (steps 0 to `warmup_steps`): + - Uses `LinearLR` to gradually increase learning rate from ~0 to the target LR + - Learning rate increases linearly: `lr = target_lr * (current_step / warmup_steps)` + +2. **Main Phase** (steps `warmup_steps` to `total_steps`): + - Uses the specified scheduler (cosine or cosine_with_restarts) + - `total_iters` specifies the TOTAL number of training iterations (including warmup) + - `T_0`/`T_max` specify iterations for the MAIN scheduler phase (after warmup) + - If T_0/T_max is specified, it takes priority over calculated value from total_iters + - Main scheduler iterations = total_iters - warmup_steps (or T_0/T_max if specified) + +### Learning Rate Progression + +``` +LR | + | _/--\ /\ /\ + | / \ / \ / \ + | / \ / \ / \ + | / \/ \ / \ + |/ \/ \ + +----------------------------> steps + |<-warmup->|<-- cosine restarts --> +``` + +## Parameter Semantics + +### Key concepts + +- **`total_iters`**: TOTAL training iterations (including warmup) +- **`T_0`/`T_max`**: Main scheduler iterations (after warmup), overrides calculation from total_iters + +### Example: Training with 1000 total steps + +**Config 1: Using total_iters (automatic calculation)** + +```yaml +train: + steps: 1000 # Will be used as total_iters by BaseSDTrainProcess + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + warmup_steps: 100 + # total_iters will be auto-set to 1000 by BaseSDTrainProcess + T_mult: 2 +``` + +**Result:** +- Steps 0-100: Linear warmup +- Steps 100-1000: Cosine with restarts (900 iterations = 1000 - 100) +- Total: 1000 steps ✓ + +**Config 2: Using T_0 explicitly (overrides calculation)** + +```yaml +train: + steps: 1000 + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + warmup_steps: 100 + total_iters: 1000 + T_0: 500 # Overrides default calculation (1000 - 100) + T_mult: 2 +``` + +**Result:** +- Steps 0-100: Linear warmup +- Steps 100-600: Cosine with restarts (500 iterations from T_0) +- Total: 600 steps (less than total_iters!) + +**Priority:** If both `T_0` and `total_iters` are specified, `T_0` takes priority and determines main scheduler length. + +## Configuration Examples + +### Example 1: Cosine with Restarts + Warmup + +```yaml +train: + steps: 2000 + lr: 1e-4 + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + warmup_steps: 100 # Warmup for first 100 steps + T_mult: 2 # Double restart period each cycle + eta_min: 1e-7 # Minimum learning rate +``` + +### Example 2: Cosine + Warmup + +```yaml +train: + steps: 2000 + lr: 1e-4 + lr_scheduler: "cosine" + lr_scheduler_params: + warmup_steps: 100 # Warmup for first 100 steps + eta_min: 1e-7 # Minimum learning rate +``` + +### Example 3: Without Warmup (Backward Compatible) + +```yaml +train: + steps: 2000 + lr: 1e-4 + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + T_mult: 2 + eta_min: 1e-7 + # No warmup_steps specified = no warmup +``` + +### Example 4: Explicitly Disable Warmup + +```yaml +train: + steps: 2000 + lr: 1e-4 + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + warmup_steps: 0 # Explicitly disable warmup + T_mult: 2 + eta_min: 1e-7 +``` + +## Parameters + +### Common Parameters + +- **`warmup_steps`** (optional, default: 0) + - Number of steps for the warmup phase + - Set to 0 or omit to disable warmup + - Typical values: 50-500 depending on total training steps + - Rule of thumb: 5-10% of total steps + +### Cosine Scheduler Parameters + +- **`eta_min`** (optional, default: 0) + - Minimum learning rate + +### Cosine with Restarts Parameters + +- **`T_mult`** (optional, default: 1) + - Factor to increase the restart period after each restart + - `T_mult=1`: equal restart periods + - `T_mult=2`: double the period each time + +- **`eta_min`** (optional, default: 0) + - Minimum learning rate + +## Choosing Warmup Steps + +### General Guidelines + +1. **Small datasets (< 1000 images)** + - Warmup steps: 50-100 + - Helps prevent overfitting to early batches + +2. **Medium datasets (1000-10000 images)** + - Warmup steps: 100-250 + - Balances stability and training time + +3. **Large datasets (> 10000 images)** + - Warmup steps: 250-500 + - More warmup helps with stability + +4. **Percentage-based approach** + - Use 5-10% of total training steps + - Example: 2000 steps → 100-200 warmup steps + +### When to Use Warmup + +✅ **Use warmup when:** +- Training from scratch or with random initialization +- Using high learning rates (> 1e-4) +- Experiencing unstable early training +- Training large models or with large batch sizes +- Using aggressive optimizers (Adam with high β values) + +❌ **Skip warmup when:** +- Fine-tuning from a well-trained checkpoint +- Using very low learning rates (< 1e-5) +- Training is already stable without warmup +- Total training steps are very small (< 500) + +## Implementation Details + +### Under the Hood + +The implementation uses PyTorch's built-in schedulers: + +```python +# Warmup phase +warmup_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1e-10, # Start almost at 0 + end_factor=1.0, # End at full LR + total_iters=warmup_steps +) + +# Main phase (example for cosine_with_restarts) +main_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0=total_iters - warmup_steps, # Calculated from total iterations + # Or explicitly specify main scheduler iterations: + # T_0=900, # Direct specification (ignores total_iters calculation) + T_mult=2, + eta_min=1e-7 +) + +# Combined scheduler +combined_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_scheduler, main_scheduler], + milestones=[warmup_steps] +) +``` + +### Backward Compatibility + +- All existing configurations continue to work without changes +- Warmup is only activated when `warmup_steps > 0` is explicitly specified +- Default behavior (no warmup) is preserved when `warmup_steps` is not specified + +## Testing + +A test script is provided to verify the warmup functionality: + +```bash +# Activate your virtual environment +.\venv\Scripts\activate.ps1 # Windows PowerShell +# or +source venv/bin/activate # Linux/Mac + +# Run tests +python test_scheduler_warmup.py +``` + +The test script verifies: +1. Backward compatibility (schedulers work without warmup) +2. Warmup functionality (SequentialLR is created when warmup_steps > 0) +3. Learning rate progression (LR increases during warmup, then follows main scheduler) + +## Full Configuration Example + +See `config/examples/train_lora_flux_with_warmup.yaml` for a complete working example. + +## Troubleshooting + +### Issue: Learning rate doesn't increase during warmup + +**Solution:** Make sure `warmup_steps` is specified in `lr_scheduler_params`, not at the top level of `train`. + +```yaml +# ❌ Wrong +train: + warmup_steps: 100 + lr_scheduler_params: + T_mult: 2 + +# ✅ Correct +train: + lr_scheduler_params: + warmup_steps: 100 + T_mult: 2 +``` + +### Issue: Training is unstable even with warmup + +**Possible solutions:** +1. Increase `warmup_steps` (try 10-15% of total steps) +2. Reduce base learning rate +3. Use gradient clipping (`max_grad_norm`) +4. Reduce batch size + +### Issue: Warmup takes too long + +**Solution:** Reduce `warmup_steps`. Remember, warmup is just the initial phase. If warmup is more than 10-15% of total training, it might be too long. + +## References + +- [PyTorch CosineAnnealingLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html) +- [PyTorch CosineAnnealingWarmRestarts](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html) +- [PyTorch SequentialLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.SequentialLR.html) +- [SGDR: Stochastic Gradient Descent with Warm Restarts](https://arxiv.org/abs/1608.03983) diff --git a/config/examples/train_lora_flux_with_warmup.yaml b/config/examples/train_lora_flux_with_warmup.yaml new file mode 100644 index 000000000..18a473141 --- /dev/null +++ b/config/examples/train_lora_flux_with_warmup.yaml @@ -0,0 +1,79 @@ +--- +# Example configuration demonstrating warmup support for cosine_with_restarts scheduler +# This is based on train_lora_flux_24gb.yaml with added warmup configuration + +job: extension +config: + name: "flux_lora_with_warmup_v1" + process: + - type: 'sd_trainer' + training_folder: "output" + device: cuda:0 + + network: + type: "lora" + linear: 16 + linear_alpha: 16 + + save: + dtype: float16 + save_every: 250 + max_step_saves_to_keep: 4 + push_to_hub: false + + datasets: + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 + shuffle_tokens: false + cache_latents_to_disk: true + resolution: [ 512, 768, 1024 ] + + train: + batch_size: 1 + steps: 2000 + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false + gradient_checkpointing: true + noise_scheduler: "flowmatch" + optimizer: "adamw8bit" + lr: 1e-4 + + # Learning rate scheduler with warmup + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + warmup_steps: 100 # Linear warmup for first 100 steps (0 → lr) + T_mult: 2 # Double the restart period each time + eta_min: 1e-7 # Minimum learning rate + + # Alternative: cosine scheduler with warmup + # lr_scheduler: "cosine" + # lr_scheduler_params: + # warmup_steps: 100 # Linear warmup for first 100 steps + # eta_min: 1e-7 # Minimum learning rate + + # For backward compatibility (no warmup): + # lr_scheduler: "cosine_with_restarts" + # lr_scheduler_params: + # T_mult: 2 + # eta_min: 1e-7 + + model: + name_or_path: "black-forest-labs/FLUX.1-dev" + is_flux: true + quantize: true + + sample: + sampler: "flowmatch" + sample_every: 250 + width: 1024 + height: 1024 + prompts: + - "a photo of a cat" + - "a photo of a dog" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 20 diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 152cd131c..da6c82677 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -6,7 +6,6 @@ import numpy as np from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel -import torch.functional as F from safetensors.torch import load_file from torch.utils.data import DataLoader, ConcatDataset @@ -112,6 +111,9 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0]) else: raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}") + + # store differential guidance norm metric for logging + self.diff_guidance_norm = None def before_model_load(self): @@ -708,11 +710,14 @@ def calculate_loss( unconditional_target = unconditional_target * alpha target = unconditional_target + guidance_scale * (target - unconditional_target) - - if self.train_config.do_differential_guidance: - with torch.no_grad(): - guidance_scale = self.train_config.differential_guidance_scale - target = noise_pred + guidance_scale * (target - noise_pred) + + if self.train_config.do_differential_guidance: + with torch.no_grad(): + guidance_scale = self.train_config.differential_guidance_scale + diff_correction = guidance_scale * (target - noise_pred) + # Calculate L2 norm for monitoring + self.diff_guidance_norm = torch.norm(diff_correction).item() + target = noise_pred + diff_correction if target is None: target = noise @@ -789,6 +794,22 @@ def calculate_loss( elif len(loss.shape) == 5: timestep_weight = timestep_weight.view(-1, 1, 1, 1, 1).detach() loss = loss * timestep_weight + elif self.train_config.content_or_style == 'fixed_cycle' and self.train_config.fixed_cycle_weight_peak_timesteps: + # fixed_cycle: weight loss by Gaussian peaks at fixed_cycle_weight_peak_timesteps (e.g. 500, 375), mean-normalized + peaks = self.train_config.fixed_cycle_weight_peak_timesteps + sigma = self.train_config.fixed_cycle_weight_sigma + peaks_t = torch.tensor(peaks, device=timesteps.device, dtype=timesteps.dtype) + diff = timesteps.unsqueeze(1).float() - peaks_t.unsqueeze(0) + weight_per_timestep = torch.exp(-(diff / sigma) ** 2).max(dim=1)[0] + cycle_ts = torch.tensor(self.train_config.fixed_cycle_timesteps, device=timesteps.device, dtype=timesteps.dtype) + diff_cycle = cycle_ts.unsqueeze(1).float() - peaks_t.unsqueeze(0) + mean_w = torch.exp(-(diff_cycle / sigma) ** 2).max(dim=1)[0].mean().clamp(min=1e-8) + timestep_weight = (weight_per_timestep / mean_w).to(loss.device, dtype=loss.dtype) + if len(loss.shape) == 4: + timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() + elif len(loss.shape) == 5: + timestep_weight = timestep_weight.view(-1, 1, 1, 1, 1).detach() + loss = loss * timestep_weight if self.train_config.do_prior_divergence and prior_pred is not None: loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) @@ -1983,15 +2004,19 @@ def get_adapter_multiplier(): prior_pred=prior_to_calculate_loss, ) - if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation: - # send the loss backwards otherwise checkpointing will fail - self.accelerator.backward(loss) - normal_loss = loss.detach() # dont send backward again - + # Determine if we should run BPP this step + should_run_bpp = False + if self.train_config.blank_prompt_preservation: + should_run_bpp = random.random() < self.train_config.blank_prompt_probability + + # Flag to track if backward was already performed in preservation block + backward_done = False + + if self.train_config.diff_output_preservation or should_run_bpp: with torch.no_grad(): if self.train_config.diff_output_preservation: preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) - elif self.train_config.blank_prompt_preservation: + elif should_run_bpp: blank_embeds = self.cached_blank_embeds.clone().detach().to( self.device_torch, dtype=dtype ) @@ -2008,12 +2033,14 @@ def get_adapter_multiplier(): ) multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier - self.accelerator.backward(preservation_loss) - - loss = normal_loss + preservation_loss - loss = loss.clone().detach() - # require grad again so the backward wont fail - loss.requires_grad_(True) + + # Combine main loss and preservation loss with loss_multiplier applied + total_loss = loss * loss_multiplier.mean() + preservation_loss + self.accelerator.backward(total_loss) + + # Detach total loss for logging and nan checks + loss = total_loss.detach() + backward_done = True # check if nan if torch.isnan(loss): @@ -2021,18 +2048,20 @@ def get_adapter_multiplier(): loss = torch.zeros_like(loss).requires_grad_(True) with self.timer('backward'): - # todo we have multiplier seperated. works for now as res are not in same batch, but need to change - loss = loss * loss_multiplier.mean() - # IMPORTANT if gradient checkpointing do not leave with network when doing backward - # it will destroy the gradients. This is because the network is a context manager - # and will change the multipliers back to 0.0 when exiting. They will be - # 0.0 for the backward pass and the gradients will be 0.0 - # I spent weeks on fighting this. DON'T DO IT - # with fsdp_overlap_step_with_backward(): - # if self.is_bfloat: - # loss.backward() - # else: - self.accelerator.backward(loss) + # Only perform backward if not already done in preservation block + if not backward_done: + # todo we have multiplier seperated. works for now as res are not in same batch, but need to change + loss = loss * loss_multiplier.mean() + # IMPORTANT if gradient checkpointing do not leave with network when doing backward + # it will destroy the gradients. This is because the network is a context manager + # and will change the multipliers back to 0.0 when exiting. They will be + # 0.0 for the backward pass and the gradients will be 0.0 + # I spent weeks on fighting this. DON'T DO IT + # with fsdp_overlap_step_with_backward(): + # if self.is_bfloat: + # loss.backward() + # else: + self.accelerator.backward(loss) return loss.detach() # flush() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 925d34daa..0b3e8e31e 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -4,6 +4,7 @@ import json import random import shutil +import time from collections import OrderedDict import os import re @@ -98,6 +99,9 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No self.start_step = 0 self.epoch_num = 0 self.last_save_step = 0 + # For logging timestep debugging + self._collected_indices = [] + self._collected_timesteps = [] # start at 1 so we can do a sample at the start self.grad_accumulation_step = 1 # if true, then we do not do an optimizer step. We are accumulating gradients @@ -469,14 +473,29 @@ def clean_up_saves(self): for item in items_to_remove: print_acc(f"Removing old save: {item}") - if os.path.isdir(item): - shutil.rmtree(item) - else: - os.remove(item) - # see if a yaml file with same name exists - yaml_file = os.path.splitext(item)[0] + ".yaml" - if os.path.exists(yaml_file): - os.remove(yaml_file) + try: + if os.path.isdir(item): + shutil.rmtree(item) + else: + os.remove(item) + # see if a yaml file with same name exists + yaml_file = os.path.splitext(item)[0] + ".yaml" + if os.path.exists(yaml_file): + os.remove(yaml_file) + except PermissionError as e: + print_acc(f"Could not remove {item}: {e}. Skipping (file may be in use).") + try: + time.sleep(2) + if os.path.isdir(item): + shutil.rmtree(item) + else: + os.remove(item) + yaml_file = os.path.splitext(item)[0] + ".yaml" + if os.path.exists(yaml_file): + os.remove(yaml_file) + print_acc(f"Retry succeeded: removed {item}") + except PermissionError: + print_acc(f"Retry failed for {item}. Leaving file in place.") if combined_items: latest_item = combined_items[-1] return latest_item @@ -1236,22 +1255,121 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): # for style, it is best to favor later timesteps orig_timesteps = torch.rand((batch_size,), device=latents.device) + ntt = self.train_config.num_train_timesteps if content_or_style == 'content': - timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps + timestep_indices = (1 - orig_timesteps) ** self.train_config.timestep_bias_exponent * self.train_config.num_train_timesteps elif content_or_style == 'style': - timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps + timestep_indices = orig_timesteps ** self.train_config.timestep_bias_exponent * self.train_config.num_train_timesteps + + timestep_indices = value_map( + timestep_indices, + 0, + ntt, + ntt - max_noise_steps, + ntt - min_noise_steps + ) + + timestep_indices = timestep_indices.long().clamp( + ntt - max_noise_steps, + ntt - min_noise_steps + ) + + timestep_indices.sort() + + elif content_or_style == 'gaussian': + # Gaussian (normal) distribution with configurable mean and std + # gaussian_mean: controls the center of distribution (default 0.5) + # - Lower values (0.0-0.5): favor earlier timesteps (more noise) + # - Higher values (0.5-1.0): favor later timesteps (less noise) + # gaussian_std: controls the spread of distribution (default 0.2) + # - Affects the shape: smaller = narrower, larger = wider + + # Curriculum learning: linearly interpolate gaussian_std if target is set + if self.train_config.gaussian_std_target is not None: + progress = self.step_num / self.train_config.steps + current_std = self.train_config.gaussian_std + progress * (self.train_config.gaussian_std_target - self.train_config.gaussian_std) + else: + current_std = self.train_config.gaussian_std + + def truncated_normal_samples(batch_size, mu, sigma, device, low=0.0, high=1.0): + # Convert mu and sigma to tensors to ensure all operations are handled by PyTorch + # and to avoid TypeErrors with torch.special functions + mu = torch.as_tensor(mu, device=device, dtype=torch.float32) + sigma = torch.as_tensor(sigma, device=device, dtype=torch.float32) + + # Standardize the boundaries (calculate z-scores) + alpha = (low - mu) / sigma + beta = (high - mu) / sigma + + # Calculate Cumulative Distribution Function (CDF) values for the boundaries + cdf_low = torch.special.ndtr(alpha) + cdf_high = torch.special.ndtr(beta) + # Generate uniform random samples in the [0, 1] range + u = torch.rand((batch_size,), device=device) + + # Linearly interpolate between the boundary CDF values + v = cdf_low + u * (cdf_high - cdf_low) + + # Clamp values to avoid numerical instability (infinities) when calling ndtri + v = v.clamp(1e-7, 1 - 1e-7) + + # Apply the Inverse CDF (Quantile function) to transform samples to the truncated normal distribution + samples = mu + sigma * torch.special.ndtri(v) + + return samples + + gaussian_samples = truncated_normal_samples( + batch_size, + self.train_config.gaussian_mean, + current_std, + latents.device + ) + + # Scale to num_train_timesteps + timestep_indices = gaussian_samples * self.train_config.num_train_timesteps + + ntt = self.train_config.num_train_timesteps + + # Map to min/max_noise_steps range (same as content/style) timestep_indices = value_map( timestep_indices, 0, - self.train_config.num_train_timesteps - 1, - min_noise_steps, - max_noise_steps + ntt, + ntt - max_noise_steps, + ntt - min_noise_steps ) + timestep_indices = timestep_indices.long().clamp( - min_noise_steps, - max_noise_steps + ntt - max_noise_steps, + ntt - min_noise_steps + ) + + timestep_indices.sort() + + elif content_or_style == 'fixed_cycle': + # Deterministic cycle over fixed timestep values (for Turbo LoRA reproducibility) + timestep_list = self.train_config.fixed_cycle_timesteps + if not timestep_list: + raise ValueError("content_or_style is 'fixed_cycle' but fixed_cycle_timesteps is empty") + resolved = getattr(self, '_fixed_cycle_resolved_timesteps', None) + if resolved is None: + list_copy = list(timestep_list) + if self.train_config.fixed_cycle_seed is not None: + random.Random(self.train_config.fixed_cycle_seed).shuffle(list_copy) + st = self.sd.noise_scheduler.timesteps + resolved = [] + for v in list_copy: + v_t = torch.tensor(v, device=st.device, dtype=st.dtype) + idx = (torch.abs(st - v_t)).argmin().item() + resolved.append(st[idx].item()) + self._fixed_cycle_resolved_timesteps = resolved + idx_cycle = self.step_num % len(resolved) + t_val = resolved[idx_cycle] + st = self.sd.noise_scheduler.timesteps + timesteps = torch.full( + (batch_size,), t_val, device=latents.device, dtype=st.dtype ) elif content_or_style == 'balanced': @@ -1275,8 +1393,71 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): else: raise ValueError(f"Unknown content_or_style {content_or_style}") with self.timer('convert_timestep_indices_to_timesteps'): - # convert the timestep_indices to a timestep - timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()] + # convert the timestep_indices to a timestep (fixed_cycle already set timesteps above) + if content_or_style != 'fixed_cycle': + timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()] + + # Debug logging for timestep distribution + if self.train_config.timestep_debug_log > 0: + # Always collect data (fixed_cycle has no timestep_indices, use cycle index) + if content_or_style == 'fixed_cycle': + self._collected_indices.append(self.step_num % len(self._fixed_cycle_resolved_timesteps)) + self._collected_timesteps.extend(timesteps.cpu().tolist()) + else: + self._collected_indices.extend(timestep_indices.cpu().tolist()) + self._collected_timesteps.extend(timesteps.cpu().tolist()) + + # Log when we have enough samples + if len(self._collected_indices) >= self.train_config.timestep_debug_log: + scheduler_timesteps = self.sd.noise_scheduler.timesteps.cpu().tolist() + + print_acc(f"\n{'='*70}") + print_acc(f"TIMESTEP DISTRIBUTION DEBUG") + print_acc(f"{'='*70}") + + print_acc(f"Total scheduler timesteps length: {len(scheduler_timesteps)}") + # print_acc(f"\nScheduler timesteps array (first 10): {scheduler_timesteps[:10]}") + # print_acc(f"Scheduler timesteps array (last 10): {scheduler_timesteps[-10:]}") + + print_acc(f"\nFirst 10 timestep_indices (generated indices):") + print_acc(f"{self._collected_indices[:10]}") + print_acc(f"\nFirst 10 timesteps (actual values after indexing):") + print_acc(f"{self._collected_timesteps[:10]}") + + print_acc(f"Config:") + print_acc(f" content_or_style: {content_or_style}") + print_acc(f" noise_scheduler: {self.train_config.noise_scheduler}") + print_acc(f" timestep_type: {self.train_config.timestep_type}") + print_acc(f" num_train_timesteps: {self.train_config.num_train_timesteps}") + print_acc(f" min_denoising_steps: {min_noise_steps}") + print_acc(f" max_denoising_steps: {max_noise_steps}") + print_acc(f" gaussian_mean: {self.train_config.gaussian_mean}") + print_acc(f" gaussian_std: {self.train_config.gaussian_std}") + print_acc(f" gaussian_std_target: {self.train_config.gaussian_std_target}") + + # Statistics + num_samples = self.train_config.timestep_debug_log + indices_min = min(self._collected_indices[:num_samples]) + indices_max = max(self._collected_indices[:num_samples]) + indices_mean = sum(self._collected_indices[:num_samples]) / num_samples + + timesteps_min = min(self._collected_timesteps[:num_samples]) + timesteps_max = max(self._collected_timesteps[:num_samples]) + timesteps_mean = sum(self._collected_timesteps[:num_samples]) / num_samples + + print_acc(f"\nStatistics ({num_samples} samples):") + print_acc(f" Indices: max={indices_max}, mean={indices_mean:.1f}, min={indices_min}") + print_acc(f" Timesteps: max={timesteps_max:.1f}, mean={timesteps_mean:.1f}, min={timesteps_min:.1f}") + if self.train_config.gaussian_std_target is not None: + progress = self.step_num / self.train_config.steps + current_std = self.train_config.gaussian_std + progress * (self.train_config.gaussian_std_target - self.train_config.gaussian_std) + percent = progress * 100.0 + print_acc(f" current_std: {current_std:.3f} (progress: {percent:.1f}%)") + print_acc(f" Step: {self.step_num} ({self.step_num * 100 / self.train_config.steps:.1f}%)") + print_acc(f"{'='*70}\n") + + self._collected_indices = [] + self._collected_timesteps = [] with self.timer('prepare_noise'): # get noise @@ -1758,6 +1939,8 @@ def run(self): is_ssd=self.model_config.is_ssd, is_vega=self.model_config.is_vega, dropout=self.network_config.dropout, + rank_dropout=self.network_config.rank_dropout, + module_dropout=self.network_config.module_dropout, use_text_encoder_1=self.model_config.use_text_encoder_1, use_text_encoder_2=self.model_config.use_text_encoder_2, use_bias=is_lorm, @@ -1806,6 +1989,61 @@ def run(self): self.network.prepare_grad_etc(text_encoder, unet) flush() + # Create sampling network if sampling_transformer is specified and LoRA is used + if hasattr(self.sd, '_sampling_transformer') and self.sd._sampling_transformer is not None: + print_acc("Creating sampling network for _sampling_transformer") + sampling_network = NetworkClass( + text_encoder=text_encoder, + unet=self.sd._sampling_transformer, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + rank_dropout=self.network_config.rank_dropout, + module_dropout=self.network_config.module_dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=is_lorm, + is_lorm=is_lorm, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=self.sd.is_transformer, + base_model=self.sd, + **network_kwargs + ) + + sampling_network.force_to(self.device_torch, dtype=torch.float32) + sampling_network._update_torch_multiplier() + + sampling_network.apply_to( + text_encoder, + self.sd._sampling_transformer, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + + # Set can_merge_in same as main network (False if quantized/layer_offloading) + if self.model_config.quantize or self.model_config.layer_offloading: + sampling_network.can_merge_in = False + + # Store sampling network on model for use during generation + self.sd._sampling_network = sampling_network + flush() + # LyCORIS doesnt have default_lr config = { 'text_encoder_lr': self.train_config.lr, @@ -2226,6 +2464,7 @@ def run(self): # torch.cuda.empty_cache() # if optimizer has get_lrs method, then use it learning_rate = 0.0 + update_rms = 0.0 # Average weight update RMS (for monitoring optimizer step magnitude) if not did_oom and loss_dict is not None: if hasattr(optimizer, 'get_avg_learning_rate'): learning_rate = optimizer.get_avg_learning_rate() @@ -2239,8 +2478,14 @@ def run(self): ) else: learning_rate = optimizer.param_groups[0]['lr'] + + # Get average weight update RMS if optimizer supports it (e.g., Adafactor) + if hasattr(optimizer, 'get_avg_update_rms'): + update_rms = optimizer.get_avg_update_rms() prog_bar_string = f"lr: {learning_rate:.1e}" + if update_rms > 0: + prog_bar_string += f" upd: {update_rms:.2e}" for key, value in loss_dict.items(): prog_bar_string += f" {key}: {value:.3e}" @@ -2300,6 +2545,9 @@ def run(self): for key, value in loss_dict.items(): self.writer.add_scalar(f"{key}", value, self.step_num) self.writer.add_scalar(f"lr", learning_rate, self.step_num) + # Log weight update RMS if available (shows optimizer step magnitude) + if update_rms > 0: + self.writer.add_scalar("train/update_rms", update_rms, self.step_num) if self.progress_bar is not None: self.progress_bar.unpause() @@ -2308,6 +2556,17 @@ def run(self): self.logger.log({ 'learning_rate': learning_rate, }) + # Log differential guidance norm if available + if hasattr(self, 'diff_guidance_norm') and self.diff_guidance_norm is not None: + self.logger.log({ + 'diff_guidance_norm': self.diff_guidance_norm, + }) + self.diff_guidance_norm = None + # Log weight update RMS if available (Adafactor optimizer statistic) + if update_rms > 0: + self.logger.log({ + 'train/update_rms': update_rms, + }) if loss_dict is not None: for key, value in loss_dict.items(): self.logger.log({ @@ -2319,6 +2578,17 @@ def run(self): self.logger.log({ 'learning_rate': learning_rate, }) + # Log differential guidance norm if available + if hasattr(self, 'diff_guidance_norm') and self.diff_guidance_norm is not None: + self.logger.log({ + 'diff_guidance_norm': self.diff_guidance_norm, + }) + self.diff_guidance_norm = None + # Log weight update RMS if available (Adafactor optimizer statistic) + if update_rms > 0: + self.logger.log({ + 'train/update_rms': update_rms, + }) for key, value in loss_dict.items(): self.logger.log({ f'loss/{key}': value, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a48cbe09b..ae3447156 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -182,6 +182,8 @@ def __init__(self, **kwargs): self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) self.dropout: Union[float, None] = kwargs.get('dropout', None) + self.rank_dropout: Union[float, None] = kwargs.get('rank_dropout', None) + self.module_dropout: Union[float, None] = kwargs.get('module_dropout', None) self.network_kwargs: dict = kwargs.get('network_kwargs', {}) self.lorm_config: Union[LoRMConfig, None] = None @@ -344,7 +346,7 @@ def __init__(self, **kwargs): self.num_tokens: str = kwargs.get('num_tokens', 4) -ContentOrStyleType = Literal['balanced', 'style', 'content'] +ContentOrStyleType = Literal['balanced', 'style', 'content', 'gaussian', 'fixed_cycle'] LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise'] @@ -353,6 +355,18 @@ def __init__(self, **kwargs): self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') self.content_or_style_reg: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') + self.gaussian_mean: float = kwargs.get('gaussian_mean', 0.5) + self.gaussian_std: float = kwargs.get('gaussian_std', 0.2) + self.gaussian_std_target: float = kwargs.get('gaussian_std_target', None) + self.timestep_bias_exponent: float = kwargs.get('timestep_bias_exponent', 3.0) + self.timestep_debug_log: int = kwargs.get('timestep_debug_log', 0) + # fixed_cycle: deterministic cycle over fixed timestep values (for Turbo LoRA reproducibility) + _default_fixed_cycle = [999, 875, 750, 625, 500, 375, 250, 125] + _fc = kwargs.get('fixed_cycle_timesteps', _default_fixed_cycle) + self.fixed_cycle_timesteps: Optional[List[float]] = _fc if (_fc is not None and len(_fc) > 0) else _default_fixed_cycle + self.fixed_cycle_seed: Optional[int] = kwargs.get('fixed_cycle_seed', None) + self.fixed_cycle_weight_peak_timesteps: Optional[List[float]] = kwargs.get('fixed_cycle_weight_peak_timesteps', [500, 375]) + self.fixed_cycle_weight_sigma: float = kwargs.get('fixed_cycle_weight_sigma', 372.8) self.steps: int = kwargs.get('steps', 1000) self.lr = kwargs.get('lr', 1e-6) self.unet_lr = kwargs.get('unet_lr', self.lr) @@ -464,6 +478,7 @@ def __init__(self, **kwargs): # blank prompt preservation will preserve the model's knowledge of a blank prompt self.blank_prompt_preservation = kwargs.get('blank_prompt_preservation', False) self.blank_prompt_preservation_multiplier = kwargs.get('blank_prompt_preservation_multiplier', 1.0) + self.blank_prompt_probability = kwargs.get('blank_prompt_probability', 1.0) # legacy if match_adapter_assist and self.match_adapter_chance == 0.0: @@ -667,6 +682,10 @@ def __init__(self, **kwargs): # 20 different model variants self.extras_name_or_path = kwargs.get("extras_name_or_path", self.name_or_path) + # for models that support it (e.g., zimage), a separate model path for sampling/inference + # training uses name_or_path, sampling uses sampling_name_or_path if set + self.sampling_name_or_path: Optional[str] = kwargs.get("sampling_name_or_path", None) + # path to an accuracy recovery adapter, either local or remote self.accuracy_recovery_adapter = kwargs.get("accuracy_recovery_adapter", None) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 20140eb8b..39dbc6468 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -578,6 +578,8 @@ def load_and_process_video( img = img.transpose(Image.FLIP_TOP_BOTTOM) # Apply bucketing + if self.scale_to_width <= 0 or self.scale_to_height <= 0: + raise ValueError(f"Invalid scale dimensions for video {self.path}: scale_to_width={self.scale_to_width}, scale_to_height={self.scale_to_height}") img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) img = img.crop(( self.crop_x, @@ -780,6 +782,8 @@ def load_and_process_image( if self.dataset_config.buckets: # scale and crop based on file item + if self.scale_to_width <= 0 or self.scale_to_height <= 0: + raise ValueError(f"Invalid scale dimensions for image {self.path}: scale_to_width={self.scale_to_width}, scale_to_height={self.scale_to_height}") img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) # crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height: diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 5cb19229a..12cec5c1b 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -279,12 +279,15 @@ def __init__( if self.network_type.lower() != "lokr" or not self.use_old_lokr_format: self.peft_format = True - if self.peft_format: - # no alpha for peft - self.alpha = self.lora_dim - alpha = self.alpha - self.conv_alpha = self.conv_lora_dim - conv_alpha = self.conv_alpha + # Allow alpha for peft + # Todo: Folding Alpha on save + # + # if self.peft_format: + # # no alpha for peft + # self.alpha = self.lora_dim + # alpha = self.alpha + # self.conv_alpha = self.conv_lora_dim + # conv_alpha = self.conv_alpha self.full_train_in_out = full_train_in_out diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index b8556f125..493878bd3 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -547,8 +547,8 @@ def get_state_dict(self: Network, extra_state_dict=None, dtype=torch.float16): new_save_dict = {} for key, value in save_dict.items(): # lokr needs alpha - if key.endswith('.alpha') and self.network_type.lower() != "lokr": - continue + # if key.endswith('.alpha') and self.network_type.lower() != "lokr": + # continue new_key = key new_key = new_key.replace('lora_down', 'lora_A') new_key = new_key.replace('lora_up', 'lora_B') @@ -633,8 +633,8 @@ def load_weights(self: Network, file, force_weight_mapping=False): # lora_down = lora_A # lora_up = lora_B # no alpha - if load_key.endswith('.alpha') and self.network_type.lower() != "lokr": - continue + # if load_key.endswith('.alpha') and self.network_type.lower() != "lokr": + # continue load_key = load_key.replace('lora_A', 'lora_down') load_key = load_key.replace('lora_B', 'lora_up') # replace all . with $$ diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index bac7f3fde..cc1e71033 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -25,6 +25,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.init_noise_sigma = 1.0 self.timestep_type = "linear" + self._alphas_cumprod = None # Lazy initialization with torch.no_grad(): # create weights for timesteps @@ -56,6 +57,60 @@ def __init__(self, *args, **kwargs): self.linear_timesteps_weights2 = hbsmntw_weighing pass + def _compute_alphas_cumprod(self): + """ + Compute equivalent alphas_cumprod for flow matching for SNR weighting compatibility. + + For flow matching: x_t = (1-t)*x_0 + t*noise + Equivalent SNR: (1-t)^2 / t^2 + + For DDPM: SNR = alphas_cumprod / (1 - alphas_cumprod) + Therefore: alphas_cumprod = (1-t)^2 + """ + num_timesteps = 1000 + # Create timesteps from 1000 to 0 (descending) + timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + + # Normalize to [0, 1] range + t = timesteps / 1000.0 + + # Clamp to avoid numerical issues at boundaries + t = torch.clamp(t, min=1e-8, max=1.0 - 1e-8) + + # Compute equivalent alphas_cumprod: (1-t)^2 + # This gives correct SNR: alphas_cumprod / (1 - alphas_cumprod) = (1-t)^2 / t^2 + alphas_cumprod = (1.0 - t) ** 2 + + return alphas_cumprod + + def compute_snr(self): + """ + Compute SNR for each timestep in flow matching. + + For flow matching: x_t = (1-t)*x_0 + t*noise + Signal = (1-t)*x_0, Noise = t*noise + SNR = (1-t)^2 / t^2 + """ + num_timesteps = 1000 + timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + t = timesteps / 1000.0 + + # Add small epsilon to avoid division by zero + epsilon = 1e-8 + snr = ((1.0 - t) ** 2) / (t ** 2 + epsilon) + + return snr + + @property + def alphas_cumprod(self): + """ + Returns equivalent alphas_cumprod for flow matching. + This provides compatibility with SNR weighting functions in train_tools.py. + """ + if self._alphas_cumprod is None: + self._alphas_cumprod = self._compute_alphas_cumprod() + return self._alphas_cumprod + def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False, timestep_type="linear") -> torch.Tensor: # Get the indices of the timesteps step_indices = [(self.timesteps == t).nonzero().item() diff --git a/toolkit/scheduler.py b/toolkit/scheduler.py index f6f8f61ae..efd4cf5e0 100644 --- a/toolkit/scheduler.py +++ b/toolkit/scheduler.py @@ -3,22 +3,130 @@ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup +class SequentialLRWrapper(torch.optim.lr_scheduler.SequentialLR): + """ + Wrapper for SequentialLR that ignores extra arguments to step(). + + This is needed because the training code calls lr_scheduler.step(step_num), + but SequentialLR.step() doesn't accept any arguments. + """ + def step(self, *args, **kwargs): + # Ignore all arguments and call parent step() without them + super().step() + + +def _create_scheduler_with_warmup( + scheduler_type: str, + optimizer: torch.optim.Optimizer, + **scheduler_kwargs +): + """ + Creates a scheduler with optional warmup period using SequentialLR. + + Args: + scheduler_type: 'cosine' or 'cosine_with_restarts' + optimizer: The optimizer to schedule + scheduler_kwargs: Parameters for the scheduler. Can include: + - warmup_steps: Number of warmup steps (default: 0, no warmup) + - total_iters: TOTAL number of iterations INCLUDING warmup (default: 1000) + - T_0/T_max: Iterations for MAIN scheduler, overrides calculation from total_iters + - Other scheduler parameters (T_mult, eta_min, etc.) + + Semantics: + - total_iters: Total training iterations (warmup + main scheduler) + - T_0/T_max: Main scheduler iterations (if specified, total_iters is ignored) + - If only total_iters specified: main_iters = total_iters - warmup_steps + - If T_0/T_max specified: main_iters = T_0/T_max (total_iters ignored) + + Returns: + A scheduler (SequentialLR if warmup_steps > 0, otherwise base scheduler) + """ + # Extract warmup_steps (default: 0, no warmup) + warmup_steps = scheduler_kwargs.pop('warmup_steps', 0) + + # Extract total_iters (GENERAL total, including warmup) + total_iters = scheduler_kwargs.pop('total_iters', 1000) + + # Calculate main scheduler iterations + # T_0/T_max have priority and specify main scheduler iterations directly + if scheduler_type == "cosine": + if 'T_max' in scheduler_kwargs: + # T_max specifies main scheduler iterations (ignores total_iters) + main_total_iters = scheduler_kwargs.pop('T_max') + else: + # Calculate from total_iters + main_total_iters = total_iters - warmup_steps + elif scheduler_type == "cosine_with_restarts": + if 'T_0' in scheduler_kwargs: + # T_0 specifies main scheduler iterations (ignores total_iters) + main_total_iters = scheduler_kwargs.pop('T_0') + else: + # Calculate from total_iters + main_total_iters = total_iters - warmup_steps + + # Validation: warn if configuration seems incorrect + if main_total_iters <= 0: + raise ValueError( + f"Main scheduler iterations must be positive, got {main_total_iters}. " + f"Check your total_iters ({total_iters}) and warmup_steps ({warmup_steps})." + ) + if warmup_steps > 0 and warmup_steps >= total_iters: + print(f"WARNING: warmup_steps ({warmup_steps}) >= total_iters ({total_iters}). " + f"The main scheduler will have very few or no iterations.") + + if warmup_steps <= 0: + # No warmup, create base scheduler directly + if scheduler_type == "cosine": + return torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=main_total_iters, **scheduler_kwargs + ) + elif scheduler_type == "cosine_with_restarts": + return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=main_total_iters, **scheduler_kwargs + ) + + # Create warmup scheduler (linear from ~0 to 1.0) + warmup_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.1, # 1/10 of the target LR + end_factor=1.0, # End at full LR + total_iters=warmup_steps + ) + + # Create main scheduler + if scheduler_type == "cosine": + main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=main_total_iters, **scheduler_kwargs + ) + elif scheduler_type == "cosine_with_restarts": + main_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=main_total_iters, **scheduler_kwargs + ) + + # Combine schedulers using SequentialLRWrapper + combined_scheduler = SequentialLRWrapper( + optimizer, + schedulers=[warmup_scheduler, main_scheduler], + milestones=[warmup_steps] + ) + + return combined_scheduler + + def get_lr_scheduler( name: Optional[str], optimizer: torch.optim.Optimizer, **kwargs, ): if name == "cosine": - if 'total_iters' in kwargs: - kwargs['T_max'] = kwargs.pop('total_iters') - return torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, **kwargs + # All parameters passed via kwargs, handled in _create_scheduler_with_warmup + return _create_scheduler_with_warmup( + "cosine", optimizer, **kwargs ) elif name == "cosine_with_restarts": - if 'total_iters' in kwargs: - kwargs['T_0'] = kwargs.pop('total_iters') - return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( - optimizer, **kwargs + # All parameters passed via kwargs, handled in _create_scheduler_with_warmup + return _create_scheduler_with_warmup( + "cosine_with_restarts", optimizer, **kwargs ) elif name == "step": diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 78e2183c3..f75c41961 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -624,7 +624,16 @@ def add_all_snr_to_noise_scheduler(noise_scheduler, device): try: if hasattr(noise_scheduler, "all_snr"): return - # compute it + + # Handle flow matching schedulers that have compute_snr method + if hasattr(noise_scheduler, "compute_snr") and callable(getattr(noise_scheduler, "compute_snr")): + with torch.no_grad(): + all_snr = noise_scheduler.compute_snr() + all_snr.requires_grad = False + noise_scheduler.all_snr = all_snr.to(device) + return + + # Standard DDPM/DDIM scheduler path with torch.no_grad(): alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) @@ -642,6 +651,15 @@ def add_all_snr_to_noise_scheduler(noise_scheduler, device): def get_all_snr(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): return noise_scheduler.all_snr.to(device) + + # Handle flow matching schedulers that have compute_snr method + if hasattr(noise_scheduler, "compute_snr") and callable(getattr(noise_scheduler, "compute_snr")): + with torch.no_grad(): + all_snr = noise_scheduler.compute_snr() + all_snr.requires_grad = False + return all_snr.to(device) + + # Standard DDPM/DDIM scheduler path # compute it with torch.no_grad(): alphas_cumprod = noise_scheduler.alphas_cumprod diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 5db650c30..fd20e39aa 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -531,12 +531,15 @@ export default function SimpleJob({ setJobConfig(value, 'config.process[0].train.content_or_style')} options={[ { value: 'balanced', label: 'Balanced' }, { value: 'content', label: 'High Noise' }, { value: 'style', label: 'Low Noise' }, + { value: 'gaussian', label: 'Gaussian (Normal)' }, + { value: 'fixed_cycle', label: 'Fixed Cycle' }, ]} /> + {jobConfig.config.process[0].train.content_or_style === 'fixed_cycle' && ( + <> + { + const arr = value + .split(',') + .map(s => parseFloat(s.trim())) + .filter(n => !isNaN(n)); + setJobConfig(arr, 'config.process[0].train.fixed_cycle_timesteps'); + }} + placeholder="eg. 999, 875, 750, 625, 500, 375, 250, 125" + /> + setJobConfig(value ? parseInt(value as string) : null, 'config.process[0].train.fixed_cycle_seed')} + placeholder="eg. 42 (leave empty for no shuffle)" + min={0} + /> + { + const arr = value + .split(',') + .map(s => parseFloat(s.trim())) + .filter(n => !isNaN(n)); + setJobConfig(arr.length > 0 ? arr : null, 'config.process[0].train.fixed_cycle_weight_peak_timesteps'); + }} + placeholder="eg. 500, 375 (leave empty to disable)" + /> + setJobConfig(value, 'config.process[0].train.fixed_cycle_weight_sigma')} + placeholder="eg. 372.8" + min={0} + step={0.1} + /> + + )}
@@ -675,6 +737,21 @@ export default function SimpleJob({ placeholder="eg. 1.0" min={0} /> + + setJobConfig(value, 'config.process[0].train.blank_prompt_probability') + } + placeholder="eg. 0.1 for 10%, 1.0 for 100%" + min={0} + max={1} + step={0.1} + /> )} diff --git a/ui/src/components/JobLossGraph.tsx b/ui/src/components/JobLossGraph.tsx index 78a77bfb5..1a736a0db 100644 --- a/ui/src/components/JobLossGraph.tsx +++ b/ui/src/components/JobLossGraph.tsx @@ -1,7 +1,7 @@ 'use client'; import { Job } from '@prisma/client'; -import useJobLossLog, { LossPoint } from '@/hooks/useJobLossLog'; +import useJobLossLog, { LossPoint, MetricFilter } from '@/hooks/useJobLossLog'; import { useMemo, useState, useEffect } from 'react'; import { ResponsiveContainer, LineChart, Line, XAxis, YAxis, Tooltip, CartesianGrid, Legend } from 'recharts'; @@ -65,7 +65,8 @@ function strokeForKey(key: string) { } export default function JobLossGraph({ job }: Props) { - const { series, lossKeys, status, refreshLoss } = useJobLossLog(job.id, 2000); + const [metricFilter, setMetricFilter] = useState('loss'); + const { series, filteredKeys, status, refreshLoss } = useJobLossLog(job.id, 2000, metricFilter); // Controls const [useLogScale, setUseLogScale] = useState(false); @@ -91,18 +92,18 @@ export default function JobLossGraph({ job }: Props) { useEffect(() => { setEnabled(prev => { const next = { ...prev }; - for (const k of lossKeys) { + for (const k of filteredKeys) { if (next[k] === undefined) next[k] = true; } // drop removed keys for (const k of Object.keys(next)) { - if (!lossKeys.includes(k)) delete next[k]; + if (!filteredKeys.includes(k)) delete next[k]; } return next; }); - }, [lossKeys]); + }, [filteredKeys]); - const activeKeys = useMemo(() => lossKeys.filter(k => enabled[k] !== false), [lossKeys, enabled]); + const activeKeys = useMemo(() => filteredKeys.filter(k => enabled[k] !== false), [filteredKeys, enabled]); const perSeries = useMemo(() => { // Build per-series processed point arrays (raw + smoothed), then merge by step for charting. @@ -318,7 +319,7 @@ export default function JobLossGraph({ job }: Props) { {/* Controls */}
-
+
@@ -329,13 +330,44 @@ export default function JobLossGraph({ job }: Props) {
+
+ +
+ setMetricFilter('loss')} + label="Loss" + /> + setMetricFilter('learning_rate')} + label="Learning Rate" + /> + setMetricFilter('diff_guidance')} + label="Diff Guidance" + /> + setMetricFilter('all')} + label="All" + /> + setMetricFilter('other')} + label="Other" + /> +
+
+
- {lossKeys.length === 0 ? ( -
No loss keys found yet.
+ {filteredKeys.length === 0 ? ( +
No keys found yet.
) : (
- {lossKeys.map(k => ( + {filteredKeys.map(k => (
-
+
{windowSize === 0 ? 'all' : windowSize.toLocaleString()} diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index c82d800bd..6da72c11c 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -287,6 +287,78 @@ const docs: { [key: string]: ConfigDoc } = { ), }, + 'train.blank_prompt_probability': { + title: 'BPP Probability', + description: ( + <> + Controls how often the Blank Prompt Preservation check runs during training. + Value between 0.0 and 1.0. Default is 1.0 (runs every step). + Setting to 0.1 means BPP runs ~10% of steps, reducing training time by up to 45% + while still preventing model degradation. Lower values give the model more freedom + to adapt to new concepts between BPP corrections. Recommended: 0.1-0.2 for Turbo models. + + ), + }, + 'train.content_or_style': { + title: 'Timestep Bias', + description: ( + <> + Controls how timesteps are sampled during training: +

+ Balanced: Uniform distribution across all timesteps. +

+ High Noise: Cubic distribution favoring earlier timesteps (more noise). Best for learning content/structure. +

+ Low Noise: Cubic distribution favoring later timesteps (less noise). Best for learning style/details. +

+ Gaussian (Normal): Normal distribution with configurable center and spread. Use gaussian_mean and gaussian_std in YAML config: +
+ • gaussian_mean (default 0.5): Center of distribution. Lower values (0.0-0.5) = more noise/earlier timesteps, higher values (0.5-1.0) = less noise/later timesteps. +
+ • gaussian_std (default 0.2): Spread of distribution. Smaller = narrower focus, larger = wider coverage. +
+ • gaussian_std_target (optional, default None): Enable curriculum learning. When set, gaussian_std will linearly interpolate from initial value to this target value during training. Example: start with gaussian_std: 0.001 (narrow distribution, focused training) → end with gaussian_std_target: 0.3 (wide distribution, diverse timestep coverage). +
+ • timestep_bias_exponent (default 3.0): Controls the cubic bias exponent for content/style timestep distribution. Higher values create stronger bias toward edges (early timesteps for content, late timesteps for style). Lower values create more uniform distribution. +

+ Fixed Cycle: Deterministic cycle over a fixed list of timestep values. Same step number always gets the same timestep, so training is reproducible. Recommended for distilled/Turbo models (e.g. Z-Image-Turbo LoRA) that are sensitive to random timestep sampling. Configure via YAML: fixed_cycle_timesteps, fixed_cycle_seed, fixed_cycle_weight_peak_timesteps. + + ), + }, + 'train.fixed_cycle_timesteps': { + title: 'Fixed Cycle Timesteps', + description: ( + <> + Used when content_or_style is fixed_cycle. List of timestep values (scheduler scale, typically 0–1000) to cycle through deterministically. At step s, the whole batch uses fixed_cycle_timesteps[s % length] (after optional shuffle by fixed_cycle_seed). Values are snapped to the nearest scheduler timestep. +

+ Default: [999, 875, 750, 625, 500, 375, 250, 125]. Set in config file (e.g. YAML) as an array of numbers. + + ), + }, + 'train.fixed_cycle_seed': { + title: 'Fixed Cycle Seed', + description: ( + <> + Used when content_or_style is fixed_cycle. If set, the list fixed_cycle_timesteps is shuffled once at the start of training using this seed. Same seed gives the same cycle order and reproducible runs. If null or unset, the cycle uses the order from the config. + + ), + }, + 'train.fixed_cycle_weight_peak_timesteps': { + title: 'Fixed Cycle Weight Peak Timesteps', + description: ( + <> + Used when content_or_style is fixed_cycle. Enables timestep-weighted loss (like timestep_type: weighted): loss is multiplied by weights with peaks at these timestep values. Default: [500, 375] so maximum weight is around 500 and 375. Set to null or empty to disable weighting. Set in config file as an array of numbers. + + ), + }, + 'train.fixed_cycle_weight_sigma': { + title: 'Fixed Cycle Weight Sigma', + description: ( + <> + Used when content_or_style is fixed_cycle and fixed_cycle_weight_peak_timesteps is set. Standard deviation (sigma) for the Gaussian distribution used to weight loss around the peak timesteps. Controls the width of the Gaussian peaks — larger values create wider, flatter distributions, smaller values create sharper peaks. Default: 372.8. + + ), + }, 'train.do_differential_guidance': { title: 'Differential Guidance', description: ( diff --git a/ui/src/hooks/useJobLossLog.tsx b/ui/src/hooks/useJobLossLog.tsx index a0bf27ed3..5c4d062cb 100644 --- a/ui/src/hooks/useJobLossLog.tsx +++ b/ui/src/hooks/useJobLossLog.tsx @@ -11,13 +11,33 @@ export interface LossPoint { type SeriesMap = Record; +export type MetricFilter = 'loss' | 'learning_rate' | 'diff_guidance' | 'all' | 'other'; + +function categorizeMetric(key: string): 'loss' | 'learning_rate' | 'diff_guidance' | 'other' { + if (key === 'learning_rate') return 'learning_rate'; + if (key === 'diff_guidance_norm') return 'diff_guidance'; + if (/loss/i.test(key)) return 'loss'; + return 'other'; +} + +function matchesFilter(key: string, filter: MetricFilter): boolean { + if (filter === 'all') return true; + const category = categorizeMetric(key); + if (filter === 'other') return category === 'other'; + return category === filter; +} + function isLossKey(key: string) { // treat anything containing "loss" as a loss-series // (covers loss, train_loss, val_loss, loss/xyz, etc.) return /loss/i.test(key); } -export default function useJobLossLog(jobID: string, reloadInterval: null | number = null) { +export default function useJobLossLog( + jobID: string, + reloadInterval: null | number = null, + metricFilter: MetricFilter = 'loss' +) { const [series, setSeries] = useState({}); const [keys, setKeys] = useState([]); const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error' | 'refreshing'>('idle'); @@ -28,12 +48,12 @@ export default function useJobLossLog(jobID: string, reloadInterval: null | numb // track last step per key so polling is incremental per series const lastStepByKeyRef = useRef>({}); - const lossKeys = useMemo(() => { - const base = (keys ?? []).filter(isLossKey); + const filteredKeys = useMemo(() => { + const filtered = (keys ?? []).filter(k => matchesFilter(k, metricFilter)); // if keys table is empty early on, fall back to just "loss" - if (base.length === 0) return ['loss']; - return base.sort(); - }, [keys]); + if (filtered.length === 0 && metricFilter === 'loss') return ['loss']; + return filtered.sort(); + }, [keys, metricFilter]); const refreshLoss = useCallback(async () => { if (!jobID) return; @@ -54,10 +74,13 @@ export default function useJobLossLog(jobID: string, reloadInterval: null | numb const newKeys = first.keys ?? []; setKeys(newKeys); - const wantedLossKeys = (newKeys.filter(isLossKey).length ? newKeys.filter(isLossKey) : ['loss']).sort(); + const wantedKeys = (newKeys.filter(k => matchesFilter(k, metricFilter)).length + ? newKeys.filter(k => matchesFilter(k, metricFilter)) + : (metricFilter === 'loss' ? ['loss'] : []) + ).sort(); // Step 2: fetch each loss key incrementally (since_step per key if polling) - const requests = wantedLossKeys.map(k => { + const requests = wantedKeys.map(k => { const params: Record = { key: k }; if (reloadInterval && lastStepByKeyRef.current[k] != null) { @@ -101,9 +124,9 @@ export default function useJobLossLog(jobID: string, reloadInterval: null | numb : (lastStepByKeyRef.current[k] ?? null); } - // remove stale loss keys that no longer exist (rare, but keeps UI clean) + // remove stale keys that no longer exist (rare, but keeps UI clean) for (const existingKey of Object.keys(next)) { - if (isLossKey(existingKey) && !wantedLossKeys.includes(existingKey)) { + if (matchesFilter(existingKey, metricFilter) && !wantedKeys.includes(existingKey)) { delete next[existingKey]; delete lastStepByKeyRef.current[existingKey]; } @@ -120,7 +143,7 @@ export default function useJobLossLog(jobID: string, reloadInterval: null | numb } finally { inFlightRef.current = false; } - }, [jobID, reloadInterval]); + }, [jobID, reloadInterval, metricFilter]); useEffect(() => { // reset when job changes @@ -141,5 +164,5 @@ export default function useJobLossLog(jobID: string, reloadInterval: null | numb } }, [jobID, reloadInterval, refreshLoss]); - return { series, keys, lossKeys, status, refreshLoss, setSeries }; + return { series, keys, filteredKeys, status, refreshLoss, setSeries }; } diff --git a/ui/src/types.ts b/ui/src/types.ts index 07f1e1d18..8e1d692aa 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -142,10 +142,15 @@ export interface TrainConfig { diff_output_preservation_class: string; blank_prompt_preservation?: boolean; blank_prompt_preservation_multiplier?: number; + blank_prompt_probability?: number; switch_boundary_every: number; loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped'; do_differential_guidance?: boolean; differential_guidance_scale?: number; + fixed_cycle_timesteps?: number[]; + fixed_cycle_seed?: number | null; + fixed_cycle_weight_peak_timesteps?: number[] | null; + fixed_cycle_weight_sigma?: number; } export interface QuantizeKwargsConfig { From 2b2ecc871694ae9276badc34e2e659d664a70974 Mon Sep 17 00:00:00 2001 From: Daan Date: Fri, 13 Feb 2026 17:59:21 +0300 Subject: [PATCH 04/16] feat(optimizers): Adafactor min/max lr, RMS tracking, fixes - min_lr/max_lr, get_avg_learning_rate, RMS tracking methods - Fix lr=0 with relative_step, uninitialized state, clamp to max_lr - Truncated normal sampling, weight update RMS logging --- toolkit/optimizers/adafactor.py | 150 ++++++++++++++++++++++++-------- 1 file changed, 112 insertions(+), 38 deletions(-) diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py index 8897bdc07..41596c08a 100644 --- a/toolkit/optimizers/adafactor.py +++ b/toolkit/optimizers/adafactor.py @@ -32,6 +32,8 @@ class Adafactor(torch.optim.Optimizer): Coefficient used to compute running averages of square beta1 (`float`, *optional*): Coefficient used for computing running averages of gradient + (first moment, like in Adam). If not None, enables momentum. + Suggested values: 0.9 (default), 0.95 or 0.99 for smoother updates. weight_decay (`float`, *optional*, defaults to 0.0): Weight decay (L2 penalty) scale_parameter (`bool`, *optional*, defaults to `True`): @@ -40,6 +42,12 @@ class Adafactor(torch.optim.Optimizer): If True, time-dependent learning rate is computed instead of external learning rate warmup_init (`bool`, *optional*, defaults to `False`): Time-dependent learning rate computation depends on whether warm-up initialization is being used + min_lr (`float`, *optional*, defaults to `1e-6`): + Minimum learning rate multiplier for warmup phase when `warmup_init=True` and `relative_step=True`. + Controls the linear growth rate: `lr = min_lr * step` during warmup. + max_lr (`float`, *optional*, defaults to `1e-2`): + Maximum learning rate cap for relative step mode when `relative_step=True`. + Acts as upper bound for `min_step` when `warmup_init=False` or when warmup phase completes. This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. @@ -106,13 +114,15 @@ def __init__( scale_parameter=True, relative_step=True, warmup_init=False, - do_paramiter_swapping=False, - paramiter_swapping_factor=0.1, + min_lr=1e-6, + max_lr=1e-2, + do_parameter_swapping=False, + parameter_swapping_factor=0.1, stochastic_accumulation=True, stochastic_rounding=True, ): self.stochastic_rounding = stochastic_rounding - if lr is not None and relative_step: + if lr is not None and lr != 0 and relative_step: raise ValueError( "Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: @@ -129,11 +139,13 @@ def __init__( "scale_parameter": scale_parameter, "relative_step": relative_step, "warmup_init": warmup_init, + "min_lr": min_lr, + "max_lr": max_lr, } super().__init__(params, defaults) self.base_lrs: List[float] = [ - lr for group in self.param_groups + group['lr'] for group in self.param_groups ] self.is_stochastic_rounding_accumulation = False @@ -148,60 +160,65 @@ def __init__( stochastic_grad_accummulation ) - self.do_paramiter_swapping = do_paramiter_swapping - self.paramiter_swapping_factor = paramiter_swapping_factor - self._total_paramiter_size = 0 - # count total paramiters + self.do_parameter_swapping = do_parameter_swapping + self.parameter_swapping_factor = parameter_swapping_factor + self._total_parameter_size = 0 + # count total parameters for group in self.param_groups: for param in group['params']: - self._total_paramiter_size += torch.numel(param) - # pretty print total paramiters with comma seperation - print(f"Total training paramiters: {self._total_paramiter_size:,}") + self._total_parameter_size += torch.numel(param) + # pretty print total parameters with comma separation + print(f"Total training parameters: {self._total_parameter_size:,}") - # needs to be enabled to count paramiters - if self.do_paramiter_swapping: - self.enable_paramiter_swapping(self.paramiter_swapping_factor) + # needs to be enabled to count parameters + if self.do_parameter_swapping: + self.enable_parameter_swapping(self.parameter_swapping_factor) - def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1): - self.do_paramiter_swapping = True - self.paramiter_swapping_factor = paramiter_swapping_factor + def enable_parameter_swapping(self, parameter_swapping_factor=0.1): + self.do_parameter_swapping = True + self.parameter_swapping_factor = parameter_swapping_factor # call it an initial time - self.swap_paramiters() + self.swap_parameters() - def swap_paramiters(self): + def swap_parameters(self): all_params = [] - # deactivate all paramiters + # deactivate all parameters for group in self.param_groups: for param in group['params']: param.requires_grad_(False) # remove any grad param.grad = None all_params.append(param) - # shuffle all paramiters + # shuffle all parameters random.shuffle(all_params) - # keep activating paramiters until we are going to go over the target paramiters - target_paramiters = int(self._total_paramiter_size * self.paramiter_swapping_factor) - total_paramiters = 0 + # keep activating parameters until we are going to go over the target parameters + target_parameters = max(1, int(self._total_parameter_size * self.parameter_swapping_factor)) + total_parameters = 0 for param in all_params: - total_paramiters += torch.numel(param) - if total_paramiters >= target_paramiters: + param.requires_grad_(True) + total_parameters += torch.numel(param) + if total_parameters >= target_parameters: break - else: - param.requires_grad_(True) @staticmethod def _get_lr(param_group, param_state): rel_step_sz = param_group["lr"] if param_group["relative_step"]: - min_step = 1e-6 * \ - param_state["step"] if param_group["warmup_init"] else 1e-2 + if param_group["warmup_init"]: + min_step = param_group["min_lr"] * param_state["step"] + else: + min_step = param_group["max_lr"] rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) param_scale = 1.0 if param_group["scale_parameter"]: param_scale = max(param_group["eps"][1], param_state["RMS"]) - return param_scale * rel_step_sz + lr = param_scale * rel_step_sz + # Ensure learning rate doesn't exceed max_lr when using relative_step + if param_group["relative_step"]: + lr = min(lr, param_group["max_lr"]) + return lr @staticmethod def _get_options(param_group, param_shape): @@ -234,11 +251,20 @@ def step_hook(self): # adafactor manages its own lr def get_learning_rates(self): - lrs = [ - self._get_lr(group, self.state[group["params"][0]]) - for group in self.param_groups - if group["params"][0].grad is not None - ] + lrs = [] + for group in self.param_groups: + # Find first param with initialized state + lr = None + for param in group["params"]: + if param in self.state and len(self.state[param]) > 0: + lr = self._get_lr(group, self.state[param]) + break + if lr is not None: + lrs.append(lr) + elif group["lr"] is not None: + # Fallback to group lr if state not initialized + lrs.append(group["lr"]) + if len(lrs) == 0: lrs = self.base_lrs # if called before stepping return lrs @@ -288,6 +314,7 @@ def step(self, closure=None): if factored: state["exp_avg_sq_row"] = torch.zeros( grad_shape[:-1]).to(grad) + # For 2D tensors, grad_shape[:-2] is empty tuple, which is correct for column stats state["exp_avg_sq_col"] = torch.zeros( grad_shape[:-2] + grad_shape[-1:]).to(grad) else: @@ -306,8 +333,9 @@ def step(self, closure=None): state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) p_data_fp32 = p + is_quantized = isinstance(p_data_fp32, QBytesTensor) - if isinstance(p_data_fp32, QBytesTensor): + if is_quantized: p_data_fp32 = p_data_fp32.dequantize() if p.dtype != torch.float32: p_data_fp32 = p_data_fp32.clone().float() @@ -356,8 +384,54 @@ def step(self, closure=None): p_data_fp32.add_(-update) - if p.dtype != torch.float32 and self.stochastic_rounding: + # Store update RMS for monitoring + state["update_rms"] = self._rms(update).item() + + if (p.dtype != torch.float32 or is_quantized) and self.stochastic_rounding: # apply stochastic rounding copy_stochastic(p, p_data_fp32) return loss + + def get_avg_learning_rate(self): + lrs = self.get_learning_rates() + if len(lrs) == 0: + return 0.0 + return sum(lrs) / len(lrs) + + def get_update_rms(self): + """ + Get RMS (root mean square) of weight updates for each parameter group. + + Returns: + List[float]: RMS of weight updates for each parameter group. + Returns 0.0 for groups that haven't been updated yet. + """ + update_rms_list = [] + for group in self.param_groups: + group_rms_sum = 0.0 + group_count = 0 + for p in group["params"]: + if p in self.state and "update_rms" in self.state[p]: + group_rms_sum += self.state[p]["update_rms"] + group_count += 1 + if group_count > 0: + update_rms_list.append(group_rms_sum / group_count) + else: + update_rms_list.append(0.0) + return update_rms_list + + def get_avg_update_rms(self): + """ + Get average RMS of weight updates across all parameter groups. + + This metric represents the average magnitude of weight changes per optimization step. + Useful for monitoring training stability and convergence. + + Returns: + float: Average RMS of weight updates across all parameter groups. + """ + update_rms_list = self.get_update_rms() + if len(update_rms_list) == 0: + return 0.0 + return sum(update_rms_list) / len(update_rms_list) From 373fdb2bb30dd9d729eea39ff7e62d25bf922a3d Mon Sep 17 00:00:00 2001 From: Daan Date: Fri, 13 Feb 2026 17:59:35 +0300 Subject: [PATCH 05/16] refactor(device): safe_module_to_device and conditional model transfer - Add safe_module_to_device in toolkit/util/device.py - Use in ltx2, qwen_image, wan22 for low_vram-aware device handling --- .../diffusion_models/ltx2/ltx2.py | 11 ++++---- .../diffusion_models/qwen_image/qwen_image.py | 7 +++-- .../qwen_image/qwen_image_edit.py | 4 ++- .../qwen_image/qwen_image_edit_plus.py | 4 ++- .../diffusion_models/wan22/wan22_14b_model.py | 24 +++++++++------- toolkit/util/device.py | 28 +++++++++++++++++++ 6 files changed, 59 insertions(+), 19 deletions(-) create mode 100644 toolkit/util/device.py diff --git a/extensions_built_in/diffusion_models/ltx2/ltx2.py b/extensions_built_in/diffusion_models/ltx2/ltx2.py index a6d1e6dc7..36d32787f 100644 --- a/extensions_built_in/diffusion_models/ltx2/ltx2.py +++ b/extensions_built_in/diffusion_models/ltx2/ltx2.py @@ -18,6 +18,7 @@ from toolkit.accelerator import unwrap_model from optimum.quanto import freeze from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.util.device import safe_module_to_device from toolkit.memory_management import MemoryManager from safetensors.torch import load_file from PIL import Image @@ -570,8 +571,8 @@ def generate_single_image( generator: torch.Generator, extra: dict, ): - if self.model.device == torch.device("cpu"): - self.model.to(self.device_torch) + if self.low_vram and self.model.device == torch.device("cpu"): + safe_module_to_device(self.model, self.device_torch) # handle control image if gen_config.ctrl_img is not None: @@ -771,9 +772,9 @@ def get_noise_prediction( **kwargs, ): with torch.no_grad(): - if self.model.device == torch.device("cpu"): - self.model.to(self.device_torch) - + if self.low_vram and self.model.device == torch.device("cpu"): + safe_module_to_device(self.model, self.device_torch) + # We only encode and store the minimum prompt tokens, but need them padded to 1024 for LTX2 text_embeddings = self.pad_embeds(text_embeddings) diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index 6e6813ffb..20772a29d 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -15,6 +15,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model from optimum.quanto import freeze, QTensor from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.util.device import safe_module_to_device import torch.nn.functional as F from toolkit.memory_management import MemoryManager from safetensors.torch import load_file @@ -262,7 +263,8 @@ def generate_single_image( control_img = control_img.resize( (gen_config.width, gen_config.height), Image.BILINEAR ) - self.model.to(self.device_torch) + if self.model_config.low_vram: + safe_module_to_device(self.model, self.device_torch) # flush for low vram if we are doing that flush_between_steps = self.model_config.low_vram @@ -305,7 +307,8 @@ def get_noise_prediction( text_embeddings: PromptEmbeds, **kwargs, ): - self.model.to(self.device_torch) + if self.model_config.low_vram: + safe_module_to_device(self.model, self.device_torch) batch_size, num_channels_latents, height, width = latent_model_input.shape ps = self.transformer.config.patch_size diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py index bcc8d735c..328cb81f9 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py @@ -16,6 +16,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model from optimum.quanto import freeze, QTensor from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.util.device import safe_module_to_device import torch.nn.functional as F from diffusers import ( @@ -89,7 +90,8 @@ def generate_single_image( generator: torch.Generator, extra: dict, ): - self.model.to(self.device_torch, dtype=self.torch_dtype) + if self.model_config.low_vram: + safe_module_to_device(self.model, self.device_torch, self.torch_dtype) sc = self.get_bucket_divisibility() gen_config.width = int(gen_config.width // sc * sc) gen_config.height = int(gen_config.height // sc * sc) diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py index 8272ee464..dfe772575 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py @@ -16,6 +16,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model from optimum.quanto import freeze, QTensor from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.util.device import safe_module_to_device import torch.nn.functional as F from diffusers import ( @@ -97,7 +98,8 @@ def generate_single_image( generator: torch.Generator, extra: dict, ): - self.model.to(self.device_torch, dtype=self.torch_dtype) + if self.model_config.low_vram: + safe_module_to_device(self.model, self.device_torch, self.torch_dtype) sc = self.get_bucket_divisibility() gen_config.width = int(gen_config.width // sc * sc) gen_config.height = int(gen_config.height // sc * sc) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index a32183cec..9205d7d8a 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -16,6 +16,7 @@ CustomFlowMatchEulerDiscreteScheduler, ) from toolkit.util.quantize import quantize_model +from toolkit.util.device import safe_module_to_device from .wan22_pipeline import Wan22Pipeline from diffusers import WanTransformer3DModel @@ -132,20 +133,23 @@ def forward( # todo swap the loras as well if t_name != self._active_transformer_name: if self.low_vram: - getattr(self, self._active_transformer_name).to("cpu") - getattr(self, t_name).to(self.device_torch) + safe_module_to_device( + getattr(self, self._active_transformer_name), torch.device("cpu") + ) + safe_module_to_device(getattr(self, t_name), self.device_torch) torch.cuda.empty_cache() self._active_transformer_name = t_name - if self.transformer.device != hidden_states.device: - if self.low_vram: - # move other transformer to cpu - other_tname = ( - "transformer_1" if t_name == "transformer_2" else "transformer_2" - ) - getattr(self, other_tname).to("cpu") + if self.low_vram and self.transformer.device != hidden_states.device: + # move other transformer to cpu + other_tname = ( + "transformer_1" if t_name == "transformer_2" else "transformer_2" + ) + safe_module_to_device( + getattr(self, other_tname), torch.device("cpu") + ) - self.transformer.to(hidden_states.device) + safe_module_to_device(self.transformer, hidden_states.device) return self.transformer( hidden_states=hidden_states, diff --git a/toolkit/util/device.py b/toolkit/util/device.py new file mode 100644 index 000000000..0a06817c0 --- /dev/null +++ b/toolkit/util/device.py @@ -0,0 +1,28 @@ +""" +Device utilities. safe_module_to_device moves a module in-place without using +Module.to(), avoiding PyTorch's swap_tensors path which fails on quantized +parameters (e.g. QLinear) that have requires_grad=False. +""" +from typing import Optional + +import torch + + +def safe_module_to_device( + module: torch.nn.Module, + device: torch.device, + dtype: Optional[torch.dtype] = None, +) -> None: + """ + Move module to device (and optionally dtype) by moving param/buffer .data in-place. + Avoids Module.to() which uses swap_tensors and can raise on tensors + that do not require gradients (e.g. QLinear in quantized models). + """ + for _, param in module.named_parameters(recurse=False): + if param.device != device or (dtype is not None and param.dtype != dtype): + param.data = param.data.to(device=device, dtype=dtype or param.dtype) + for _, buf in module.named_buffers(recurse=False): + if buf.device != device or (dtype is not None and buf.dtype != dtype): + buf.data = buf.data.to(device=device, dtype=dtype or buf.dtype) + for _, child in module.named_children(): + safe_module_to_device(child, device, dtype) From 87057c5912b452341da2b272a0fec50a82a49dab Mon Sep 17 00:00:00 2001 From: Daan Date: Fri, 13 Feb 2026 17:59:45 +0300 Subject: [PATCH 06/16] feat(z_image): sampling model, quantization, stop job API - Separate sampling model with device handling, quantization support - Sampling model unload, LoRA on sampling transformer - Accuracy recovery adapter in quantize, file cleanup with retries - Stop job API marks job stopped and handles dead process - Sample image viewer keyboard navigation fix --- .../diffusion_models/z_image/z_image.py | 167 +++++++++++++----- toolkit/util/quantize.py | 4 +- ui/src/app/api/jobs/[jobID]/stop/route.ts | 70 +++++++- ui/src/components/SampleImageViewer.tsx | 4 +- 4 files changed, 190 insertions(+), 55 deletions(-) diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py index fda894a5a..44bdf2966 100644 --- a/extensions_built_in/diffusion_models/z_image/z_image.py +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -15,6 +15,7 @@ from toolkit.accelerator import unwrap_model from optimum.quanto import freeze from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.util.device import safe_module_to_device from toolkit.memory_management import MemoryManager from safetensors.torch import load_file @@ -202,6 +203,32 @@ def load_model(self): flush() + # Load sampling transformer if a separate sampling_name_or_path is specified + self._sampling_transformer = None + if self.model_config.sampling_name_or_path is not None and self.model_config.sampling_name_or_path != model_path: + self.print_and_status_update("Loading sampling transformer") + sampling_model_path = self.model_config.sampling_name_or_path + sampling_transformer_path = sampling_model_path + sampling_transformer_subfolder = "transformer" + + if os.path.exists(sampling_model_path): + sampling_transformer_subfolder = None + sampling_transformer_path = os.path.join(sampling_model_path, "transformer") + + sampling_transformer = ZImageTransformer2DModel.from_pretrained( + sampling_transformer_path, subfolder=sampling_transformer_subfolder, torch_dtype=dtype + ) + + if self.model_config.quantize: + self.print_and_status_update("Quantizing sampling transformer") + quantize_model(self, sampling_transformer) + flush() + + # Always keep sampling transformer on CPU + sampling_transformer.to("cpu") + self._sampling_transformer = sampling_transformer + flush() + self.print_and_status_update("Text Encoder") tokenizer = AutoTokenizer.from_pretrained( base_model_path, subfolder="tokenizer", torch_dtype=dtype @@ -279,15 +306,21 @@ def load_model(self): def get_generation_pipeline(self): scheduler = ZImageModel.get_train_scheduler() + # Determine which transformer to use for sampling + transformer_for_sampling = self.transformer + if self._sampling_transformer is not None: + # Use sampling transformer, keep it on CPU (for testing / avoid OOM) + transformer_for_sampling = self._sampling_transformer + pipeline: ZImagePipeline = ZImagePipeline( scheduler=scheduler, text_encoder=unwrap_model(self.text_encoder[0]), tokenizer=self.tokenizer[0], vae=unwrap_model(self.vae), - transformer=unwrap_model(self.transformer), + transformer=unwrap_model(transformer_for_sampling), ) - pipeline = pipeline.to(self.device_torch) + # pipeline = pipeline.to(self.device_torch) return pipeline @@ -300,48 +333,58 @@ def generate_single_image( generator: torch.Generator, extra: dict, ): - self.model.to(self.device_torch, dtype=self.torch_dtype) - self.model.to(self.device_torch) - - sc = self.get_bucket_divisibility() - gen_config.width = int(gen_config.width // sc * sc) - gen_config.height = int(gen_config.height // sc * sc) - - # ZImagePipeline expects prompt_embeds and negative_prompt_embeds to be - # List[torch.FloatTensor] where each element is [seq_len, dim]. - # The pipeline concatenates these lists for CFG (not element-wise add). - cond_embeds = conditional_embeds.text_embeds - uncond_embeds = unconditional_embeds.text_embeds - - # Convert rank-3 tensors back to list of rank-2 tensors - def to_embed_list(embeds): - if embeds is None: - return [] - if isinstance(embeds, list): + + if self._sampling_transformer is not None: + self.model.to("cpu") + self._sampling_transformer.to(self.device_torch, dtype=self.torch_dtype) + else: + self.model.to(self.device_torch, dtype=self.torch_dtype) + + try: + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + # ZImagePipeline expects prompt_embeds and negative_prompt_embeds to be + # List[torch.FloatTensor] where each element is [seq_len, dim]. + # The pipeline concatenates these lists for CFG (not element-wise add). + cond_embeds = conditional_embeds.text_embeds + uncond_embeds = unconditional_embeds.text_embeds + + # Convert rank-3 tensors back to list of rank-2 tensors + def to_embed_list(embeds): + if embeds is None: + return [] + if isinstance(embeds, list): + return embeds + if len(embeds.shape) == 3: + # [batch, seq_len, dim] -> list of [seq_len, dim] + return list(embeds.unbind(dim=0)) + elif len(embeds.shape) == 2: + # Already [seq_len, dim], wrap in list + return [embeds] return embeds - if len(embeds.shape) == 3: - # [batch, seq_len, dim] -> list of [seq_len, dim] - return list(embeds.unbind(dim=0)) - elif len(embeds.shape) == 2: - # Already [seq_len, dim], wrap in list - return [embeds] - return embeds - - cond_embeds_list = to_embed_list(cond_embeds) - uncond_embeds_list = to_embed_list(uncond_embeds) - - img = pipeline( - prompt_embeds=cond_embeds_list, - negative_prompt_embeds=uncond_embeds_list, - height=gen_config.height, - width=gen_config.width, - num_inference_steps=gen_config.num_inference_steps, - guidance_scale=gen_config.guidance_scale, - latents=gen_config.latents, - generator=generator, - **extra, - ).images[0] - return img + + cond_embeds_list = to_embed_list(cond_embeds) + uncond_embeds_list = to_embed_list(uncond_embeds) + + img = pipeline( + prompt_embeds=cond_embeds_list, + negative_prompt_embeds=uncond_embeds_list, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra, + ).images[0] + return img + finally: + # Unload sampling transformer from GPU and restore main model to device + if self._sampling_transformer is not None: + self._sampling_transformer.to("cpu") + self.model.to(self.device_torch, dtype=self.torch_dtype) def get_noise_prediction( self, @@ -350,7 +393,8 @@ def get_noise_prediction( text_embeddings: PromptEmbeds, **kwargs, ): - self.model.to(self.device_torch) + if self.low_vram and next(self.model.parameters()).device != self.device_torch: + safe_module_to_device(self.model, self.device_torch) latent_model_input = latent_model_input.unsqueeze(2) latent_model_input_list = list(latent_model_input.unbind(dim=0)) @@ -432,6 +476,41 @@ def save_model(self, output_path, meta, save_dtype): with open(meta_path, "w") as f: yaml.dump(meta, f) + def generate_images( + self, + image_configs: List[GenerateImageConfig], + sampler=None, + ): + """ + Override generate_images to handle LoRA on sampling transformer. + When using _sampling_transformer with LoRA, temporarily swap self.network + to _sampling_network (which has LoRA applied to _sampling_transformer) + for the duration of generation. + """ + saved_network = None + try: + # If using sampling transformer with LoRA, sync weights and swap network + if ( + hasattr(self, '_sampling_transformer') + and self._sampling_transformer is not None + and hasattr(self, '_sampling_network') + and self._sampling_network is not None + and hasattr(self, 'network') + and self.network is not None + ): + # Sync weights from training network to sampling network + self._sampling_network.load_state_dict(self.network.state_dict()) + # Save current network and swap to sampling network + saved_network = self.network + self.network = self._sampling_network + + # Call parent's generate_images with possibly swapped network + return super().generate_images(image_configs, sampler) + finally: + # Restore original network + if saved_network is not None: + self.network = saved_network + def get_loss_target(self, *args, **kwargs): noise = kwargs.get("noise") batch = kwargs.get("batch") diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py index 27d6733a5..af188be89 100644 --- a/toolkit/util/quantize.py +++ b/toolkit/util/quantize.py @@ -129,6 +129,8 @@ def quantize( def quantize_model( base_model: "BaseModel", model_to_quantize: torch.nn.Module, + use_accuracy_recovery: bool = True, + adapter_attr_name: str = "accuracy_recovery_adapter", ): from toolkit.dequantize import patch_dequantization_on_save @@ -140,7 +142,7 @@ def quantize_model( # patch the state dict method patch_dequantization_on_save(model_to_quantize) - if base_model.model_config.accuracy_recovery_adapter is not None: + if use_accuracy_recovery and base_model.model_config.accuracy_recovery_adapter is not None: from toolkit.config_modules import NetworkConfig from toolkit.lora_special import LoRASpecialNetwork diff --git a/ui/src/app/api/jobs/[jobID]/stop/route.ts b/ui/src/app/api/jobs/[jobID]/stop/route.ts index 73b352dfc..9d1ce421e 100644 --- a/ui/src/app/api/jobs/[jobID]/stop/route.ts +++ b/ui/src/app/api/jobs/[jobID]/stop/route.ts @@ -1,8 +1,20 @@ import { NextRequest, NextResponse } from 'next/server'; import { PrismaClient } from '@prisma/client'; +import { getTrainingFolder } from '@/server/settings'; +import path from 'path'; +import fs from 'fs'; const prisma = new PrismaClient(); +function isProcessAlive(pid: number): boolean { + try { + process.kill(pid, 0); + return true; + } catch { + return false; + } +} + export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { const { jobID } = await params; @@ -10,14 +22,54 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s where: { id: jobID }, }); - // update job status to 'running' - await prisma.job.update({ - where: { id: jobID }, - data: { - stop: true, - info: 'Stopping job...', - }, - }); + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + const markStopped = (info: string) => + prisma.job.update({ + where: { id: jobID }, + data: { + stop: true, + status: 'stopped', + info, + }, + }); + + if (job.status !== 'running') { + const updated = await markStopped(job.status === 'stopped' ? job.info : 'Job stopped'); + return NextResponse.json(updated); + } + + let pid: number | null = null; + try { + const trainingRoot = await getTrainingFolder(); + const pidPath = path.join(trainingRoot, job.name, 'pid.txt'); + if (fs.existsSync(pidPath)) { + const raw = fs.readFileSync(pidPath, 'utf8').trim(); + const n = parseInt(raw, 10); + if (Number.isInteger(n) && n > 0) pid = n; + } + } catch { + // pid file missing or unreadable — treat as no process + } + + if (pid === null) { + const updated = await markStopped('Job stopped'); + return NextResponse.json(updated); + } + + if (!isProcessAlive(pid)) { + const updated = await markStopped('Job stopped'); + return NextResponse.json(updated); + } + + try { + process.kill(pid, 'SIGTERM'); + } catch { + // Process may have exited; still mark as stopped + } - return NextResponse.json(job); + const updated = await markStopped('Job stopped'); + return NextResponse.json(updated); } diff --git a/ui/src/components/SampleImageViewer.tsx b/ui/src/components/SampleImageViewer.tsx index e8804a9a4..799427835 100644 --- a/ui/src/components/SampleImageViewer.tsx +++ b/ui/src/components/SampleImageViewer.tsx @@ -82,7 +82,7 @@ export default function SampleImageViewer({ if (idx < 0 || idx >= sampleImages.length) return; onChange(sampleImages[idx]); }, - [sampleImages, numSamples, onChange], + [sampleImages, onChange], ); const currentIndex = useMemo(() => { @@ -167,9 +167,11 @@ export default function SampleImageViewer({ onCancel(); break; case 'ArrowUp': + event.preventDefault(); handleArrowUp(); break; case 'ArrowDown': + event.preventDefault(); handleArrowDown(); break; case 'ArrowLeft': From cf43441f7108a60d3812163db48279a5bf80e5f4 Mon Sep 17 00:00:00 2001 From: Daan Date: Fri, 13 Feb 2026 20:13:39 +0300 Subject: [PATCH 07/16] fix(optimizers): adjust max_lr and improve learning rate clamping - Change max_lr from 1e-2 to 1e-4 for Adafactor optimizer - Update comment to clarify learning rate clamping between min_lr and max_lr when using relative_step fix(optimizers): streamline learning rate clamping in Adafactor - Simplified the learning rate clamping logic to ensure it consistently respects min_lr and max_lr boundaries. - Removed the conditional check for relative_step, enhancing code clarity. --- toolkit/optimizers/adafactor.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py index 41596c08a..1879f3333 100644 --- a/toolkit/optimizers/adafactor.py +++ b/toolkit/optimizers/adafactor.py @@ -115,7 +115,7 @@ def __init__( relative_step=True, warmup_init=False, min_lr=1e-6, - max_lr=1e-2, + max_lr=1e-4, do_parameter_swapping=False, parameter_swapping_factor=0.1, stochastic_accumulation=True, @@ -215,9 +215,8 @@ def _get_lr(param_group, param_state): if param_group["scale_parameter"]: param_scale = max(param_group["eps"][1], param_state["RMS"]) lr = param_scale * rel_step_sz - # Ensure learning rate doesn't exceed max_lr when using relative_step - if param_group["relative_step"]: - lr = min(lr, param_group["max_lr"]) + # Ensure learning rate is between min_lr and max_lr + lr = max(param_group["min_lr"], min(lr, param_group["max_lr"])) return lr @staticmethod From 43b1655ee3de53b468da363bc9acb86f24d5774d Mon Sep 17 00:00:00 2001 From: Daan Date: Fri, 13 Feb 2026 20:39:31 +0300 Subject: [PATCH 08/16] refactor(z_image): streamline embedding conversion and device management - Simplified the embedding conversion process in ZImageModel for better clarity and efficiency. - Improved device management for the sampling transformer, ensuring proper loading and unloading to optimize resource usage. - Enhanced code readability by removing redundant checks and restructuring the flow. fix(z_image): use dtype instead of deprecated torch_dtype in from_pretrained calls Revert "fix(z_image): use dtype instead of deprecated torch_dtype in from_pretrained calls" This reverts commit 7acdd142cd030089a0eb6734baa3194590cc410f. refactor(zimage): load sampling transformer first and extract load helpers - Add _load_sampling_transformer() and _load_transformer(model_path) - Call sampling transformer load before main transformer to control peak VRAM - Shrink load_model() by delegating to helpers; preserve base_model_path resolution - Restore and add comments and docstrings in load paths feat(zimage): add debug_zimage_load to trace safetensors load/mmap - Add model.debug_zimage_load in ModelConfig to enable load debugging - When enabled, patch safetensors.torch.load_file and safetensors.safe_open to log path, size, and duration via print_and_status_update - Log markers before loading sampling vs main transformer to correlate slow sampling load with specific file opens - Patches applied once per process to avoid double-wrap fix(zimage): normalize model paths to avoid HF identifier warning Add normalize_model_path() in toolkit/paths.py to strip trailing path separators. Use it in Z-Image for name_or_path, sampling_name_or_path, and extras_name_or_path so paths like "e:\...\snapshots\hash\" no longer trigger "The module name (originally ) is not a valid Python identifier" from Hugging Face transformers. fix(zimage): use dtype instead of deprecated torch_dtype for transformers Pass dtype=dtype to AutoTokenizer.from_pretrained and Qwen3ForCausalLM.from_pretrained to remove the "torch_dtype is deprecated! Use dtype instead!" warnings when loading the text encoder. --- .../diffusion_models/z_image/z_image.py | 237 +++++++++++------- toolkit/config_modules.py | 3 + toolkit/models/base_model.py | 13 + toolkit/paths.py | 10 + 4 files changed, 175 insertions(+), 88 deletions(-) diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py index 44bdf2966..9cb4e2943 100644 --- a/extensions_built_in/diffusion_models/z_image/z_image.py +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -1,5 +1,6 @@ import os -from typing import List, Optional +import time +from typing import List, Optional, Tuple import huggingface_hub import torch @@ -17,7 +18,9 @@ from toolkit.util.quantize import quantize, get_qtype, quantize_model from toolkit.util.device import safe_module_to_device from toolkit.memory_management import MemoryManager +from toolkit.paths import normalize_model_path from safetensors.torch import load_file +import safetensors from transformers import AutoTokenizer, Qwen3ForCausalLM from diffusers import AutoencoderKL @@ -37,6 +40,9 @@ "shift": 3.0, } +# Set when safetensors patches are applied for debug_zimage_load (avoid double-wrap) +_zimage_load_debug_patched = False + class ZImageModel(BaseModel): arch = "zimage" @@ -148,11 +154,45 @@ def load_training_adapter(self, transformer: ZImageTransformer2DModel): # tell the model to invert assistant on inference since we want remove lora effects self.invert_assistant_lora = True - def load_model(self): + def _load_sampling_transformer(self) -> Optional[ZImageTransformer2DModel]: + """Load a separate transformer for inference/sampling when sampling_name_or_path is set. + Returns None if not configured or when sampling path equals the training model path. + """ + if ( + self.model_config.sampling_name_or_path is None + or self.model_config.sampling_name_or_path == self.model_config.name_or_path + ): + return None dtype = self.torch_dtype - self.print_and_status_update("Loading ZImage model") - model_path = self.model_config.name_or_path - base_model_path = self.model_config.extras_name_or_path + self.print_and_status_update("Loading sampling transformer") + sampling_model_path = normalize_model_path(self.model_config.sampling_name_or_path) + sampling_transformer_path = sampling_model_path + sampling_transformer_subfolder = "transformer" + + if os.path.exists(sampling_model_path): + sampling_transformer_subfolder = None + sampling_transformer_path = os.path.join(sampling_model_path, "transformer") + + sampling_transformer = ZImageTransformer2DModel.from_pretrained( + sampling_transformer_path, + subfolder=sampling_transformer_subfolder, + torch_dtype=dtype, + ) + if self.model_config.quantize: + self.print_and_status_update("Quantizing sampling transformer") + quantize_model(self, sampling_transformer) + flush() + # Always keep sampling transformer on CPU to save VRAM during training + sampling_transformer.to("cpu") + flush() + return sampling_transformer + + def _load_transformer(self, model_path: str) -> Tuple[ZImageTransformer2DModel, str]: + """Load the training transformer and resolve base_model_path for VAE/text_encoder. + Returns (transformer, base_model_path). base_model_path is used for tokenizer, text_encoder, VAE. + """ + dtype = self.torch_dtype + base_model_path = normalize_model_path(self.model_config.extras_name_or_path) self.print_and_status_update("Loading transformer") @@ -161,20 +201,19 @@ def load_model(self): if os.path.exists(transformer_path): transformer_subfolder = None transformer_path = os.path.join(transformer_path, "transformer") - # check if the path is a full checkpoint. + # Check if the path is a full checkpoint (contains text_encoder, vae, etc.) te_folder_path = os.path.join(model_path, "text_encoder") - # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): base_model_path = model_path transformer = ZImageTransformer2DModel.from_pretrained( transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype ) - - # load assistant lora if specified + # Load assistant LoRA if specified (e.g. for ARA-style training) if self.model_config.assistant_lora_path is not None: self.load_training_adapter(transformer) - # set qtype to be float8 if it is qfloat8 + # Set qtype to float8 if it is qfloat8 (quantize compatibility) if self.model_config.qtype == "qfloat8": self.model_config.qtype = "float8" @@ -202,39 +241,74 @@ def load_model(self): transformer.to("cpu") flush() + return transformer, base_model_path - # Load sampling transformer if a separate sampling_name_or_path is specified - self._sampling_transformer = None - if self.model_config.sampling_name_or_path is not None and self.model_config.sampling_name_or_path != model_path: - self.print_and_status_update("Loading sampling transformer") - sampling_model_path = self.model_config.sampling_name_or_path - sampling_transformer_path = sampling_model_path - sampling_transformer_subfolder = "transformer" - - if os.path.exists(sampling_model_path): - sampling_transformer_subfolder = None - sampling_transformer_path = os.path.join(sampling_model_path, "transformer") - - sampling_transformer = ZImageTransformer2DModel.from_pretrained( - sampling_transformer_path, subfolder=sampling_transformer_subfolder, torch_dtype=dtype - ) - - if self.model_config.quantize: - self.print_and_status_update("Quantizing sampling transformer") - quantize_model(self, sampling_transformer) - flush() - - # Always keep sampling transformer on CPU - sampling_transformer.to("cpu") - self._sampling_transformer = sampling_transformer - flush() + def load_model(self): + global _zimage_load_debug_patched + dtype = self.torch_dtype + self.print_and_status_update("Loading ZImage model") + model_path = normalize_model_path(self.model_config.name_or_path) + + if self.model_config.debug_zimage_load and not _zimage_load_debug_patched: + log_func = self.print_and_status_update + orig_load_file = safetensors.torch.load_file + orig_safe_open = safetensors.safe_open + + def _debug_path_and_size(path): + if path is not None and isinstance(path, (str, os.PathLike)) and os.path.isfile(path): + path_str = os.path.abspath(path) + try: + size = os.path.getsize(path) + return path_str, f" size={size}" + except OSError: + return path_str, "" + path_str = repr(path) if path else "" + return path_str, "" + + def _wrapped_load_file(*args, **kwargs): + path = args[0] if args else kwargs.get("filename") or kwargs.get("path") + path_str, size_str = _debug_path_and_size(path) + start = time.perf_counter() + result = orig_load_file(*args, **kwargs) + duration = time.perf_counter() - start + log_func(f"[ZImage debug] load_file path={path_str}{size_str} duration={duration:.3f}s") + return result + + def _wrapped_safe_open(*args, **kwargs): + path = args[0] if args else kwargs.get("filename") or kwargs.get("path") + path_str, size_str = _debug_path_and_size(path) + start = time.perf_counter() + result = orig_safe_open(*args, **kwargs) + + class _WrappedCtx: + def __enter__(_self): + return result.__enter__() + + def __exit__(_self, *exc): + duration = time.perf_counter() - start + log_func(f"[ZImage debug] safe_open path={path_str}{size_str} duration={duration:.3f}s") + return result.__exit__(*exc) + + return _WrappedCtx() + + safetensors.torch.load_file = _wrapped_load_file + safetensors.safe_open = _wrapped_safe_open + _zimage_load_debug_patched = True + + if self.model_config.debug_zimage_load: + self.print_and_status_update("[ZImage debug] === Loading sampling transformer (from_pretrained) ===") + # Load sampling transformer first (if configured) to control peak VRAM + self._sampling_transformer = self._load_sampling_transformer() + if self.model_config.debug_zimage_load: + self.print_and_status_update("[ZImage debug] === Loading main transformer (from_pretrained) ===") + transformer, base_model_path = self._load_transformer(model_path) self.print_and_status_update("Text Encoder") tokenizer = AutoTokenizer.from_pretrained( - base_model_path, subfolder="tokenizer", torch_dtype=dtype + base_model_path, subfolder="tokenizer", dtype=dtype ) text_encoder = Qwen3ForCausalLM.from_pretrained( - base_model_path, subfolder="text_encoder", torch_dtype=dtype + base_model_path, subfolder="text_encoder", dtype=dtype ) if ( @@ -333,58 +407,45 @@ def generate_single_image( generator: torch.Generator, extra: dict, ): - - if self._sampling_transformer is not None: - self.model.to("cpu") - self._sampling_transformer.to(self.device_torch, dtype=self.torch_dtype) - else: - self.model.to(self.device_torch, dtype=self.torch_dtype) - - try: - sc = self.get_bucket_divisibility() - gen_config.width = int(gen_config.width // sc * sc) - gen_config.height = int(gen_config.height // sc * sc) - - # ZImagePipeline expects prompt_embeds and negative_prompt_embeds to be - # List[torch.FloatTensor] where each element is [seq_len, dim]. - # The pipeline concatenates these lists for CFG (not element-wise add). - cond_embeds = conditional_embeds.text_embeds - uncond_embeds = unconditional_embeds.text_embeds - - # Convert rank-3 tensors back to list of rank-2 tensors - def to_embed_list(embeds): - if embeds is None: - return [] - if isinstance(embeds, list): - return embeds - if len(embeds.shape) == 3: - # [batch, seq_len, dim] -> list of [seq_len, dim] - return list(embeds.unbind(dim=0)) - elif len(embeds.shape) == 2: - # Already [seq_len, dim], wrap in list - return [embeds] + sc = self.get_bucket_divisibility() + gen_config.width = int(gen_config.width // sc * sc) + gen_config.height = int(gen_config.height // sc * sc) + + # ZImagePipeline expects prompt_embeds and negative_prompt_embeds to be + # List[torch.FloatTensor] where each element is [seq_len, dim]. + # The pipeline concatenates these lists for CFG (not element-wise add). + cond_embeds = conditional_embeds.text_embeds + uncond_embeds = unconditional_embeds.text_embeds + + # Convert rank-3 tensors back to list of rank-2 tensors + def to_embed_list(embeds): + if embeds is None: + return [] + if isinstance(embeds, list): return embeds - - cond_embeds_list = to_embed_list(cond_embeds) - uncond_embeds_list = to_embed_list(uncond_embeds) - - img = pipeline( - prompt_embeds=cond_embeds_list, - negative_prompt_embeds=uncond_embeds_list, - height=gen_config.height, - width=gen_config.width, - num_inference_steps=gen_config.num_inference_steps, - guidance_scale=gen_config.guidance_scale, - latents=gen_config.latents, - generator=generator, - **extra, - ).images[0] - return img - finally: - # Unload sampling transformer from GPU and restore main model to device - if self._sampling_transformer is not None: - self._sampling_transformer.to("cpu") - self.model.to(self.device_torch, dtype=self.torch_dtype) + if len(embeds.shape) == 3: + # [batch, seq_len, dim] -> list of [seq_len, dim] + return list(embeds.unbind(dim=0)) + elif len(embeds.shape) == 2: + # Already [seq_len, dim], wrap in list + return [embeds] + return embeds + + cond_embeds_list = to_embed_list(cond_embeds) + uncond_embeds_list = to_embed_list(uncond_embeds) + + img = pipeline( + prompt_embeds=cond_embeds_list, + negative_prompt_embeds=uncond_embeds_list, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra, + ).images[0] + return img def get_noise_prediction( self, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ae3447156..5e9b5a23c 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -685,6 +685,9 @@ def __init__(self, **kwargs): # for models that support it (e.g., zimage), a separate model path for sampling/inference # training uses name_or_path, sampling uses sampling_name_or_path if set self.sampling_name_or_path: Optional[str] = kwargs.get("sampling_name_or_path", None) + + # enable debug logging for safetensors load (path, size, duration) to diagnose mmap/load differences + self.debug_zimage_load: bool = kwargs.get("debug_zimage_load", False) # path to an accuracy recovery adapter, either local or remote self.accuracy_recovery_adapter = kwargs.get("accuracy_recovery_adapter", None) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index ae117c857..fb9bebc2c 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -428,10 +428,18 @@ def generate_images( if network is not None: assert network.is_active + # Load sampling transformer onto device if it exists + if self._sampling_transformer is not None: + self.model.to("cpu") + self._sampling_transformer.to(self.device_torch, dtype=self.torch_dtype) + else: + self.model.to(self.device_torch, dtype=self.torch_dtype) + for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): gen_config = image_configs[i] extra = {} + validation_image = None if self.adapter is not None and gen_config.adapter_image_path is not None: validation_image = Image.open(gen_config.adapter_image_path) @@ -668,6 +676,11 @@ def generate_images( if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): self.adapter.clear_memory() + # Unload sampling transformer from GPU and restore main model to device + if self._sampling_transformer is not None: + self._sampling_transformer.to("cpu") + self.model.to(self.device_torch, dtype=self.torch_dtype) + # clear pipeline and cache to reduce vram usage del pipeline torch.cuda.empty_cache() diff --git a/toolkit/paths.py b/toolkit/paths.py index edd36ce19..ad4d29ffb 100644 --- a/toolkit/paths.py +++ b/toolkit/paths.py @@ -1,4 +1,5 @@ import os +from typing import Optional TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config') @@ -22,3 +23,12 @@ def get_path(path): if not os.path.isabs(path): path = os.path.join(TOOLKIT_ROOT, path) return path + + +def normalize_model_path(path: Optional[str]) -> Optional[str]: + """Strip trailing path separators so Hugging Face transformers don't get empty basenames. + Avoids the 'The module name (originally ) is not a valid Python identifier' warning. + """ + if isinstance(path, str): + return path.rstrip(os.sep).rstrip("/") + return path From 4f954ed68e78feef4030c6762aa87f77ae464c2d Mon Sep 17 00:00:00 2001 From: Daan Date: Tue, 17 Feb 2026 00:48:19 +0300 Subject: [PATCH 09/16] fix(train): free memory before checkpoint save to reduce OOM risk Call optimizer.zero_grad(set_to_none=True), torch.cuda.synchronize(), and flush() before self.save() in BaseSDTrainProcess so gradients and CUDA cache are released before building state_dict and hashes. fix(train): validate optimizer state load and document save behavior - Add param count check before loading optimizer.pt; skip load and warn when current and saved param counts differ to avoid wrong state mapping. - Comment that optimizer is always saved via unwrap to match pre-prepare load target; drop unused exception variable in inner except. --- jobs/process/BaseSDTrainProcess.py | 31 +++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 0b3e8e31e..b391af6d6 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -684,14 +684,14 @@ def save(self, step=None): print_acc(f"Saved checkpoint to {file_path}") - # save optimizer + # save optimizer (always unwrap so state matches the optimizer we load into before accelerator.prepare()) if self.optimizer is not None: try: filename = f'optimizer.pt' file_path = os.path.join(self.save_root, filename) try: state_dict = unwrap_model(self.optimizer).state_dict() - except Exception as e: + except Exception: state_dict = self.optimizer.state_dict() torch.save(state_dict, file_path) print_acc(f"Saved optimizer to {file_path}") @@ -2221,7 +2221,27 @@ def run(self): try: print_acc(f"Loading optimizer state from {optimizer_state_file_path}") optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) - optimizer.load_state_dict(optimizer_state_dict) + + # PyTorch maps optimizer state by param order; param count must match + # or state will be applied to wrong parameters + current_param_count = sum( + len(g["params"]) for g in optimizer.param_groups + ) + saved_param_count = sum( + len(g["params"]) for g in optimizer_state_dict.get("param_groups", []) + ) + + if current_param_count != saved_param_count: + print_acc( + f"WARNING: Optimizer state NOT loaded: param count mismatch. " + f"Current optimizer has {current_param_count} params, " + f"but saved state has {saved_param_count} params. " + "Training will continue with fresh optimizer state." + ) + else: + optimizer.load_state_dict(optimizer_state_dict) + print_acc("Optimizer state restored successfully.") + del optimizer_state_dict flush() except Exception as e: @@ -2508,6 +2528,11 @@ def run(self): if self.progress_bar is not None: self.progress_bar.pause() print_acc(f"\nSaving at step {self.step_num}") + # free memory before save to reduce OOM risk + optimizer.zero_grad(set_to_none=True) + if torch.cuda.is_available(): + torch.cuda.synchronize() + flush() self.save(self.step_num) self.ensure_params_requires_grad() # clear any grads From 18a5501498dfa20b4f05b6716c190b5f8050ef5e Mon Sep 17 00:00:00 2001 From: Daan Date: Sun, 15 Feb 2026 18:50:33 +0300 Subject: [PATCH 10/16] fix(peft): apply alpha key fix for all peft types when loading LoRA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move $$alpha → .alpha key normalization into common peft block so it runs for all peft (lora, lokr, etc.). Remove duplicate alpha fix from lokr-only block to avoid missing_keys for alpha in state_dict. --- toolkit/network_mixins.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 493878bd3..1ce7a16bf 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -641,13 +641,13 @@ def load_weights(self: Network, file, force_weight_mapping=False): load_key = load_key.replace('.', '$$') load_key = load_key.replace('$$lora_down$$', '.lora_down.') load_key = load_key.replace('$$lora_up$$', '.lora_up.') - + if load_key.endswith('$$alpha'): + load_key = load_key[:-7] + '.alpha' + # patch lokr, not sure why we need to but whatever if self.network_type.lower() == "lokr": load_key = load_key.replace('$$lokr_w1', '.lokr_w1') load_key = load_key.replace('$$lokr_w2', '.lokr_w2') - if load_key.endswith('$$alpha'): - load_key = load_key[:-7] + '.alpha' if self.network_type.lower() == "lokr": # lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1 From 881bb4d8693749c602e928a7eaeff0eb5009e1ea Mon Sep 17 00:00:00 2001 From: Daan Date: Sun, 15 Feb 2026 00:07:56 +0300 Subject: [PATCH 11/16] refactor(dataloader): remove redundant token shuffling logic - Eliminated unnecessary shuffling of tokens as it is already handled in the subsequent code. fix(dataloader): skip shuffle_tokens when caching text embeddings Apply shuffle_tokens only when not cache_text_embeddings, so cached embeddings are saved from the original caption order. Aligns with existing caption_dropout and token_dropout behavior and keeps cache paths stable. feat(shuffle): keep first segment/token fixed when shuffling captions and embeddings - get_caption: when shuffle_tokens is on, keep first comma-separated segment in place and shuffle only the rest - shuffle_sequence: keep index 0 fixed and permute only positions 1..seq_len-1 feat(train): enhance text embedding caching logic and memory management - Introduced a mechanism to clear cached text embeddings based on the dataset configuration, improving memory efficiency during training. - Updated the AiToolkitDataset to retain prompt embeddings if they are cached, ensuring consistency across epochs. - Added logging for shuffling cached text embedding tokens at the start of each epoch, enhancing visibility into the training process. feat(train): log shuffling of cached text embedding tokens per epoch Added a print statement to log when cached text embedding tokens are shuffled at the start of each epoch, enhancing visibility into the training process. feat(train): shuffle cached text embedding tokens every epoch When cache_text_embeddings is true, shuffle token order along the sequence dimension at the start of each new epoch (after first full pass). Add PromptEmbeds.shuffle_sequence(), dataset set_epoch_num/ clear_cached_embeddings_memory, wire epoch boundary in BaseSDTrainProcess, and sync epoch_num to datasets at train start for correct behavior on resume from checkpoint. fix(sd_trainer): use trigger_word when cached text embeds are missing - Prefer trigger over blank when batch.prompt_embeds is None (no cache on disk) - Require trigger_word when cache_text_embeddings is enabled and batch has no cached embeds; raise ValueError otherwise - Keep reg batches using blank (no trigger) fix(dataloader): respect shuffle_tokens for cached text embeddings Skip shuffling cached prompt embeds when shuffle_tokens is false in set_epoch_num and load_prompt_embedding. --- extensions_built_in/sd_trainer/SDTrainer.py | 13 ++++++-- jobs/process/BaseSDTrainProcess.py | 12 +++++++- toolkit/data_loader.py | 3 ++ toolkit/dataloader_mixins.py | 30 +++++++++++++++---- toolkit/prompt_utils.py | 33 +++++++++++++++++++++ 5 files changed, 82 insertions(+), 9 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index da6c82677..2b475c517 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1526,13 +1526,20 @@ def get_adapter_multiplier(): self.device_torch, dtype=dtype ) else: - embeds_to_use = self.cached_blank_embeds.clone().detach().to( - self.device_torch, dtype=dtype - ) + # No cache on disk: use trigger_word instead of blank embedding if self.cached_trigger_embeds is not None and not is_reg: embeds_to_use = self.cached_trigger_embeds.clone().detach().to( self.device_torch, dtype=dtype ) + else: + if self.is_caching_text_embeddings: + raise ValueError( + "cache_text_embeddings is enabled but no cached embeds in batch. " + "Set trigger_word when using cache_text_embeddings so fallback is available when cache is missing." + ) + embeds_to_use = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) conditional_embeds = concat_prompt_embeds( [embeds_to_use] * noisy_latents.shape[0] ) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index b391af6d6..07f418791 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -27,7 +27,7 @@ from toolkit.basic import value_map from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.custom_adapter import CustomAdapter -from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch +from toolkit.data_loader import get_dataloader_from_datasets, get_dataloader_datasets, trigger_dataloader_setup_epoch from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.ema import ExponentialMovingAverage from toolkit.embedding import Embedding @@ -2322,6 +2322,11 @@ def run(self): dataloader_reg = None dataloader_iterator_reg = None + # Sync epoch_num to datasets so resume from checkpoint keeps correct shuffle-for-cached-embeddings behavior + if dataloader is not None: + for dataset in get_dataloader_datasets(dataloader): + dataset.set_epoch_num(self.epoch_num) + # zero any gradients optimizer.zero_grad() @@ -2407,6 +2412,11 @@ def run(self): dataloader_iterator = iter(dataloader) trigger_dataloader_setup_epoch(dataloader) self.epoch_num += 1 + for dataset in get_dataloader_datasets(dataloader): + dataset.set_epoch_num(self.epoch_num) + clear_embeddings_cache = not getattr(dataset.dataset_config, 'cache_text_embeddings', False) + if clear_embeddings_cache: + dataset.clear_cached_embeddings_memory() if self.train_config.gradient_accumulation_steps == -1: # if we are accumulating for an entire epoch, trigger a step self.is_grad_accumulation_step = False diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 51605c982..cfb084aa0 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -598,7 +598,10 @@ def __len__(self): def _get_single_item(self, index) -> 'FileItemDTO': file_item: 'FileItemDTO' = copy.deepcopy(self.file_list[index]) + file_item._current_epoch_num = getattr(self, '_epoch_num', 0) file_item.load_and_process_image(self.transform) + if file_item.is_text_embedding_cached and file_item.prompt_embeds is not None: + self.file_list[index].prompt_embeds = file_item.prompt_embeds file_item.load_caption(self.caption_dict) return file_item diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 39dbc6468..151f60d2a 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -415,8 +415,9 @@ def get_caption( new_token_list.append(token) token_list = new_token_list - if self.dataset_config.shuffle_tokens: - random.shuffle(token_list) + # Redundant code. Shuffling is performed in the code below. + # if self.dataset_config.shuffle_tokens: + # random.shuffle(token_list) # join back together caption = ', '.join(token_list) @@ -436,14 +437,17 @@ def get_caption( # trigger = self.dataset_config.random_triggers[int(random.random() * (len(self.dataset_config.random_triggers)))] # caption = caption + ', ' + trigger - if self.dataset_config.shuffle_tokens: - # shuffle again + if self.dataset_config.shuffle_tokens and not self.dataset_config.cache_text_embeddings: + # shuffle, keep first segment (until first comma) in place token_list = caption.split(',') # trim whitespace token_list = [x.strip() for x in token_list] # remove empty strings token_list = [x for x in token_list if x] - random.shuffle(token_list) + if len(token_list) > 1: + rest = token_list[1:] + random.shuffle(rest) + token_list = [token_list[0]] + rest caption = ', '.join(token_list) if caption == '': pass @@ -1965,6 +1969,9 @@ def load_prompt_embedding(self, device=None): if self.prompt_embeds is None: # load it from disk self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path()) + if getattr(self, '_current_epoch_num', 0) > 0 and getattr(self.dataset_config, 'shuffle_tokens', False): + self.prompt_embeds.shuffle_sequence() + print_acc(f"Cached text embedding tokens shuffled (epoch {getattr(self, '_current_epoch_num', 0)})") class TextEmbeddingCachingMixin: def __init__(self: 'AiToolkitDataset', **kwargs): @@ -1972,6 +1979,19 @@ def __init__(self: 'AiToolkitDataset', **kwargs): if hasattr(super(), '__init__'): super().__init__(**kwargs) self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings + self._epoch_num: int = 0 + + def set_epoch_num(self: 'AiToolkitDataset', epoch_num: int) -> None: + self._epoch_num = epoch_num + if epoch_num > 0 and self.is_caching_text_embeddings and self.dataset_config.shuffle_tokens: + for file_item in self.file_list: + if getattr(file_item, 'prompt_embeds', None) is not None: + file_item.prompt_embeds.shuffle_sequence() + print_acc(f"\nCached text embedding tokens shuffled (epoch {epoch_num})\n") + + def clear_cached_embeddings_memory(self: 'AiToolkitDataset') -> None: + for file_item in self.file_list: + file_item.cleanup_text_embedding() def cache_text_embeddings(self: 'AiToolkitDataset'): with accelerator.main_process_first(): diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 0f3c10a9b..9813a7056 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -114,6 +114,39 @@ def expand_to_batch(self, batch_size): pe.attention_mask = pe.attention_mask.expand(batch_size, -1) return pe + def shuffle_sequence(self) -> None: + """Shuffle token order along sequence dimension (dim=1), keeping first token fixed. In-place. No-op if seq_len <= 1.""" + def make_perm(seq_len: int, device: torch.device) -> torch.Tensor: + if seq_len <= 1: + return torch.arange(seq_len, device=device, dtype=torch.long) + return torch.cat([ + torch.zeros(1, device=device, dtype=torch.long), + 1 + torch.randperm(seq_len - 1, device=device, dtype=torch.long), + ]) + if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple): + te_list = list(self.text_embeds) + attn_list = list(self.attention_mask) if self.attention_mask is not None and isinstance(self.attention_mask, (list, tuple)) else None + attn_is_tuple = isinstance(self.attention_mask, tuple) if self.attention_mask is not None else False + for i, t in enumerate(te_list): + if t.dim() >= 2 and t.shape[1] > 1: + perm = make_perm(t.shape[1], t.device) + te_list[i] = t[:, perm, ...].contiguous() + if attn_list is not None and i < len(attn_list): + attn = attn_list[i] + if attn.dim() >= 2 and attn.shape[1] > 1: + attn_list[i] = attn[:, perm, ...].contiguous() + elif self.attention_mask is not None and not isinstance(self.attention_mask, (list, tuple)) and i == 0: + self.attention_mask = self.attention_mask[:, perm, ...].contiguous() + self.text_embeds = tuple(te_list) if isinstance(self.text_embeds, tuple) else te_list + if attn_list is not None: + self.attention_mask = tuple(attn_list) if attn_is_tuple else attn_list + else: + if self.text_embeds.dim() >= 2 and self.text_embeds.shape[1] > 1: + perm = make_perm(self.text_embeds.shape[1], self.text_embeds.device) + self.text_embeds = self.text_embeds[:, perm, ...].contiguous() + if self.attention_mask is not None and self.attention_mask.dim() >= 2 and self.attention_mask.shape[1] > 1: + self.attention_mask = self.attention_mask[:, perm, ...].contiguous() + def save(self, path: str): """ Save the prompt embeds to a file. From 2d54c629b4032aee07935c5815f938bc0294a333 Mon Sep 17 00:00:00 2001 From: Daan Date: Sun, 15 Feb 2026 00:07:56 +0300 Subject: [PATCH 12/16] refactor(dataloader): remove redundant token shuffling logic - Eliminated unnecessary shuffling of tokens as it is already handled in the subsequent code. fix(dataloader): skip shuffle_tokens when caching text embeddings Apply shuffle_tokens only when not cache_text_embeddings, so cached embeddings are saved from the original caption order. Aligns with existing caption_dropout and token_dropout behavior and keeps cache paths stable. feat(shuffle): keep first segment/token fixed when shuffling captions and embeddings - get_caption: when shuffle_tokens is on, keep first comma-separated segment in place and shuffle only the rest - shuffle_sequence: keep index 0 fixed and permute only positions 1..seq_len-1 feat(train): enhance text embedding caching logic and memory management - Introduced a mechanism to clear cached text embeddings based on the dataset configuration, improving memory efficiency during training. - Updated the AiToolkitDataset to retain prompt embeddings if they are cached, ensuring consistency across epochs. - Added logging for shuffling cached text embedding tokens at the start of each epoch, enhancing visibility into the training process. feat(train): log shuffling of cached text embedding tokens per epoch Added a print statement to log when cached text embedding tokens are shuffled at the start of each epoch, enhancing visibility into the training process. feat(train): shuffle cached text embedding tokens every epoch When cache_text_embeddings is true, shuffle token order along the sequence dimension at the start of each new epoch (after first full pass). Add PromptEmbeds.shuffle_sequence(), dataset set_epoch_num/ clear_cached_embeddings_memory, wire epoch boundary in BaseSDTrainProcess, and sync epoch_num to datasets at train start for correct behavior on resume from checkpoint. fix(sd_trainer): use trigger_word when cached text embeds are missing - Prefer trigger over blank when batch.prompt_embeds is None (no cache on disk) - Require trigger_word when cache_text_embeddings is enabled and batch has no cached embeds; raise ValueError otherwise - Keep reg batches using blank (no trigger) fix(dataloader): respect shuffle_tokens for cached text embeddings Skip shuffling cached prompt embeds when shuffle_tokens is false in set_epoch_num and load_prompt_embedding. --- extensions_built_in/sd_trainer/SDTrainer.py | 13 ++++++-- jobs/process/BaseSDTrainProcess.py | 12 +++++++- toolkit/data_loader.py | 3 ++ toolkit/dataloader_mixins.py | 30 +++++++++++++++---- toolkit/prompt_utils.py | 33 +++++++++++++++++++++ 5 files changed, 82 insertions(+), 9 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index da6c82677..2b475c517 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1526,13 +1526,20 @@ def get_adapter_multiplier(): self.device_torch, dtype=dtype ) else: - embeds_to_use = self.cached_blank_embeds.clone().detach().to( - self.device_torch, dtype=dtype - ) + # No cache on disk: use trigger_word instead of blank embedding if self.cached_trigger_embeds is not None and not is_reg: embeds_to_use = self.cached_trigger_embeds.clone().detach().to( self.device_torch, dtype=dtype ) + else: + if self.is_caching_text_embeddings: + raise ValueError( + "cache_text_embeddings is enabled but no cached embeds in batch. " + "Set trigger_word when using cache_text_embeddings so fallback is available when cache is missing." + ) + embeds_to_use = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) conditional_embeds = concat_prompt_embeds( [embeds_to_use] * noisy_latents.shape[0] ) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index b391af6d6..07f418791 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -27,7 +27,7 @@ from toolkit.basic import value_map from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.custom_adapter import CustomAdapter -from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch +from toolkit.data_loader import get_dataloader_from_datasets, get_dataloader_datasets, trigger_dataloader_setup_epoch from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.ema import ExponentialMovingAverage from toolkit.embedding import Embedding @@ -2322,6 +2322,11 @@ def run(self): dataloader_reg = None dataloader_iterator_reg = None + # Sync epoch_num to datasets so resume from checkpoint keeps correct shuffle-for-cached-embeddings behavior + if dataloader is not None: + for dataset in get_dataloader_datasets(dataloader): + dataset.set_epoch_num(self.epoch_num) + # zero any gradients optimizer.zero_grad() @@ -2407,6 +2412,11 @@ def run(self): dataloader_iterator = iter(dataloader) trigger_dataloader_setup_epoch(dataloader) self.epoch_num += 1 + for dataset in get_dataloader_datasets(dataloader): + dataset.set_epoch_num(self.epoch_num) + clear_embeddings_cache = not getattr(dataset.dataset_config, 'cache_text_embeddings', False) + if clear_embeddings_cache: + dataset.clear_cached_embeddings_memory() if self.train_config.gradient_accumulation_steps == -1: # if we are accumulating for an entire epoch, trigger a step self.is_grad_accumulation_step = False diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 51605c982..cfb084aa0 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -598,7 +598,10 @@ def __len__(self): def _get_single_item(self, index) -> 'FileItemDTO': file_item: 'FileItemDTO' = copy.deepcopy(self.file_list[index]) + file_item._current_epoch_num = getattr(self, '_epoch_num', 0) file_item.load_and_process_image(self.transform) + if file_item.is_text_embedding_cached and file_item.prompt_embeds is not None: + self.file_list[index].prompt_embeds = file_item.prompt_embeds file_item.load_caption(self.caption_dict) return file_item diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 39dbc6468..151f60d2a 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -415,8 +415,9 @@ def get_caption( new_token_list.append(token) token_list = new_token_list - if self.dataset_config.shuffle_tokens: - random.shuffle(token_list) + # Redundant code. Shuffling is performed in the code below. + # if self.dataset_config.shuffle_tokens: + # random.shuffle(token_list) # join back together caption = ', '.join(token_list) @@ -436,14 +437,17 @@ def get_caption( # trigger = self.dataset_config.random_triggers[int(random.random() * (len(self.dataset_config.random_triggers)))] # caption = caption + ', ' + trigger - if self.dataset_config.shuffle_tokens: - # shuffle again + if self.dataset_config.shuffle_tokens and not self.dataset_config.cache_text_embeddings: + # shuffle, keep first segment (until first comma) in place token_list = caption.split(',') # trim whitespace token_list = [x.strip() for x in token_list] # remove empty strings token_list = [x for x in token_list if x] - random.shuffle(token_list) + if len(token_list) > 1: + rest = token_list[1:] + random.shuffle(rest) + token_list = [token_list[0]] + rest caption = ', '.join(token_list) if caption == '': pass @@ -1965,6 +1969,9 @@ def load_prompt_embedding(self, device=None): if self.prompt_embeds is None: # load it from disk self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path()) + if getattr(self, '_current_epoch_num', 0) > 0 and getattr(self.dataset_config, 'shuffle_tokens', False): + self.prompt_embeds.shuffle_sequence() + print_acc(f"Cached text embedding tokens shuffled (epoch {getattr(self, '_current_epoch_num', 0)})") class TextEmbeddingCachingMixin: def __init__(self: 'AiToolkitDataset', **kwargs): @@ -1972,6 +1979,19 @@ def __init__(self: 'AiToolkitDataset', **kwargs): if hasattr(super(), '__init__'): super().__init__(**kwargs) self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings + self._epoch_num: int = 0 + + def set_epoch_num(self: 'AiToolkitDataset', epoch_num: int) -> None: + self._epoch_num = epoch_num + if epoch_num > 0 and self.is_caching_text_embeddings and self.dataset_config.shuffle_tokens: + for file_item in self.file_list: + if getattr(file_item, 'prompt_embeds', None) is not None: + file_item.prompt_embeds.shuffle_sequence() + print_acc(f"\nCached text embedding tokens shuffled (epoch {epoch_num})\n") + + def clear_cached_embeddings_memory(self: 'AiToolkitDataset') -> None: + for file_item in self.file_list: + file_item.cleanup_text_embedding() def cache_text_embeddings(self: 'AiToolkitDataset'): with accelerator.main_process_first(): diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 0f3c10a9b..9813a7056 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -114,6 +114,39 @@ def expand_to_batch(self, batch_size): pe.attention_mask = pe.attention_mask.expand(batch_size, -1) return pe + def shuffle_sequence(self) -> None: + """Shuffle token order along sequence dimension (dim=1), keeping first token fixed. In-place. No-op if seq_len <= 1.""" + def make_perm(seq_len: int, device: torch.device) -> torch.Tensor: + if seq_len <= 1: + return torch.arange(seq_len, device=device, dtype=torch.long) + return torch.cat([ + torch.zeros(1, device=device, dtype=torch.long), + 1 + torch.randperm(seq_len - 1, device=device, dtype=torch.long), + ]) + if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple): + te_list = list(self.text_embeds) + attn_list = list(self.attention_mask) if self.attention_mask is not None and isinstance(self.attention_mask, (list, tuple)) else None + attn_is_tuple = isinstance(self.attention_mask, tuple) if self.attention_mask is not None else False + for i, t in enumerate(te_list): + if t.dim() >= 2 and t.shape[1] > 1: + perm = make_perm(t.shape[1], t.device) + te_list[i] = t[:, perm, ...].contiguous() + if attn_list is not None and i < len(attn_list): + attn = attn_list[i] + if attn.dim() >= 2 and attn.shape[1] > 1: + attn_list[i] = attn[:, perm, ...].contiguous() + elif self.attention_mask is not None and not isinstance(self.attention_mask, (list, tuple)) and i == 0: + self.attention_mask = self.attention_mask[:, perm, ...].contiguous() + self.text_embeds = tuple(te_list) if isinstance(self.text_embeds, tuple) else te_list + if attn_list is not None: + self.attention_mask = tuple(attn_list) if attn_is_tuple else attn_list + else: + if self.text_embeds.dim() >= 2 and self.text_embeds.shape[1] > 1: + perm = make_perm(self.text_embeds.shape[1], self.text_embeds.device) + self.text_embeds = self.text_embeds[:, perm, ...].contiguous() + if self.attention_mask is not None and self.attention_mask.dim() >= 2 and self.attention_mask.shape[1] > 1: + self.attention_mask = self.attention_mask[:, perm, ...].contiguous() + def save(self, path: str): """ Save the prompt embeds to a file. From 8b6330e68e86c893a0d803fe84ae017f3c79879a Mon Sep 17 00:00:00 2001 From: Daan Date: Tue, 17 Feb 2026 21:47:42 +0300 Subject: [PATCH 13/16] fix(adafactor): apply new min_lr/max_lr on restart after loading checkpoint --- toolkit/optimizers/adafactor.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py index 1879f3333..16f16508e 100644 --- a/toolkit/optimizers/adafactor.py +++ b/toolkit/optimizers/adafactor.py @@ -144,6 +144,10 @@ def __init__( } super().__init__(params, defaults) + # Store LR limits so they can be reapplied after load_state_dict (restart with new config). + self._min_lr = min_lr + self._max_lr = max_lr + self.base_lrs: List[float] = [ group['lr'] for group in self.param_groups ] @@ -173,8 +177,14 @@ def __init__( # needs to be enabled to count parameters if self.do_parameter_swapping: self.enable_parameter_swapping(self.parameter_swapping_factor) - - + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + # Apply current run's min_lr/max_lr so changed config is used after restart. + for group in self.param_groups: + group["min_lr"] = self._min_lr + group["max_lr"] = self._max_lr + def enable_parameter_swapping(self, parameter_swapping_factor=0.1): self.do_parameter_swapping = True self.parameter_swapping_factor = parameter_swapping_factor From 0b66ca57b588bce6cebf762fdfc5450c5b475b78 Mon Sep 17 00:00:00 2001 From: Daan Date: Tue, 17 Feb 2026 21:47:54 +0300 Subject: [PATCH 14/16] refactor(paths): rename normalize_model_path to normalize_path and strip whitespace --- toolkit/paths.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/toolkit/paths.py b/toolkit/paths.py index ad4d29ffb..ad8a31394 100644 --- a/toolkit/paths.py +++ b/toolkit/paths.py @@ -25,10 +25,11 @@ def get_path(path): return path -def normalize_model_path(path: Optional[str]) -> Optional[str]: - """Strip trailing path separators so Hugging Face transformers don't get empty basenames. - Avoids the 'The module name (originally ) is not a valid Python identifier' warning. +def normalize_path(path: Optional[str]) -> Optional[str]: + """Strip leading/trailing whitespace and trailing path separators. + Use for any path: model dirs, LoRA/safetensors files, etc. """ - if isinstance(path, str): - return path.rstrip(os.sep).rstrip("/") - return path + if not isinstance(path, str): + return path + path = path.strip() + return path.rstrip(os.sep).rstrip("/") From 108617e6900ba0c73f43ef4f696803b99ad15e92 Mon Sep 17 00:00:00 2001 From: Daan Date: Tue, 17 Feb 2026 21:48:04 +0300 Subject: [PATCH 15/16] feat(debug): add memory_debug util and use for CUDA load/unload logging - Add toolkit/util/debug.py with set_debug_config() and memory_debug() context manager - Register logging config in BaseSDTrainProcess - Wrap text encoder unload and Z-Image load stages in memory_debug for optional [DEBUG ...] CUDA lines - Add debug flag and CUDA memory log on text encoder unload --- extensions_built_in/sd_trainer/SDTrainer.py | 10 ++- toolkit/config_modules.py | 4 +- toolkit/util/debug.py | 82 +++++++++++++++++++++ 3 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 toolkit/util/debug.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 2b475c517..650a4abbe 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -33,6 +33,7 @@ import math from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe +from toolkit.util.debug import memory_debug from toolkit.util.losses import wavelet_loss, stepped_loss import torch.nn.functional as F from toolkit.unloader import unload_text_encoder @@ -343,12 +344,15 @@ def hook_before_train_loop(self): # unload the text encoder if self.is_caching_text_embeddings: - unload_text_encoder(self.sd) + with memory_debug(print_acc, "UNLOAD TEXT ENCODER"): + unload_text_encoder(self.sd) + flush() else: # todo once every model is tested to work, unload properly. Though, this will all be merged into one thing. # keep legacy usage for now. - self.sd.text_encoder_to("cpu") - flush() + with memory_debug(print_acc, "UNLOAD TEXT ENCODER"): + self.sd.text_encoder_to("cpu") + flush() if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None: # make sure we have this if not unloading diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 5e9b5a23c..7962c0788 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -6,6 +6,7 @@ import torch import torchaudio +from toolkit.paths import normalize_path from toolkit.prompt_utils import PromptEmbeds ImgExt = Literal['jpg', 'png', 'webp'] @@ -36,6 +37,7 @@ def __init__(self, **kwargs): self.verbose: bool = kwargs.get('verbose', False) self.use_wandb: bool = kwargs.get('use_wandb', False) self.use_ui_logger: bool = kwargs.get('use_ui_logger', False) + self.debug: bool = kwargs.get('debug', False) self.project_name: str = kwargs.get('project_name', 'ai-toolkit') self.run_name: str = kwargs.get('run_name', None) @@ -219,7 +221,7 @@ def __init__(self, **kwargs): self.layer_offloading = kwargs.get('layer_offloading', False) # start from a pretrained lora - self.pretrained_lora_path = kwargs.get('pretrained_lora_path', None) + self.pretrained_lora_path = normalize_path(kwargs.get('pretrained_lora_path', None)) AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v'] diff --git a/toolkit/util/debug.py b/toolkit/util/debug.py new file mode 100644 index 000000000..61d72a2f2 --- /dev/null +++ b/toolkit/util/debug.py @@ -0,0 +1,82 @@ +""" +Debug utilities. memory_debug context manager measures GPU/RAM around a block; +enabled via set_debug_config(config) with config.debug. Extensible to RAM later. +""" +import contextlib +from typing import Callable + +import torch + +_debug_config = None + + +def set_debug_config(config) -> None: + """Register the config object used to decide if memory debug is enabled (config.debug).""" + global _debug_config + _debug_config = config + + +def is_debug_enabled() -> bool: + """Return True if debug logging is enabled (config.debug). Used for optional debug messages.""" + if _debug_config is None: + return False + return bool(getattr(_debug_config, "debug", False)) + + +def _is_enabled_for_cuda() -> bool: + if _debug_config is None: + return False + if not getattr(_debug_config, "debug", False): + return False + return torch.cuda.is_available() + + +def _cuda_snapshot_mb(): + """Return (allocated_mb, max_allocated_mb).""" + return ( + torch.cuda.memory_allocated() / 2**20, + torch.cuda.max_memory_allocated() / 2**20, + ) + + +def _format_cuda_diff(label: str, before: tuple, after: tuple) -> list: + mem_before, max_before = before + mem_after, max_after = after + delta = mem_before - mem_after + delta_str = f"(freed {delta:.1f} MB)" if delta >= 0 else f"(+{-delta:.1f} MB)" + return [ + f"[DEBUG {label}] CUDA allocated: {mem_before:.1f} MB -> {mem_after:.1f} MB {delta_str}", + f"[DEBUG {label}] CUDA max: {max_before:.1f} MB -> {max_after:.1f} MB", + ] + + +@contextlib.contextmanager +def memory_debug( + print_fn: Callable[[str], None], + label: str, + kind: str = "cuda", +): + """ + Context manager: measure memory around the block and log if debug is enabled. + enabled is read from the config set via set_debug_config(); no need to pass it. + kind="cuda" measures CUDA allocated/max; other kinds (e.g. "ram") are stubs for now. + """ + if kind != "cuda": + yield + return + if not _is_enabled_for_cuda(): + yield + return + before = _cuda_snapshot_mb() + try: + yield + finally: + torch.cuda.synchronize() + after = _cuda_snapshot_mb() + for line in _format_cuda_diff(label, before, after): + print_fn(line) + + +def cuda_memory_debug(print_fn: Callable[[str], None], label: str): + """Alias for memory_debug(print_fn, label, kind="cuda").""" + return memory_debug(print_fn, label, kind="cuda") From 536431e87814cdb1c112b844fa60da0856859d11 Mon Sep 17 00:00:00 2001 From: Daan Date: Tue, 17 Feb 2026 21:48:56 +0300 Subject: [PATCH 16/16] fix(sampling), feat(lora): sampling VRAM, single LoRA, path and debug wiring - fix(sampling): load sampling transformer on CPU, guarantee unload to CPU after generate_images, do not force unet to GPU - feat(lora): use single LoRA for training and sampling via shared parameters; free pretrained LoRA memory after load - Use normalize_path in Z-Image name_or_path; wire memory_debug in BaseSDTrainProcess LoRA load and Z-Image stages - train.example.yaml, base_model, network_mixins updates --- .../diffusion_models/z_image/z_image.py | 141 ++--- .../sd_trainer/config/train.example.yaml | 1 + jobs/process/BaseSDTrainProcess.py | 130 ++--- toolkit/lora_special.py | 461 ++++++++-------- toolkit/models/base_model.py | 503 +++++++++--------- toolkit/network_mixins.py | 7 +- 6 files changed, 648 insertions(+), 595 deletions(-) diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py index 9cb4e2943..7fc257757 100644 --- a/extensions_built_in/diffusion_models/z_image/z_image.py +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -15,10 +15,11 @@ ) from toolkit.accelerator import unwrap_model from optimum.quanto import freeze +from toolkit.util.debug import memory_debug from toolkit.util.quantize import quantize, get_qtype, quantize_model from toolkit.util.device import safe_module_to_device from toolkit.memory_management import MemoryManager -from toolkit.paths import normalize_model_path +from toolkit.paths import normalize_path from safetensors.torch import load_file import safetensors @@ -165,7 +166,7 @@ def _load_sampling_transformer(self) -> Optional[ZImageTransformer2DModel]: return None dtype = self.torch_dtype self.print_and_status_update("Loading sampling transformer") - sampling_model_path = normalize_model_path(self.model_config.sampling_name_or_path) + sampling_model_path = normalize_path(self.model_config.sampling_name_or_path) sampling_transformer_path = sampling_model_path sampling_transformer_subfolder = "transformer" @@ -177,12 +178,13 @@ def _load_sampling_transformer(self) -> Optional[ZImageTransformer2DModel]: sampling_transformer_path, subfolder=sampling_transformer_subfolder, torch_dtype=dtype, + device_map="cpu", ) if self.model_config.quantize: self.print_and_status_update("Quantizing sampling transformer") quantize_model(self, sampling_transformer) flush() - # Always keep sampling transformer on CPU to save VRAM during training + # Already on CPU via device_map="cpu"; ensure it stays there sampling_transformer.to("cpu") flush() return sampling_transformer @@ -192,7 +194,7 @@ def _load_transformer(self, model_path: str) -> Tuple[ZImageTransformer2DModel, Returns (transformer, base_model_path). base_model_path is used for tokenizer, text_encoder, VAE. """ dtype = self.torch_dtype - base_model_path = normalize_model_path(self.model_config.extras_name_or_path) + base_model_path = normalize_path(self.model_config.extras_name_or_path) self.print_and_status_update("Loading transformer") @@ -247,7 +249,7 @@ def load_model(self): global _zimage_load_debug_patched dtype = self.torch_dtype self.print_and_status_update("Loading ZImage model") - model_path = normalize_model_path(self.model_config.name_or_path) + model_path = normalize_path(self.model_config.name_or_path) if self.model_config.debug_zimage_load and not _zimage_load_debug_patched: log_func = self.print_and_status_update @@ -298,84 +300,89 @@ def __exit__(_self, *exc): if self.model_config.debug_zimage_load: self.print_and_status_update("[ZImage debug] === Loading sampling transformer (from_pretrained) ===") # Load sampling transformer first (if configured) to control peak VRAM - self._sampling_transformer = self._load_sampling_transformer() + with memory_debug(self.print_and_status_update, "Loading sampling transformer"): + self._sampling_transformer = self._load_sampling_transformer() if self.model_config.debug_zimage_load: self.print_and_status_update("[ZImage debug] === Loading main transformer (from_pretrained) ===") - transformer, base_model_path = self._load_transformer(model_path) + with memory_debug(self.print_and_status_update, "Loading transformer"): + transformer, base_model_path = self._load_transformer(model_path) self.print_and_status_update("Text Encoder") - tokenizer = AutoTokenizer.from_pretrained( - base_model_path, subfolder="tokenizer", dtype=dtype - ) - text_encoder = Qwen3ForCausalLM.from_pretrained( - base_model_path, subfolder="text_encoder", dtype=dtype - ) - - if ( - self.model_config.layer_offloading - and self.model_config.layer_offloading_text_encoder_percent > 0 - ): - MemoryManager.attach( - text_encoder, - self.device_torch, - offload_percent=self.model_config.layer_offloading_text_encoder_percent, + with memory_debug(self.print_and_status_update, "Text Encoder"): + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", dtype=dtype + ) + text_encoder = Qwen3ForCausalLM.from_pretrained( + base_model_path, subfolder="text_encoder", dtype=dtype ) - text_encoder.to(self.device_torch, dtype=dtype) - flush() + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) - if self.model_config.quantize_te: - self.print_and_status_update("Quantizing Text Encoder") - quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) - freeze(text_encoder) + text_encoder.to(self.device_torch, dtype=dtype) flush() + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + self.print_and_status_update("Loading VAE") - vae = AutoencoderKL.from_pretrained( - base_model_path, subfolder="vae", torch_dtype=dtype - ) + with memory_debug(self.print_and_status_update, "Loading VAE"): + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype + ) self.noise_scheduler = ZImageModel.get_train_scheduler() self.print_and_status_update("Making pipe") - - kwargs = {} - - pipe: ZImagePipeline = ZImagePipeline( - scheduler=self.noise_scheduler, - text_encoder=None, - tokenizer=tokenizer, - vae=vae, - transformer=None, - **kwargs, - ) - # for quantization, it works best to do these after making the pipe - pipe.text_encoder = text_encoder - pipe.transformer = transformer + with memory_debug(self.print_and_status_update, "Making pipe"): + kwargs = {} + + pipe: ZImagePipeline = ZImagePipeline( + scheduler=self.noise_scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + **kwargs, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer self.print_and_status_update("Preparing Model") + with memory_debug(self.print_and_status_update, "Preparing Model"): + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] - text_encoder = [pipe.text_encoder] - tokenizer = [pipe.tokenizer] - - # leave it on cpu for now - if not self.low_vram: - pipe.transformer = pipe.transformer.to(self.device_torch) + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) - flush() - # just to make sure everything is on the right device and dtype - text_encoder[0].to(self.device_torch) - text_encoder[0].requires_grad_(False) - text_encoder[0].eval() - flush() + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() # save it to the model class - self.vae = vae - self.text_encoder = text_encoder # list of text encoders - self.tokenizer = tokenizer # list of tokenizers - self.model = pipe.transformer - self.pipeline = pipe - self.print_and_status_update("Model Loaded") + with memory_debug(self.print_and_status_update, "Model Loaded"): + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") def get_generation_pipeline(self): scheduler = ZImageModel.get_train_scheduler() @@ -550,7 +557,7 @@ def generate_images( """ saved_network = None try: - # If using sampling transformer with LoRA, sync weights and swap network + # If using sampling transformer with LoRA, swap to sampling network (weights already shared) if ( hasattr(self, '_sampling_transformer') and self._sampling_transformer is not None @@ -559,9 +566,6 @@ def generate_images( and hasattr(self, 'network') and self.network is not None ): - # Sync weights from training network to sampling network - self._sampling_network.load_state_dict(self.network.state_dict()) - # Save current network and swap to sampling network saved_network = self.network self.network = self._sampling_network @@ -571,6 +575,9 @@ def generate_images( # Restore original network if saved_network is not None: self.network = saved_network + # Ensure sampling transformer is off GPU (defense in depth) + if getattr(self, "_sampling_transformer", None) is not None: + self._sampling_transformer.to("cpu") def get_loss_target(self, *args, **kwargs): noise = kwargs.get("noise") diff --git a/extensions_built_in/sd_trainer/config/train.example.yaml b/extensions_built_in/sd_trainer/config/train.example.yaml index 793d5d55b..4587e786e 100644 --- a/extensions_built_in/sd_trainer/config/train.example.yaml +++ b/extensions_built_in/sd_trainer/config/train.example.yaml @@ -76,6 +76,7 @@ config: log_every: 10 # log every this many steps use_wandb: false # not supported yet verbose: false + debug: false # enable to log CUDA memory around text encoder unload # You can put any information you want here, and it will be saved in the model. # The below is an example, but you can put your grocery list in it if you want. diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 07f418791..0033e433a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -72,6 +72,7 @@ import hashlib from toolkit.util.blended_blur_noise import get_blended_blur_noise +from toolkit.util.debug import memory_debug, set_debug_config from toolkit.util.get_model import get_model_class def flush(): @@ -131,6 +132,7 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No self.has_first_sample_requested = False self.first_sample_config = self.sample_config self.logging_config = LoggingConfig(**self.get_conf('logging', {})) + set_debug_config(self.logging_config) self.logger = create_logger(self.logging_config, config, self.save_root) self.optimizer: torch.optim.Optimizer = None self.lr_scheduler = None @@ -873,12 +875,10 @@ def load_training_state_from_metadata(self, path): def load_weights(self, path): if self.network is not None: - extra_weights = self.network.load_weights(path) + self.network.load_weights(path) self.load_training_state_from_metadata(path) - return extra_weights else: print_acc("load_weights not implemented for non-network models") - return None def apply_snr(self, seperated_loss, timesteps): if self.train_config.learnable_snr_gos: @@ -1991,58 +1991,62 @@ def run(self): # Create sampling network if sampling_transformer is specified and LoRA is used if hasattr(self.sd, '_sampling_transformer') and self.sd._sampling_transformer is not None: - print_acc("Creating sampling network for _sampling_transformer") - sampling_network = NetworkClass( - text_encoder=text_encoder, - unet=self.sd._sampling_transformer, - lora_dim=self.network_config.linear, - multiplier=1.0, - alpha=self.network_config.linear_alpha, - train_unet=self.train_config.train_unet, - train_text_encoder=self.train_config.train_text_encoder, - conv_lora_dim=self.network_config.conv, - conv_alpha=self.network_config.conv_alpha, - is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, - is_v2=self.model_config.is_v2, - is_v3=self.model_config.is_v3, - is_pixart=self.model_config.is_pixart, - is_auraflow=self.model_config.is_auraflow, - is_flux=self.model_config.is_flux, - is_lumina2=self.model_config.is_lumina2, - is_ssd=self.model_config.is_ssd, - is_vega=self.model_config.is_vega, - dropout=self.network_config.dropout, - rank_dropout=self.network_config.rank_dropout, - module_dropout=self.network_config.module_dropout, - use_text_encoder_1=self.model_config.use_text_encoder_1, - use_text_encoder_2=self.model_config.use_text_encoder_2, - use_bias=is_lorm, - is_lorm=is_lorm, - network_config=self.network_config, - network_type=self.network_config.type, - transformer_only=self.network_config.transformer_only, - is_transformer=self.sd.is_transformer, - base_model=self.sd, - **network_kwargs - ) - - sampling_network.force_to(self.device_torch, dtype=torch.float32) - sampling_network._update_torch_multiplier() - - sampling_network.apply_to( - text_encoder, - self.sd._sampling_transformer, - self.train_config.train_text_encoder, - self.train_config.train_unet - ) - - # Set can_merge_in same as main network (False if quantized/layer_offloading) - if self.model_config.quantize or self.model_config.layer_offloading: - sampling_network.can_merge_in = False - - # Store sampling network on model for use during generation - self.sd._sampling_network = sampling_network - flush() + with memory_debug(print_acc, "Creating sampling network"): + print_acc("Creating sampling network for _sampling_transformer") + sampling_network = NetworkClass( + text_encoder=text_encoder, + unet=self.sd._sampling_transformer, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + rank_dropout=self.network_config.rank_dropout, + module_dropout=self.network_config.module_dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=is_lorm, + is_lorm=is_lorm, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=self.sd.is_transformer, + base_model=self.sd, + **network_kwargs + ) + + sampling_network.force_to(self.device_torch, dtype=torch.float32) + sampling_network._update_torch_multiplier() + + sampling_network.apply_to( + text_encoder, + self.sd._sampling_transformer, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + + # Set can_merge_in same as main network (False if quantized/layer_offloading) + if self.model_config.quantize or self.model_config.layer_offloading: + sampling_network.can_merge_in = False + + # Store sampling network on model for use during generation + self.sd._sampling_network = sampling_network + # One LoRA for both: share parameters with training network (no copy/sync) + if hasattr(sampling_network, "share_parameters_with"): + sampling_network.share_parameters_with(self.network) + flush() # LyCORIS doesnt have default_lr config = { @@ -2069,13 +2073,14 @@ def run(self): lora_name = f"{lora_name}_LoRA" latest_save_path = self.get_latest_save_path(lora_name) - extra_weights = None if latest_save_path is not None: - print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") - print_acc(f"Loading from {latest_save_path}") - extra_weights = self.load_weights(latest_save_path) - self.network.multiplier = 1.0 - + with memory_debug(print_acc, "pretrained_lora_load"): + print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + print_acc(f"Loading from {latest_save_path}") + self.load_weights(latest_save_path) + self.network.multiplier = 1.0 + flush() + if self.network_config.layer_offloading: MemoryManager.attach( self.network, @@ -2294,7 +2299,8 @@ def run(self): print_acc("Skipping first sample due to config setting") elif self.step_num <= 1 or self.train_config.force_first_sample: print_acc("Generating baseline samples before training") - self.sample(self.step_num) + with memory_debug(print_acc, "Generating baseline samples before training"): + self.sample(self.step_num) if self.accelerator.is_local_main_process: self.progress_bar = ToolkitProgressBar( diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 12cec5c1b..2e799f7f1 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -14,6 +14,7 @@ from .config_modules import NetworkConfig from .lorm import count_parameters from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin +from toolkit.util.debug import memory_debug from toolkit.kohya_lora import LoRANetwork from toolkit.models.DoRA import DoRAModule @@ -291,237 +292,238 @@ def __init__( self.full_train_in_out = full_train_in_out - if modules_dim is not None: - print(f"create LoRA network from weights") - elif block_dims is not None: - print(f"create LoRA network from block_dims") - print( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") - if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") - else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - if self.conv_lora_dim is not None: + with memory_debug(print, "create LoRA network"): + if modules_dim is not None: + print(f"create LoRA network from weights") + elif block_dims is not None: + print(f"create LoRA network from block_dims") + print( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + print(f"block_dims: {block_dims}") + print(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + print(f"conv_block_dims: {conv_block_dims}") + print(f"conv_block_alphas: {conv_block_alphas}") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") print( - f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + if self.conv_lora_dim is not None: + print( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") - # create module instances - def create_modules( + # create module instances + def create_modules( is_unet: bool, text_encoder_idx: Optional[int], # None, 1, 2 root_module: torch.nn.Module, target_replace_modules: List[torch.nn.Module], - ) -> List[LoRAModule]: - unet_prefix = self.LORA_PREFIX_UNET - if self.peft_format: - unet_prefix = self.PEFT_PREFIX_UNET - if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer: - unet_prefix = f"lora_transformer" + ) -> List[LoRAModule]: + unet_prefix = self.LORA_PREFIX_UNET if self.peft_format: - unet_prefix = "transformer" - - prefix = ( - unet_prefix - if is_unet - else ( - self.LORA_PREFIX_TEXT_ENCODER - if text_encoder_idx is None - else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + unet_prefix = self.PEFT_PREFIX_UNET + if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer: + unet_prefix = f"lora_transformer" + if self.peft_format: + unet_prefix = "transformer" + + prefix = ( + unet_prefix + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) ) - ) - loras = [] - skipped = [] - attached_modules = [] - lora_shape_dict = {} - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - is_linear = child_module.__class__.__name__ in LINEAR_MODULES - is_conv2d = child_module.__class__.__name__ in CONV_MODULES - is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) - - - lora_name = [prefix, name, child_name] - # filter out blank - lora_name = [x for x in lora_name if x and x != ""] - lora_name = ".".join(lora_name) - # if it doesnt have a name, it wil have two dots - lora_name.replace("..", ".") - clean_name = lora_name - if self.peft_format: - # we replace this on saving - lora_name = lora_name.replace(".", "$$") - else: - lora_name = lora_name.replace(".", "_") - - skip = False - if any([word in clean_name for word in self.ignore_if_contains]): - skip = True - - # see if it is over threshold - if count_parameters(child_module) < parameter_threshold: - skip = True - - if self.transformer_only and is_unet: - transformer_block_names = None - if base_model is not None: - transformer_block_names = base_model.get_transformer_block_names() - - if transformer_block_names is not None: - if not any([name in lora_name for name in transformer_block_names]): - skip = True + loras = [] + skipped = [] + attached_modules = [] + lora_shape_dict = {} + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ in LINEAR_MODULES + is_conv2d = child_module.__class__.__name__ in CONV_MODULES + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + + lora_name = [prefix, name, child_name] + # filter out blank + lora_name = [x for x in lora_name if x and x != ""] + lora_name = ".".join(lora_name) + # if it doesnt have a name, it wil have two dots + lora_name.replace("..", ".") + clean_name = lora_name + if self.peft_format: + # we replace this on saving + lora_name = lora_name.replace(".", "$$") else: - if self.is_pixart: - if "transformer_blocks" not in lora_name: - skip = True - if self.is_flux: - if "transformer_blocks" not in lora_name: - skip = True - if self.is_lumina2: - if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name: - skip = True - if self.is_v3: - if "transformer_blocks" not in lora_name: - skip = True - - # handle custom models - if hasattr(root_module, 'transformer_blocks'): - if "transformer_blocks" not in lora_name: - skip = True - - if hasattr(root_module, 'blocks'): - if "blocks" not in lora_name: - skip = True - - if hasattr(root_module, 'single_blocks'): - if "single_blocks" not in lora_name and "double_blocks" not in lora_name: - skip = True - - if (is_linear or is_conv2d) and not skip: + lora_name = lora_name.replace(".", "_") - if self.only_if_contains is not None: - if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]): - continue + skip = False + if any([word in clean_name for word in self.ignore_if_contains]): + skip = True - dim = None - alpha = None - - if modules_dim is not None: - # モジュール指定あり - if lora_name in modules_dim: - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] - else: - # 通常、すべて対象とする - if is_linear or is_conv2d_1x1: - dim = self.lora_dim - alpha = self.alpha - elif self.conv_lora_dim is not None: - dim = self.conv_lora_dim - alpha = self.conv_alpha - - if dim is None or dim == 0: - # skipした情報を出力 - if is_linear or is_conv2d_1x1 or ( - self.conv_lora_dim is not None or conv_block_dims is not None): - skipped.append(lora_name) - continue + # see if it is over threshold + if count_parameters(child_module) < parameter_threshold: + skip = True - module_kwargs = {} - - if self.network_type.lower() == "lokr": - module_kwargs["factor"] = self.network_config.lokr_factor - - if self.is_ara: - module_kwargs["is_ara"] = True - - lora = module_class( - lora_name, - child_module, - self.multiplier, - dim, - alpha, - dropout=dropout, - rank_dropout=rank_dropout, - module_dropout=module_dropout, - network=self, - parent=module, - use_bias=use_bias, - **module_kwargs - ) - loras.append(lora) - if self.network_type.lower() == "lokr": - try: - lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)] - except: - pass - else: - if self.full_rank: - lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape)] + if self.transformer_only and is_unet: + transformer_block_names = None + if base_model is not None: + transformer_block_names = base_model.get_transformer_block_names() + + if transformer_block_names is not None: + if not any([name in lora_name for name in transformer_block_names]): + skip = True else: - lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)] - return loras, skipped - - text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] - - # create LoRA for text encoder - # 毎回すべてのモジュールを作るのは無駄なので要検討 - self.text_encoder_loras = [] - skipped_te = [] - if train_text_encoder: - for i, text_encoder in enumerate(text_encoders): - if not use_text_encoder_1 and i == 0: - continue - if not use_text_encoder_2 and i == 1: - continue - if len(text_encoders) > 1: - index = i + 1 - print(f"create LoRA for Text Encoder {index}:") - else: - index = None - print(f"create LoRA for Text Encoder:") - - replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - - if self.is_pixart: - replace_modules = ["T5EncoderModel"] - - text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules) - self.text_encoder_loras.extend(text_encoder_loras) - skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - - # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights - target_modules = target_lin_modules - if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: - target_modules += target_conv_modules - - if is_v3: - target_modules = ["SD3Transformer2DModel"] - - if is_pixart: - target_modules = ["PixArtTransformer2DModel"] - - if is_auraflow: - target_modules = ["AuraFlowTransformer2DModel"] - - if is_flux: - target_modules = ["FluxTransformer2DModel"] - - if is_lumina2: - target_modules = ["Lumina2Transformer2DModel"] + if self.is_pixart: + if "transformer_blocks" not in lora_name: + skip = True + if self.is_flux: + if "transformer_blocks" not in lora_name: + skip = True + if self.is_lumina2: + if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name: + skip = True + if self.is_v3: + if "transformer_blocks" not in lora_name: + skip = True + + # handle custom models + if hasattr(root_module, 'transformer_blocks'): + if "transformer_blocks" not in lora_name: + skip = True + + if hasattr(root_module, 'blocks'): + if "blocks" not in lora_name: + skip = True + + if hasattr(root_module, 'single_blocks'): + if "single_blocks" not in lora_name and "double_blocks" not in lora_name: + skip = True + + if (is_linear or is_conv2d) and not skip: + + if self.only_if_contains is not None: + if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]): + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or ( + self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue + + module_kwargs = {} + + if self.network_type.lower() == "lokr": + module_kwargs["factor"] = self.network_config.lokr_factor + + if self.is_ara: + module_kwargs["is_ara"] = True + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + network=self, + parent=module, + use_bias=use_bias, + **module_kwargs + ) + loras.append(lora) + if self.network_type.lower() == "lokr": + try: + lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)] + except: + pass + else: + if self.full_rank: + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape)] + else: + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)] + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras = [] + skipped_te = [] + if train_text_encoder: + for i, text_encoder in enumerate(text_encoders): + if not use_text_encoder_1 and i == 0: + continue + if not use_text_encoder_2 and i == 1: + continue + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}:") + else: + index = None + print(f"create LoRA for Text Encoder:") + + replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + + if self.is_pixart: + replace_modules = ["T5EncoderModel"] + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = target_lin_modules + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: + target_modules += target_conv_modules + + if is_v3: + target_modules = ["SD3Transformer2DModel"] + + if is_pixart: + target_modules = ["PixArtTransformer2DModel"] + + if is_auraflow: + target_modules = ["AuraFlowTransformer2DModel"] + + if is_flux: + target_modules = ["FluxTransformer2DModel"] + + if is_lumina2: + target_modules = ["Lumina2Transformer2DModel"] - if train_unet: - self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - else: - self.unet_loras = [] - skipped_un = [] - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + if train_unet: + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) + else: + self.unet_loras = [] + skipped_un = [] + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: @@ -580,6 +582,31 @@ def create_modules( unet.conv_in = self.unet_conv_in unet.conv_out = self.unet_conv_out + def share_parameters_with(self, other: "LoRASpecialNetwork") -> None: + """ + Share all trainable parameters with another network of the same structure. + Used so sampling_network uses the same LoRA weights as the training network + (one LoRA for both training and sampling, no copy/sync). + """ + assert len(self.unet_loras) == len(other.unet_loras), "unet_loras length mismatch" + assert len(self.text_encoder_loras) == len(other.text_encoder_loras), "text_encoder_loras length mismatch" + + def _share_lora_pair(my_lora: torch.nn.Module, other_lora: torch.nn.Module) -> None: + assert getattr(my_lora, "lora_name", None) == getattr(other_lora, "lora_name", None), ( + f"lora name mismatch: {getattr(my_lora, 'lora_name', None)} vs {getattr(other_lora, 'lora_name', None)}" + ) + for name, param in other_lora.named_parameters(): + parts = name.split(".") + obj = my_lora + for p in parts[:-1]: + obj = getattr(obj, p) + setattr(obj, parts[-1], param) + + for my_lora, other_lora in zip(self.unet_loras, other.unet_loras): + _share_lora_pair(my_lora, other_lora) + for my_lora, other_lora in zip(self.text_encoder_loras, other.text_encoder_loras): + _share_lora_pair(my_lora, other_lora) + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): # call Lora prepare_optimizer_params all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index fb9bebc2c..906996f7b 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -41,6 +41,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc +from toolkit.util.debug import is_debug_enabled if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork @@ -410,6 +411,7 @@ def generate_images( rng_state = torch.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + pipeline_created = (pipeline is None) if pipeline is None: pipeline = self.get_generation_pipeline() try: @@ -423,267 +425,277 @@ def generate_images( # pipeline.to(self.device_torch) - with network: - with torch.no_grad(): - if network is not None: - assert network.is_active + try: + with network: + with torch.no_grad(): + if network is not None: + assert network.is_active - # Load sampling transformer onto device if it exists - if self._sampling_transformer is not None: - self.model.to("cpu") - self._sampling_transformer.to(self.device_torch, dtype=self.torch_dtype) - else: - self.model.to(self.device_torch, dtype=self.torch_dtype) + # Load sampling transformer onto device if it exists + if self._sampling_transformer is not None: + self.model.to("cpu") + self._sampling_transformer.to(self.device_torch, dtype=self.torch_dtype) + else: + self.model.to(self.device_torch, dtype=self.torch_dtype) - for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): - gen_config = image_configs[i] + for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): + gen_config = image_configs[i] - extra = {} + extra = {} - validation_image = None - if self.adapter is not None and gen_config.adapter_image_path is not None: - validation_image = Image.open(gen_config.adapter_image_path) - if ".inpaint." not in gen_config.adapter_image_path: - validation_image = validation_image.convert("RGB") + validation_image = None + if self.adapter is not None and gen_config.adapter_image_path is not None: + validation_image = Image.open(gen_config.adapter_image_path) + if ".inpaint." not in gen_config.adapter_image_path: + validation_image = validation_image.convert("RGB") + else: + # make sure it has an alpha + if validation_image.mode != "RGBA": + raise ValueError("Inpainting images must have an alpha channel") + if isinstance(self.adapter, T2IAdapter): + # not sure why this is double?? + validation_image = validation_image.resize( + (gen_config.width * 2, gen_config.height * 2)) + extra['image'] = validation_image + extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, ControlNetModel): + validation_image = validation_image.resize( + (gen_config.width, gen_config.height)) + extra['image'] = validation_image + extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None: + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['control_image'] = validation_image + extra['control_image_idx'] = gen_config.ctrl_idx + if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + if isinstance(self.adapter, CustomAdapter): + # todo allow loading multiple + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + self.adapter.num_images = 1 + if isinstance(self.adapter, ReferenceAdapter): + # need -1 to 1 + validation_image = transforms.ToTensor()(validation_image) + validation_image = validation_image * 2.0 - 1.0 + validation_image = validation_image.unsqueeze(0) + self.adapter.set_reference_images(validation_image) + + if network is not None: + network.multiplier = gen_config.network_multiplier + torch.manual_seed(gen_config.seed) + torch.cuda.manual_seed(gen_config.seed) + + generator = torch.manual_seed(gen_config.seed) + + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ + and gen_config.adapter_image_path is not None: + # run through the adapter to saturate the embeds + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + validation_image) + self.adapter(conditional_clip_embeds) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + # handle condition the prompts + gen_config.prompt = self.adapter.condition_prompt( + gen_config.prompt, + is_unconditional=False, + ) + gen_config.prompt_2 = gen_config.prompt + gen_config.negative_prompt = self.adapter.condition_prompt( + gen_config.negative_prompt, + is_unconditional=True, + ) + gen_config.negative_prompt_2 = gen_config.negative_prompt + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: + self.adapter.trigger_pre_te( + tensors_0_1=validation_image, + is_training=False, + has_been_preprocessed=False, + quad_count=4 + ) + + if self.sample_prompts_cache is not None: + conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype) + unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) else: - # make sure it has an alpha - if validation_image.mode != "RGBA": - raise ValueError("Inpainting images must have an alpha channel") - if isinstance(self.adapter, T2IAdapter): - # not sure why this is double?? - validation_image = validation_image.resize( - (gen_config.width * 2, gen_config.height * 2)) - extra['image'] = validation_image - extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale - if isinstance(self.adapter, ControlNetModel): - validation_image = validation_image.resize( - (gen_config.width, gen_config.height)) - extra['image'] = validation_image - extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale - if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None: - validation_image = validation_image.resize((gen_config.width, gen_config.height)) - extra['control_image'] = validation_image - extra['control_image_idx'] = gen_config.ctrl_idx - if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): - transform = transforms.Compose([ - transforms.ToTensor(), - ]) - validation_image = transform(validation_image) - if isinstance(self.adapter, CustomAdapter): - # todo allow loading multiple - transform = transforms.Compose([ - transforms.ToTensor(), - ]) - validation_image = transform(validation_image) - self.adapter.num_images = 1 - if isinstance(self.adapter, ReferenceAdapter): - # need -1 to 1 - validation_image = transforms.ToTensor()(validation_image) - validation_image = validation_image * 2.0 - 1.0 - validation_image = validation_image.unsqueeze(0) - self.adapter.set_reference_images(validation_image) - - if network is not None: - network.multiplier = gen_config.network_multiplier - torch.manual_seed(gen_config.seed) - torch.cuda.manual_seed(gen_config.seed) - - generator = torch.manual_seed(gen_config.seed) - - if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ - and gen_config.adapter_image_path is not None: - # run through the adapter to saturate the embeds - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( - validation_image) - self.adapter(conditional_clip_embeds) - - if self.adapter is not None and isinstance(self.adapter, CustomAdapter): - # handle condition the prompts - gen_config.prompt = self.adapter.condition_prompt( - gen_config.prompt, - is_unconditional=False, - ) - gen_config.prompt_2 = gen_config.prompt - gen_config.negative_prompt = self.adapter.condition_prompt( - gen_config.negative_prompt, - is_unconditional=True, - ) - gen_config.negative_prompt_2 = gen_config.negative_prompt - - if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: - self.adapter.trigger_pre_te( - tensors_0_1=validation_image, - is_training=False, - has_been_preprocessed=False, - quad_count=4 - ) - - if self.sample_prompts_cache is not None: - conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype) - unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) - else: - ctrl_img = None - has_control_images = False - if gen_config.ctrl_img is not None or gen_config.ctrl_img_1 is not None or gen_config.ctrl_img_2 is not None or gen_config.ctrl_img_3 is not None: - has_control_images = True - # load the control image if out model uses it in text encoding - if has_control_images and self.encode_control_in_text_embeddings: - ctrl_img_list = [] + ctrl_img = None + has_control_images = False + if gen_config.ctrl_img is not None or gen_config.ctrl_img_1 is not None or gen_config.ctrl_img_2 is not None or gen_config.ctrl_img_3 is not None: + has_control_images = True + # load the control image if out model uses it in text encoding + if has_control_images and self.encode_control_in_text_embeddings: + ctrl_img_list = [] - if gen_config.ctrl_img is not None: - ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.device_torch, dtype=self.torch_dtype) - ) - ctrl_img_list.append(ctrl_img) + if gen_config.ctrl_img is not None: + ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img) - if gen_config.ctrl_img_1 is not None: - ctrl_img_1 = Image.open(gen_config.ctrl_img_1).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img_1 = ( - TF.to_tensor(ctrl_img_1) - .unsqueeze(0) - .to(self.device_torch, dtype=self.torch_dtype) - ) - ctrl_img_list.append(ctrl_img_1) - if gen_config.ctrl_img_2 is not None: - ctrl_img_2 = Image.open(gen_config.ctrl_img_2).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img_2 = ( - TF.to_tensor(ctrl_img_2) - .unsqueeze(0) - .to(self.device_torch, dtype=self.torch_dtype) - ) - ctrl_img_list.append(ctrl_img_2) - if gen_config.ctrl_img_3 is not None: - ctrl_img_3 = Image.open(gen_config.ctrl_img_3).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img_3 = ( - TF.to_tensor(ctrl_img_3) - .unsqueeze(0) - .to(self.device_torch, dtype=self.torch_dtype) - ) - ctrl_img_list.append(ctrl_img_3) + if gen_config.ctrl_img_1 is not None: + ctrl_img_1 = Image.open(gen_config.ctrl_img_1).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_1 = ( + TF.to_tensor(ctrl_img_1) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_1) + if gen_config.ctrl_img_2 is not None: + ctrl_img_2 = Image.open(gen_config.ctrl_img_2).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_2 = ( + TF.to_tensor(ctrl_img_2) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_2) + if gen_config.ctrl_img_3 is not None: + ctrl_img_3 = Image.open(gen_config.ctrl_img_3).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_3 = ( + TF.to_tensor(ctrl_img_3) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_3) - if self.has_multiple_control_images: - ctrl_img = ctrl_img_list - else: - ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None - # encode the prompt ourselves so we can do fun stuff with embeddings - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = False - conditional_embeds = self.encode_prompt( - gen_config.prompt, - gen_config.prompt_2, - force_all=True, - control_images=ctrl_img + if self.has_multiple_control_images: + ctrl_img = ctrl_img_list + else: + ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None + # encode the prompt ourselves so we can do fun stuff with embeddings + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.encode_prompt( + gen_config.prompt, + gen_config.prompt_2, + force_all=True, + control_images=ctrl_img + ) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.encode_prompt( + gen_config.negative_prompt, + gen_config.negative_prompt_2, + force_all=True, + control_images=ctrl_img + ) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + # allow any manipulations to take place to embeddings + gen_config.post_process_embeddings( + conditional_embeds, + unconditional_embeds, ) - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = True - unconditional_embeds = self.encode_prompt( - gen_config.negative_prompt, - gen_config.negative_prompt_2, - force_all=True, - control_images=ctrl_img + if self.decorator is not None: + # apply the decorator to the embeddings + conditional_embeds.text_embeds = self.decorator( + conditional_embeds.text_embeds) + unconditional_embeds.text_embeds = self.decorator( + unconditional_embeds.text_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ + and gen_config.adapter_image_path is not None: + # apply the image projection + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + validation_image) + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, + True) + conditional_embeds = self.adapter( + conditional_embeds, conditional_clip_embeds, is_unconditional=False) + unconditional_embeds = self.adapter( + unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=conditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_generating_samples=True, + ) + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=unconditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=True, + is_generating_samples=True, + ) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( + gen_config.extra_values) > 0: + extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, + dtype=self.torch_dtype) + # apply extra values to the embeddings + self.adapter.add_extra_values( + extra_values, is_unconditional=False) + self.adapter.add_extra_values(torch.zeros_like( + extra_values), is_unconditional=True) + pass # todo remove, for debugging + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: + # if we have a refiner loaded, set the denoising end at the refiner start + extra['denoising_end'] = gen_config.refiner_start_at + extra['output_type'] = 'latent' + if not self.is_xl: + raise ValueError( + "Refiner is only supported for XL models") + + conditional_embeds = conditional_embeds.to( + self.device_torch, dtype=self.unet.dtype) + unconditional_embeds = unconditional_embeds.to( + self.device_torch, dtype=self.unet.dtype) + + img = self.generate_single_image( + pipeline, + gen_config, + conditional_embeds, + unconditional_embeds, + generator, + extra, ) - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = False - # allow any manipulations to take place to embeddings - gen_config.post_process_embeddings( - conditional_embeds, - unconditional_embeds, - ) - - if self.decorator is not None: - # apply the decorator to the embeddings - conditional_embeds.text_embeds = self.decorator( - conditional_embeds.text_embeds) - unconditional_embeds.text_embeds = self.decorator( - unconditional_embeds.text_embeds, is_unconditional=True) - - if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ - and gen_config.adapter_image_path is not None: - # apply the image projection - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( - validation_image) - unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, - True) - conditional_embeds = self.adapter( - conditional_embeds, conditional_clip_embeds, is_unconditional=False) - unconditional_embeds = self.adapter( - unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) - - if self.adapter is not None and isinstance(self.adapter, CustomAdapter): - conditional_embeds = self.adapter.condition_encoded_embeds( - tensors_0_1=validation_image, - prompt_embeds=conditional_embeds, - is_training=False, - has_been_preprocessed=False, - is_generating_samples=True, - ) - unconditional_embeds = self.adapter.condition_encoded_embeds( - tensors_0_1=validation_image, - prompt_embeds=unconditional_embeds, - is_training=False, - has_been_preprocessed=False, - is_unconditional=True, - is_generating_samples=True, - ) - - if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( - gen_config.extra_values) > 0: - extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, - dtype=self.torch_dtype) - # apply extra values to the embeddings - self.adapter.add_extra_values( - extra_values, is_unconditional=False) - self.adapter.add_extra_values(torch.zeros_like( - extra_values), is_unconditional=True) - pass # todo remove, for debugging - - if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: - # if we have a refiner loaded, set the denoising end at the refiner start - extra['denoising_end'] = gen_config.refiner_start_at - extra['output_type'] = 'latent' - if not self.is_xl: - raise ValueError( - "Refiner is only supported for XL models") - - conditional_embeds = conditional_embeds.to( - self.device_torch, dtype=self.unet.dtype) - unconditional_embeds = unconditional_embeds.to( - self.device_torch, dtype=self.unet.dtype) - - img = self.generate_single_image( - pipeline, - gen_config, - conditional_embeds, - unconditional_embeds, - generator, - extra, - ) - - gen_config.save_image(img, i) - gen_config.log_image(img, i) - self._after_sample_image(i, len(image_configs)) - flush() + gen_config.save_image(img, i) + gen_config.log_image(img, i) + self._after_sample_image(i, len(image_configs)) + flush() if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): self.adapter.clear_memory() - # Unload sampling transformer from GPU and restore main model to device - if self._sampling_transformer is not None: - self._sampling_transformer.to("cpu") - self.model.to(self.device_torch, dtype=self.torch_dtype) - - # clear pipeline and cache to reduce vram usage - del pipeline - torch.cuda.empty_cache() + finally: + # Unload sampling transformer from GPU and restore main model to device + if self._sampling_transformer is not None: + self._sampling_transformer.to("cpu") + self.model.to(self.device_torch, dtype=self.torch_dtype) + if is_debug_enabled(): + print_acc("Unloaded sampling transformer to CPU") + # Ensure CUDA work is finished so VRAM is actually released before next use + if torch.cuda.is_available(): + torch.cuda.synchronize() + # Clear pipeline and cache to reduce vram usage (only if we created pipeline here) + if pipeline_created: + try: + del pipeline + except NameError: + pass + torch.cuda.empty_cache() # restore training state torch.set_rng_state(rng_state) @@ -695,7 +707,6 @@ def generate_images( network.train() network.multiplier = start_multiplier - self.unet.to(self.device_torch, dtype=self.torch_dtype) if network.is_merged_in: network.merge_out(merge_multiplier) # self.tokenizer.to(original_device_dict['tokenizer']) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 1ce7a16bf..e27e83f28 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -713,9 +713,10 @@ def load_weights(self: Network, file, force_weight_mapping=False): return self.load_weights(file, force_weight_mapping=True) info = self.load_state_dict(load_sd, False) - if len(extra_dict.keys()) == 0: - extra_dict = None - return extra_dict + del load_sd + if isinstance(file, str): + del weights_sd + return None @torch.no_grad() def _update_torch_multiplier(self: Network):