diff --git a/.gitignore b/.gitignore index 04c233ac4..2a618d3f2 100644 --- a/.gitignore +++ b/.gitignore @@ -181,4 +181,8 @@ cython_debug/ ._.DS_Store aitk_db.db /notes.md -/data \ No newline at end of file +/data + +# ide +.cursor +.continue diff --git a/config/examples/WARMUP_SCHEDULER_GUIDE.md b/config/examples/WARMUP_SCHEDULER_GUIDE.md new file mode 100644 index 000000000..3f9a195fe --- /dev/null +++ b/config/examples/WARMUP_SCHEDULER_GUIDE.md @@ -0,0 +1,299 @@ +# Learning Rate Scheduler Warmup Guide + +## Overview + +This guide explains how to use the warmup functionality for learning rate schedulers in the AI Toolkit. Warmup gradually increases the learning rate from near-zero to the target learning rate over a specified number of steps, which can help stabilize training in the early stages. + +## Supported Schedulers + +Warmup is supported for the following schedulers: +- `cosine` - Cosine annealing scheduler +- `cosine_with_restarts` - Cosine annealing with warm restarts (SGDR) + +## How It Works + +When you specify `warmup_steps > 0`, the scheduler automatically creates a composite scheduler using PyTorch's `SequentialLR`: + +1. **Warmup Phase** (steps 0 to `warmup_steps`): + - Uses `LinearLR` to gradually increase learning rate from ~0 to the target LR + - Learning rate increases linearly: `lr = target_lr * (current_step / warmup_steps)` + +2. **Main Phase** (steps `warmup_steps` to `total_steps`): + - Uses the specified scheduler (cosine or cosine_with_restarts) + - `total_iters` specifies the TOTAL number of training iterations (including warmup) + - `T_0`/`T_max` specify iterations for the MAIN scheduler phase (after warmup) + - If T_0/T_max is specified, it takes priority over calculated value from total_iters + - Main scheduler iterations = total_iters - warmup_steps (or T_0/T_max if specified) + +### Learning Rate Progression + +``` +LR | + | _/--\ /\ /\ + | / \ / \ / \ + | / \ / \ / \ + | / \/ \ / \ + |/ \/ \ + +----------------------------> steps + |<-warmup->|<-- cosine restarts --> +``` + +## Parameter Semantics + +### Key concepts + +- **`total_iters`**: TOTAL training iterations (including warmup) +- **`T_0`/`T_max`**: Main scheduler iterations (after warmup), overrides calculation from total_iters + +### Example: Training with 1000 total steps + +**Config 1: Using total_iters (automatic calculation)** + +```yaml +train: + steps: 1000 # Will be used as total_iters by BaseSDTrainProcess + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + warmup_steps: 100 + # total_iters will be auto-set to 1000 by BaseSDTrainProcess + T_mult: 2 +``` + +**Result:** +- Steps 0-100: Linear warmup +- Steps 100-1000: Cosine with restarts (900 iterations = 1000 - 100) +- Total: 1000 steps ✓ + +**Config 2: Using T_0 explicitly (overrides calculation)** + +```yaml +train: + steps: 1000 + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + warmup_steps: 100 + total_iters: 1000 + T_0: 500 # Overrides default calculation (1000 - 100) + T_mult: 2 +``` + +**Result:** +- Steps 0-100: Linear warmup +- Steps 100-600: Cosine with restarts (500 iterations from T_0) +- Total: 600 steps (less than total_iters!) + +**Priority:** If both `T_0` and `total_iters` are specified, `T_0` takes priority and determines main scheduler length. + +## Configuration Examples + +### Example 1: Cosine with Restarts + Warmup + +```yaml +train: + steps: 2000 + lr: 1e-4 + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + warmup_steps: 100 # Warmup for first 100 steps + T_mult: 2 # Double restart period each cycle + eta_min: 1e-7 # Minimum learning rate +``` + +### Example 2: Cosine + Warmup + +```yaml +train: + steps: 2000 + lr: 1e-4 + lr_scheduler: "cosine" + lr_scheduler_params: + warmup_steps: 100 # Warmup for first 100 steps + eta_min: 1e-7 # Minimum learning rate +``` + +### Example 3: Without Warmup (Backward Compatible) + +```yaml +train: + steps: 2000 + lr: 1e-4 + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + T_mult: 2 + eta_min: 1e-7 + # No warmup_steps specified = no warmup +``` + +### Example 4: Explicitly Disable Warmup + +```yaml +train: + steps: 2000 + lr: 1e-4 + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + warmup_steps: 0 # Explicitly disable warmup + T_mult: 2 + eta_min: 1e-7 +``` + +## Parameters + +### Common Parameters + +- **`warmup_steps`** (optional, default: 0) + - Number of steps for the warmup phase + - Set to 0 or omit to disable warmup + - Typical values: 50-500 depending on total training steps + - Rule of thumb: 5-10% of total steps + +### Cosine Scheduler Parameters + +- **`eta_min`** (optional, default: 0) + - Minimum learning rate + +### Cosine with Restarts Parameters + +- **`T_mult`** (optional, default: 1) + - Factor to increase the restart period after each restart + - `T_mult=1`: equal restart periods + - `T_mult=2`: double the period each time + +- **`eta_min`** (optional, default: 0) + - Minimum learning rate + +## Choosing Warmup Steps + +### General Guidelines + +1. **Small datasets (< 1000 images)** + - Warmup steps: 50-100 + - Helps prevent overfitting to early batches + +2. **Medium datasets (1000-10000 images)** + - Warmup steps: 100-250 + - Balances stability and training time + +3. **Large datasets (> 10000 images)** + - Warmup steps: 250-500 + - More warmup helps with stability + +4. **Percentage-based approach** + - Use 5-10% of total training steps + - Example: 2000 steps → 100-200 warmup steps + +### When to Use Warmup + +✅ **Use warmup when:** +- Training from scratch or with random initialization +- Using high learning rates (> 1e-4) +- Experiencing unstable early training +- Training large models or with large batch sizes +- Using aggressive optimizers (Adam with high β values) + +❌ **Skip warmup when:** +- Fine-tuning from a well-trained checkpoint +- Using very low learning rates (< 1e-5) +- Training is already stable without warmup +- Total training steps are very small (< 500) + +## Implementation Details + +### Under the Hood + +The implementation uses PyTorch's built-in schedulers: + +```python +# Warmup phase +warmup_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1e-10, # Start almost at 0 + end_factor=1.0, # End at full LR + total_iters=warmup_steps +) + +# Main phase (example for cosine_with_restarts) +main_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0=total_iters - warmup_steps, # Calculated from total iterations + # Or explicitly specify main scheduler iterations: + # T_0=900, # Direct specification (ignores total_iters calculation) + T_mult=2, + eta_min=1e-7 +) + +# Combined scheduler +combined_scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_scheduler, main_scheduler], + milestones=[warmup_steps] +) +``` + +### Backward Compatibility + +- All existing configurations continue to work without changes +- Warmup is only activated when `warmup_steps > 0` is explicitly specified +- Default behavior (no warmup) is preserved when `warmup_steps` is not specified + +## Testing + +A test script is provided to verify the warmup functionality: + +```bash +# Activate your virtual environment +.\venv\Scripts\activate.ps1 # Windows PowerShell +# or +source venv/bin/activate # Linux/Mac + +# Run tests +python test_scheduler_warmup.py +``` + +The test script verifies: +1. Backward compatibility (schedulers work without warmup) +2. Warmup functionality (SequentialLR is created when warmup_steps > 0) +3. Learning rate progression (LR increases during warmup, then follows main scheduler) + +## Full Configuration Example + +See `config/examples/train_lora_flux_with_warmup.yaml` for a complete working example. + +## Troubleshooting + +### Issue: Learning rate doesn't increase during warmup + +**Solution:** Make sure `warmup_steps` is specified in `lr_scheduler_params`, not at the top level of `train`. + +```yaml +# ❌ Wrong +train: + warmup_steps: 100 + lr_scheduler_params: + T_mult: 2 + +# ✅ Correct +train: + lr_scheduler_params: + warmup_steps: 100 + T_mult: 2 +``` + +### Issue: Training is unstable even with warmup + +**Possible solutions:** +1. Increase `warmup_steps` (try 10-15% of total steps) +2. Reduce base learning rate +3. Use gradient clipping (`max_grad_norm`) +4. Reduce batch size + +### Issue: Warmup takes too long + +**Solution:** Reduce `warmup_steps`. Remember, warmup is just the initial phase. If warmup is more than 10-15% of total training, it might be too long. + +## References + +- [PyTorch CosineAnnealingLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html) +- [PyTorch CosineAnnealingWarmRestarts](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html) +- [PyTorch SequentialLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.SequentialLR.html) +- [SGDR: Stochastic Gradient Descent with Warm Restarts](https://arxiv.org/abs/1608.03983) diff --git a/config/examples/train_lora_flux_with_warmup.yaml b/config/examples/train_lora_flux_with_warmup.yaml new file mode 100644 index 000000000..18a473141 --- /dev/null +++ b/config/examples/train_lora_flux_with_warmup.yaml @@ -0,0 +1,79 @@ +--- +# Example configuration demonstrating warmup support for cosine_with_restarts scheduler +# This is based on train_lora_flux_24gb.yaml with added warmup configuration + +job: extension +config: + name: "flux_lora_with_warmup_v1" + process: + - type: 'sd_trainer' + training_folder: "output" + device: cuda:0 + + network: + type: "lora" + linear: 16 + linear_alpha: 16 + + save: + dtype: float16 + save_every: 250 + max_step_saves_to_keep: 4 + push_to_hub: false + + datasets: + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 + shuffle_tokens: false + cache_latents_to_disk: true + resolution: [ 512, 768, 1024 ] + + train: + batch_size: 1 + steps: 2000 + gradient_accumulation_steps: 1 + train_unet: true + train_text_encoder: false + gradient_checkpointing: true + noise_scheduler: "flowmatch" + optimizer: "adamw8bit" + lr: 1e-4 + + # Learning rate scheduler with warmup + lr_scheduler: "cosine_with_restarts" + lr_scheduler_params: + warmup_steps: 100 # Linear warmup for first 100 steps (0 → lr) + T_mult: 2 # Double the restart period each time + eta_min: 1e-7 # Minimum learning rate + + # Alternative: cosine scheduler with warmup + # lr_scheduler: "cosine" + # lr_scheduler_params: + # warmup_steps: 100 # Linear warmup for first 100 steps + # eta_min: 1e-7 # Minimum learning rate + + # For backward compatibility (no warmup): + # lr_scheduler: "cosine_with_restarts" + # lr_scheduler_params: + # T_mult: 2 + # eta_min: 1e-7 + + model: + name_or_path: "black-forest-labs/FLUX.1-dev" + is_flux: true + quantize: true + + sample: + sampler: "flowmatch" + sample_every: 250 + width: 1024 + height: 1024 + prompts: + - "a photo of a cat" + - "a photo of a dog" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4 + sample_steps: 20 diff --git a/extensions_built_in/diffusion_models/ltx2/ltx2.py b/extensions_built_in/diffusion_models/ltx2/ltx2.py index a6d1e6dc7..36d32787f 100644 --- a/extensions_built_in/diffusion_models/ltx2/ltx2.py +++ b/extensions_built_in/diffusion_models/ltx2/ltx2.py @@ -18,6 +18,7 @@ from toolkit.accelerator import unwrap_model from optimum.quanto import freeze from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.util.device import safe_module_to_device from toolkit.memory_management import MemoryManager from safetensors.torch import load_file from PIL import Image @@ -570,8 +571,8 @@ def generate_single_image( generator: torch.Generator, extra: dict, ): - if self.model.device == torch.device("cpu"): - self.model.to(self.device_torch) + if self.low_vram and self.model.device == torch.device("cpu"): + safe_module_to_device(self.model, self.device_torch) # handle control image if gen_config.ctrl_img is not None: @@ -771,9 +772,9 @@ def get_noise_prediction( **kwargs, ): with torch.no_grad(): - if self.model.device == torch.device("cpu"): - self.model.to(self.device_torch) - + if self.low_vram and self.model.device == torch.device("cpu"): + safe_module_to_device(self.model, self.device_torch) + # We only encode and store the minimum prompt tokens, but need them padded to 1024 for LTX2 text_embeddings = self.pad_embeds(text_embeddings) diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index 6e6813ffb..20772a29d 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -15,6 +15,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model from optimum.quanto import freeze, QTensor from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.util.device import safe_module_to_device import torch.nn.functional as F from toolkit.memory_management import MemoryManager from safetensors.torch import load_file @@ -262,7 +263,8 @@ def generate_single_image( control_img = control_img.resize( (gen_config.width, gen_config.height), Image.BILINEAR ) - self.model.to(self.device_torch) + if self.model_config.low_vram: + safe_module_to_device(self.model, self.device_torch) # flush for low vram if we are doing that flush_between_steps = self.model_config.low_vram @@ -305,7 +307,8 @@ def get_noise_prediction( text_embeddings: PromptEmbeds, **kwargs, ): - self.model.to(self.device_torch) + if self.model_config.low_vram: + safe_module_to_device(self.model, self.device_torch) batch_size, num_channels_latents, height, width = latent_model_input.shape ps = self.transformer.config.patch_size diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py index bcc8d735c..328cb81f9 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py @@ -16,6 +16,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model from optimum.quanto import freeze, QTensor from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.util.device import safe_module_to_device import torch.nn.functional as F from diffusers import ( @@ -89,7 +90,8 @@ def generate_single_image( generator: torch.Generator, extra: dict, ): - self.model.to(self.device_torch, dtype=self.torch_dtype) + if self.model_config.low_vram: + safe_module_to_device(self.model, self.device_torch, self.torch_dtype) sc = self.get_bucket_divisibility() gen_config.width = int(gen_config.width // sc * sc) gen_config.height = int(gen_config.height // sc * sc) diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py index 8272ee464..dfe772575 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py @@ -16,6 +16,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model from optimum.quanto import freeze, QTensor from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.util.device import safe_module_to_device import torch.nn.functional as F from diffusers import ( @@ -97,7 +98,8 @@ def generate_single_image( generator: torch.Generator, extra: dict, ): - self.model.to(self.device_torch, dtype=self.torch_dtype) + if self.model_config.low_vram: + safe_module_to_device(self.model, self.device_torch, self.torch_dtype) sc = self.get_bucket_divisibility() gen_config.width = int(gen_config.width // sc * sc) gen_config.height = int(gen_config.height // sc * sc) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index a32183cec..9205d7d8a 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -16,6 +16,7 @@ CustomFlowMatchEulerDiscreteScheduler, ) from toolkit.util.quantize import quantize_model +from toolkit.util.device import safe_module_to_device from .wan22_pipeline import Wan22Pipeline from diffusers import WanTransformer3DModel @@ -132,20 +133,23 @@ def forward( # todo swap the loras as well if t_name != self._active_transformer_name: if self.low_vram: - getattr(self, self._active_transformer_name).to("cpu") - getattr(self, t_name).to(self.device_torch) + safe_module_to_device( + getattr(self, self._active_transformer_name), torch.device("cpu") + ) + safe_module_to_device(getattr(self, t_name), self.device_torch) torch.cuda.empty_cache() self._active_transformer_name = t_name - if self.transformer.device != hidden_states.device: - if self.low_vram: - # move other transformer to cpu - other_tname = ( - "transformer_1" if t_name == "transformer_2" else "transformer_2" - ) - getattr(self, other_tname).to("cpu") + if self.low_vram and self.transformer.device != hidden_states.device: + # move other transformer to cpu + other_tname = ( + "transformer_1" if t_name == "transformer_2" else "transformer_2" + ) + safe_module_to_device( + getattr(self, other_tname), torch.device("cpu") + ) - self.transformer.to(hidden_states.device) + safe_module_to_device(self.transformer, hidden_states.device) return self.transformer( hidden_states=hidden_states, diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py index 368ae9e7c..7fc257757 100644 --- a/extensions_built_in/diffusion_models/z_image/z_image.py +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -1,5 +1,6 @@ import os -from typing import List, Optional +import time +from typing import List, Optional, Tuple import huggingface_hub import torch @@ -14,9 +15,13 @@ ) from toolkit.accelerator import unwrap_model from optimum.quanto import freeze +from toolkit.util.debug import memory_debug from toolkit.util.quantize import quantize, get_qtype, quantize_model +from toolkit.util.device import safe_module_to_device from toolkit.memory_management import MemoryManager +from toolkit.paths import normalize_path from safetensors.torch import load_file +import safetensors from transformers import AutoTokenizer, Qwen3ForCausalLM from diffusers import AutoencoderKL @@ -36,6 +41,9 @@ "shift": 3.0, } +# Set when safetensors patches are applied for debug_zimage_load (avoid double-wrap) +_zimage_load_debug_patched = False + class ZImageModel(BaseModel): arch = "zimage" @@ -147,11 +155,46 @@ def load_training_adapter(self, transformer: ZImageTransformer2DModel): # tell the model to invert assistant on inference since we want remove lora effects self.invert_assistant_lora = True - def load_model(self): + def _load_sampling_transformer(self) -> Optional[ZImageTransformer2DModel]: + """Load a separate transformer for inference/sampling when sampling_name_or_path is set. + Returns None if not configured or when sampling path equals the training model path. + """ + if ( + self.model_config.sampling_name_or_path is None + or self.model_config.sampling_name_or_path == self.model_config.name_or_path + ): + return None dtype = self.torch_dtype - self.print_and_status_update("Loading ZImage model") - model_path = self.model_config.name_or_path - base_model_path = self.model_config.extras_name_or_path + self.print_and_status_update("Loading sampling transformer") + sampling_model_path = normalize_path(self.model_config.sampling_name_or_path) + sampling_transformer_path = sampling_model_path + sampling_transformer_subfolder = "transformer" + + if os.path.exists(sampling_model_path): + sampling_transformer_subfolder = None + sampling_transformer_path = os.path.join(sampling_model_path, "transformer") + + sampling_transformer = ZImageTransformer2DModel.from_pretrained( + sampling_transformer_path, + subfolder=sampling_transformer_subfolder, + torch_dtype=dtype, + device_map="cpu", + ) + if self.model_config.quantize: + self.print_and_status_update("Quantizing sampling transformer") + quantize_model(self, sampling_transformer) + flush() + # Already on CPU via device_map="cpu"; ensure it stays there + sampling_transformer.to("cpu") + flush() + return sampling_transformer + + def _load_transformer(self, model_path: str) -> Tuple[ZImageTransformer2DModel, str]: + """Load the training transformer and resolve base_model_path for VAE/text_encoder. + Returns (transformer, base_model_path). base_model_path is used for tokenizer, text_encoder, VAE. + """ + dtype = self.torch_dtype + base_model_path = normalize_path(self.model_config.extras_name_or_path) self.print_and_status_update("Loading transformer") @@ -160,20 +203,19 @@ def load_model(self): if os.path.exists(transformer_path): transformer_subfolder = None transformer_path = os.path.join(transformer_path, "transformer") - # check if the path is a full checkpoint. + # Check if the path is a full checkpoint (contains text_encoder, vae, etc.) te_folder_path = os.path.join(model_path, "text_encoder") - # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): base_model_path = model_path transformer = ZImageTransformer2DModel.from_pretrained( transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype ) - - # load assistant lora if specified + # Load assistant LoRA if specified (e.g. for ARA-style training) if self.model_config.assistant_lora_path is not None: self.load_training_adapter(transformer) - # set qtype to be float8 if it is qfloat8 + # Set qtype to float8 if it is qfloat8 (quantize compatibility) if self.model_config.qtype == "qfloat8": self.model_config.qtype = "float8" @@ -201,93 +243,165 @@ def load_model(self): transformer.to("cpu") flush() + return transformer, base_model_path - self.print_and_status_update("Text Encoder") - tokenizer = AutoTokenizer.from_pretrained( - base_model_path, subfolder="tokenizer", torch_dtype=dtype - ) - text_encoder = Qwen3ForCausalLM.from_pretrained( - base_model_path, subfolder="text_encoder", torch_dtype=dtype - ) + def load_model(self): + global _zimage_load_debug_patched + dtype = self.torch_dtype + self.print_and_status_update("Loading ZImage model") + model_path = normalize_path(self.model_config.name_or_path) + + if self.model_config.debug_zimage_load and not _zimage_load_debug_patched: + log_func = self.print_and_status_update + orig_load_file = safetensors.torch.load_file + orig_safe_open = safetensors.safe_open + + def _debug_path_and_size(path): + if path is not None and isinstance(path, (str, os.PathLike)) and os.path.isfile(path): + path_str = os.path.abspath(path) + try: + size = os.path.getsize(path) + return path_str, f" size={size}" + except OSError: + return path_str, "" + path_str = repr(path) if path else "" + return path_str, "" + + def _wrapped_load_file(*args, **kwargs): + path = args[0] if args else kwargs.get("filename") or kwargs.get("path") + path_str, size_str = _debug_path_and_size(path) + start = time.perf_counter() + result = orig_load_file(*args, **kwargs) + duration = time.perf_counter() - start + log_func(f"[ZImage debug] load_file path={path_str}{size_str} duration={duration:.3f}s") + return result + + def _wrapped_safe_open(*args, **kwargs): + path = args[0] if args else kwargs.get("filename") or kwargs.get("path") + path_str, size_str = _debug_path_and_size(path) + start = time.perf_counter() + result = orig_safe_open(*args, **kwargs) + + class _WrappedCtx: + def __enter__(_self): + return result.__enter__() + + def __exit__(_self, *exc): + duration = time.perf_counter() - start + log_func(f"[ZImage debug] safe_open path={path_str}{size_str} duration={duration:.3f}s") + return result.__exit__(*exc) + + return _WrappedCtx() + + safetensors.torch.load_file = _wrapped_load_file + safetensors.safe_open = _wrapped_safe_open + _zimage_load_debug_patched = True + + if self.model_config.debug_zimage_load: + self.print_and_status_update("[ZImage debug] === Loading sampling transformer (from_pretrained) ===") + # Load sampling transformer first (if configured) to control peak VRAM + with memory_debug(self.print_and_status_update, "Loading sampling transformer"): + self._sampling_transformer = self._load_sampling_transformer() + if self.model_config.debug_zimage_load: + self.print_and_status_update("[ZImage debug] === Loading main transformer (from_pretrained) ===") + with memory_debug(self.print_and_status_update, "Loading transformer"): + transformer, base_model_path = self._load_transformer(model_path) - if ( - self.model_config.layer_offloading - and self.model_config.layer_offloading_text_encoder_percent > 0 - ): - MemoryManager.attach( - text_encoder, - self.device_torch, - offload_percent=self.model_config.layer_offloading_text_encoder_percent, + self.print_and_status_update("Text Encoder") + with memory_debug(self.print_and_status_update, "Text Encoder"): + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", dtype=dtype + ) + text_encoder = Qwen3ForCausalLM.from_pretrained( + base_model_path, subfolder="text_encoder", dtype=dtype ) - text_encoder.to(self.device_torch, dtype=dtype) - flush() + if ( + self.model_config.layer_offloading + and self.model_config.layer_offloading_text_encoder_percent > 0 + ): + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent, + ) - if self.model_config.quantize_te: - self.print_and_status_update("Quantizing Text Encoder") - quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) - freeze(text_encoder) + text_encoder.to(self.device_torch, dtype=dtype) flush() + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing Text Encoder") + quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te)) + freeze(text_encoder) + flush() + self.print_and_status_update("Loading VAE") - vae = AutoencoderKL.from_pretrained( - base_model_path, subfolder="vae", torch_dtype=dtype - ) + with memory_debug(self.print_and_status_update, "Loading VAE"): + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype + ) self.noise_scheduler = ZImageModel.get_train_scheduler() self.print_and_status_update("Making pipe") - - kwargs = {} - - pipe: ZImagePipeline = ZImagePipeline( - scheduler=self.noise_scheduler, - text_encoder=None, - tokenizer=tokenizer, - vae=vae, - transformer=None, - **kwargs, - ) - # for quantization, it works best to do these after making the pipe - pipe.text_encoder = text_encoder - pipe.transformer = transformer + with memory_debug(self.print_and_status_update, "Making pipe"): + kwargs = {} + + pipe: ZImagePipeline = ZImagePipeline( + scheduler=self.noise_scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + **kwargs, + ) + # for quantization, it works best to do these after making the pipe + pipe.text_encoder = text_encoder + pipe.transformer = transformer self.print_and_status_update("Preparing Model") + with memory_debug(self.print_and_status_update, "Preparing Model"): + text_encoder = [pipe.text_encoder] + tokenizer = [pipe.tokenizer] - text_encoder = [pipe.text_encoder] - tokenizer = [pipe.tokenizer] - - # leave it on cpu for now - if not self.low_vram: - pipe.transformer = pipe.transformer.to(self.device_torch) + # leave it on cpu for now + if not self.low_vram: + pipe.transformer = pipe.transformer.to(self.device_torch) - flush() - # just to make sure everything is on the right device and dtype - text_encoder[0].to(self.device_torch) - text_encoder[0].requires_grad_(False) - text_encoder[0].eval() - flush() + flush() + # just to make sure everything is on the right device and dtype + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + flush() # save it to the model class - self.vae = vae - self.text_encoder = text_encoder # list of text encoders - self.tokenizer = tokenizer # list of tokenizers - self.model = pipe.transformer - self.pipeline = pipe - self.print_and_status_update("Model Loaded") + with memory_debug(self.print_and_status_update, "Model Loaded"): + self.vae = vae + self.text_encoder = text_encoder # list of text encoders + self.tokenizer = tokenizer # list of tokenizers + self.model = pipe.transformer + self.pipeline = pipe + self.print_and_status_update("Model Loaded") def get_generation_pipeline(self): scheduler = ZImageModel.get_train_scheduler() + # Determine which transformer to use for sampling + transformer_for_sampling = self.transformer + if self._sampling_transformer is not None: + # Use sampling transformer, keep it on CPU (for testing / avoid OOM) + transformer_for_sampling = self._sampling_transformer + pipeline: ZImagePipeline = ZImagePipeline( scheduler=scheduler, text_encoder=unwrap_model(self.text_encoder[0]), tokenizer=self.tokenizer[0], vae=unwrap_model(self.vae), - transformer=unwrap_model(self.transformer), + transformer=unwrap_model(transformer_for_sampling), ) - pipeline = pipeline.to(self.device_torch) + # pipeline = pipeline.to(self.device_torch) return pipeline @@ -300,15 +414,36 @@ def generate_single_image( generator: torch.Generator, extra: dict, ): - self.model.to(self.device_torch, dtype=self.torch_dtype) - self.model.to(self.device_torch) - sc = self.get_bucket_divisibility() gen_config.width = int(gen_config.width // sc * sc) gen_config.height = int(gen_config.height // sc * sc) + + # ZImagePipeline expects prompt_embeds and negative_prompt_embeds to be + # List[torch.FloatTensor] where each element is [seq_len, dim]. + # The pipeline concatenates these lists for CFG (not element-wise add). + cond_embeds = conditional_embeds.text_embeds + uncond_embeds = unconditional_embeds.text_embeds + + # Convert rank-3 tensors back to list of rank-2 tensors + def to_embed_list(embeds): + if embeds is None: + return [] + if isinstance(embeds, list): + return embeds + if len(embeds.shape) == 3: + # [batch, seq_len, dim] -> list of [seq_len, dim] + return list(embeds.unbind(dim=0)) + elif len(embeds.shape) == 2: + # Already [seq_len, dim], wrap in list + return [embeds] + return embeds + + cond_embeds_list = to_embed_list(cond_embeds) + uncond_embeds_list = to_embed_list(uncond_embeds) + img = pipeline( - prompt_embeds=conditional_embeds.text_embeds, - negative_prompt_embeds=unconditional_embeds.text_embeds, + prompt_embeds=cond_embeds_list, + negative_prompt_embeds=uncond_embeds_list, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, @@ -326,17 +461,33 @@ def get_noise_prediction( text_embeddings: PromptEmbeds, **kwargs, ): - self.model.to(self.device_torch) + if self.low_vram and next(self.model.parameters()).device != self.device_torch: + safe_module_to_device(self.model, self.device_torch) latent_model_input = latent_model_input.unsqueeze(2) latent_model_input_list = list(latent_model_input.unbind(dim=0)) timestep_model_input = (1000 - timestep) / 1000 + text_embeds = text_embeddings.text_embeds + if isinstance(text_embeds, torch.Tensor): + if len(text_embeds.shape) == 3: + # if it is a single batch tensor, unbind it into a list of tensors + text_embeds = list(text_embeds.unbind(dim=0)) + elif isinstance(text_embeds, list): + # check if items are rank 3 (batch, length, dim) + if len(text_embeds[0].shape) == 3: + # flatten the list of batches into a single list of tensors + new_text_embeds = [] + for t in text_embeds: + if t is not None: + new_text_embeds += list(t.unbind(dim=0)) + text_embeds = new_text_embeds + model_out_list = self.transformer( latent_model_input_list, timestep_model_input, - text_embeddings.text_embeds, + text_embeds, )[0] noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) @@ -355,6 +506,24 @@ def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: do_classifier_free_guidance=False, device=self.device_torch, ) + # encode_prompt returns list of rank-2 tensors [seq_len, dim] + # Pad to same length and stack into rank-3 tensor [batch, seq_len, dim] + # for compatibility with concat_prompt_embeds and predict_noise + if isinstance(prompt_embeds, list): + # Find max sequence length, TODO: Or just use 512? + max_seq_len = max(t.shape[0] for t in prompt_embeds) + # Pad each tensor to max length + padded = [] + for t in prompt_embeds: + if t.shape[0] < max_seq_len: + pad = torch.zeros( + (max_seq_len - t.shape[0], t.shape[1]), + dtype=t.dtype, + device=t.device, + ) + t = torch.cat([t, pad], dim=0) + padded.append(t) + prompt_embeds = torch.stack(padded, dim=0) pe = PromptEmbeds([prompt_embeds, None]) return pe @@ -375,6 +544,41 @@ def save_model(self, output_path, meta, save_dtype): with open(meta_path, "w") as f: yaml.dump(meta, f) + def generate_images( + self, + image_configs: List[GenerateImageConfig], + sampler=None, + ): + """ + Override generate_images to handle LoRA on sampling transformer. + When using _sampling_transformer with LoRA, temporarily swap self.network + to _sampling_network (which has LoRA applied to _sampling_transformer) + for the duration of generation. + """ + saved_network = None + try: + # If using sampling transformer with LoRA, swap to sampling network (weights already shared) + if ( + hasattr(self, '_sampling_transformer') + and self._sampling_transformer is not None + and hasattr(self, '_sampling_network') + and self._sampling_network is not None + and hasattr(self, 'network') + and self.network is not None + ): + saved_network = self.network + self.network = self._sampling_network + + # Call parent's generate_images with possibly swapped network + return super().generate_images(image_configs, sampler) + finally: + # Restore original network + if saved_network is not None: + self.network = saved_network + # Ensure sampling transformer is off GPU (defense in depth) + if getattr(self, "_sampling_transformer", None) is not None: + self._sampling_transformer.to("cpu") + def get_loss_target(self, *args, **kwargs): noise = kwargs.get("noise") batch = kwargs.get("batch") diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 152cd131c..650a4abbe 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -6,7 +6,6 @@ import numpy as np from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel -import torch.functional as F from safetensors.torch import load_file from torch.utils.data import DataLoader, ConcatDataset @@ -34,6 +33,7 @@ import math from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe +from toolkit.util.debug import memory_debug from toolkit.util.losses import wavelet_loss, stepped_loss import torch.nn.functional as F from toolkit.unloader import unload_text_encoder @@ -112,6 +112,9 @@ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0]) else: raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}") + + # store differential guidance norm metric for logging + self.diff_guidance_norm = None def before_model_load(self): @@ -341,12 +344,15 @@ def hook_before_train_loop(self): # unload the text encoder if self.is_caching_text_embeddings: - unload_text_encoder(self.sd) + with memory_debug(print_acc, "UNLOAD TEXT ENCODER"): + unload_text_encoder(self.sd) + flush() else: # todo once every model is tested to work, unload properly. Though, this will all be merged into one thing. # keep legacy usage for now. - self.sd.text_encoder_to("cpu") - flush() + with memory_debug(print_acc, "UNLOAD TEXT ENCODER"): + self.sd.text_encoder_to("cpu") + flush() if self.train_config.blank_prompt_preservation and self.cached_blank_embeds is None: # make sure we have this if not unloading @@ -708,11 +714,14 @@ def calculate_loss( unconditional_target = unconditional_target * alpha target = unconditional_target + guidance_scale * (target - unconditional_target) - - if self.train_config.do_differential_guidance: - with torch.no_grad(): - guidance_scale = self.train_config.differential_guidance_scale - target = noise_pred + guidance_scale * (target - noise_pred) + + if self.train_config.do_differential_guidance: + with torch.no_grad(): + guidance_scale = self.train_config.differential_guidance_scale + diff_correction = guidance_scale * (target - noise_pred) + # Calculate L2 norm for monitoring + self.diff_guidance_norm = torch.norm(diff_correction).item() + target = noise_pred + diff_correction if target is None: target = noise @@ -789,6 +798,22 @@ def calculate_loss( elif len(loss.shape) == 5: timestep_weight = timestep_weight.view(-1, 1, 1, 1, 1).detach() loss = loss * timestep_weight + elif self.train_config.content_or_style == 'fixed_cycle' and self.train_config.fixed_cycle_weight_peak_timesteps: + # fixed_cycle: weight loss by Gaussian peaks at fixed_cycle_weight_peak_timesteps (e.g. 500, 375), mean-normalized + peaks = self.train_config.fixed_cycle_weight_peak_timesteps + sigma = self.train_config.fixed_cycle_weight_sigma + peaks_t = torch.tensor(peaks, device=timesteps.device, dtype=timesteps.dtype) + diff = timesteps.unsqueeze(1).float() - peaks_t.unsqueeze(0) + weight_per_timestep = torch.exp(-(diff / sigma) ** 2).max(dim=1)[0] + cycle_ts = torch.tensor(self.train_config.fixed_cycle_timesteps, device=timesteps.device, dtype=timesteps.dtype) + diff_cycle = cycle_ts.unsqueeze(1).float() - peaks_t.unsqueeze(0) + mean_w = torch.exp(-(diff_cycle / sigma) ** 2).max(dim=1)[0].mean().clamp(min=1e-8) + timestep_weight = (weight_per_timestep / mean_w).to(loss.device, dtype=loss.dtype) + if len(loss.shape) == 4: + timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() + elif len(loss.shape) == 5: + timestep_weight = timestep_weight.view(-1, 1, 1, 1, 1).detach() + loss = loss * timestep_weight if self.train_config.do_prior_divergence and prior_pred is not None: loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) @@ -1505,13 +1530,20 @@ def get_adapter_multiplier(): self.device_torch, dtype=dtype ) else: - embeds_to_use = self.cached_blank_embeds.clone().detach().to( - self.device_torch, dtype=dtype - ) + # No cache on disk: use trigger_word instead of blank embedding if self.cached_trigger_embeds is not None and not is_reg: embeds_to_use = self.cached_trigger_embeds.clone().detach().to( self.device_torch, dtype=dtype ) + else: + if self.is_caching_text_embeddings: + raise ValueError( + "cache_text_embeddings is enabled but no cached embeds in batch. " + "Set trigger_word when using cache_text_embeddings so fallback is available when cache is missing." + ) + embeds_to_use = self.cached_blank_embeds.clone().detach().to( + self.device_torch, dtype=dtype + ) conditional_embeds = concat_prompt_embeds( [embeds_to_use] * noisy_latents.shape[0] ) @@ -1983,15 +2015,19 @@ def get_adapter_multiplier(): prior_pred=prior_to_calculate_loss, ) - if self.train_config.diff_output_preservation or self.train_config.blank_prompt_preservation: - # send the loss backwards otherwise checkpointing will fail - self.accelerator.backward(loss) - normal_loss = loss.detach() # dont send backward again - + # Determine if we should run BPP this step + should_run_bpp = False + if self.train_config.blank_prompt_preservation: + should_run_bpp = random.random() < self.train_config.blank_prompt_probability + + # Flag to track if backward was already performed in preservation block + backward_done = False + + if self.train_config.diff_output_preservation or should_run_bpp: with torch.no_grad(): if self.train_config.diff_output_preservation: preservation_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) - elif self.train_config.blank_prompt_preservation: + elif should_run_bpp: blank_embeds = self.cached_blank_embeds.clone().detach().to( self.device_torch, dtype=dtype ) @@ -2008,12 +2044,14 @@ def get_adapter_multiplier(): ) multiplier = self.train_config.diff_output_preservation_multiplier if self.train_config.diff_output_preservation else self.train_config.blank_prompt_preservation_multiplier preservation_loss = torch.nn.functional.mse_loss(preservation_pred, prior_pred) * multiplier - self.accelerator.backward(preservation_loss) - - loss = normal_loss + preservation_loss - loss = loss.clone().detach() - # require grad again so the backward wont fail - loss.requires_grad_(True) + + # Combine main loss and preservation loss with loss_multiplier applied + total_loss = loss * loss_multiplier.mean() + preservation_loss + self.accelerator.backward(total_loss) + + # Detach total loss for logging and nan checks + loss = total_loss.detach() + backward_done = True # check if nan if torch.isnan(loss): @@ -2021,18 +2059,20 @@ def get_adapter_multiplier(): loss = torch.zeros_like(loss).requires_grad_(True) with self.timer('backward'): - # todo we have multiplier seperated. works for now as res are not in same batch, but need to change - loss = loss * loss_multiplier.mean() - # IMPORTANT if gradient checkpointing do not leave with network when doing backward - # it will destroy the gradients. This is because the network is a context manager - # and will change the multipliers back to 0.0 when exiting. They will be - # 0.0 for the backward pass and the gradients will be 0.0 - # I spent weeks on fighting this. DON'T DO IT - # with fsdp_overlap_step_with_backward(): - # if self.is_bfloat: - # loss.backward() - # else: - self.accelerator.backward(loss) + # Only perform backward if not already done in preservation block + if not backward_done: + # todo we have multiplier seperated. works for now as res are not in same batch, but need to change + loss = loss * loss_multiplier.mean() + # IMPORTANT if gradient checkpointing do not leave with network when doing backward + # it will destroy the gradients. This is because the network is a context manager + # and will change the multipliers back to 0.0 when exiting. They will be + # 0.0 for the backward pass and the gradients will be 0.0 + # I spent weeks on fighting this. DON'T DO IT + # with fsdp_overlap_step_with_backward(): + # if self.is_bfloat: + # loss.backward() + # else: + self.accelerator.backward(loss) return loss.detach() # flush() diff --git a/extensions_built_in/sd_trainer/config/train.example.yaml b/extensions_built_in/sd_trainer/config/train.example.yaml index 793d5d55b..4587e786e 100644 --- a/extensions_built_in/sd_trainer/config/train.example.yaml +++ b/extensions_built_in/sd_trainer/config/train.example.yaml @@ -76,6 +76,7 @@ config: log_every: 10 # log every this many steps use_wandb: false # not supported yet verbose: false + debug: false # enable to log CUDA memory around text encoder unload # You can put any information you want here, and it will be saved in the model. # The below is an example, but you can put your grocery list in it if you want. diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 925d34daa..0033e433a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -4,6 +4,7 @@ import json import random import shutil +import time from collections import OrderedDict import os import re @@ -26,7 +27,7 @@ from toolkit.basic import value_map from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.custom_adapter import CustomAdapter -from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch +from toolkit.data_loader import get_dataloader_from_datasets, get_dataloader_datasets, trigger_dataloader_setup_epoch from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.ema import ExponentialMovingAverage from toolkit.embedding import Embedding @@ -71,6 +72,7 @@ import hashlib from toolkit.util.blended_blur_noise import get_blended_blur_noise +from toolkit.util.debug import memory_debug, set_debug_config from toolkit.util.get_model import get_model_class def flush(): @@ -98,6 +100,9 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No self.start_step = 0 self.epoch_num = 0 self.last_save_step = 0 + # For logging timestep debugging + self._collected_indices = [] + self._collected_timesteps = [] # start at 1 so we can do a sample at the start self.grad_accumulation_step = 1 # if true, then we do not do an optimizer step. We are accumulating gradients @@ -127,6 +132,7 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No self.has_first_sample_requested = False self.first_sample_config = self.sample_config self.logging_config = LoggingConfig(**self.get_conf('logging', {})) + set_debug_config(self.logging_config) self.logger = create_logger(self.logging_config, config, self.save_root) self.optimizer: torch.optim.Optimizer = None self.lr_scheduler = None @@ -469,14 +475,29 @@ def clean_up_saves(self): for item in items_to_remove: print_acc(f"Removing old save: {item}") - if os.path.isdir(item): - shutil.rmtree(item) - else: - os.remove(item) - # see if a yaml file with same name exists - yaml_file = os.path.splitext(item)[0] + ".yaml" - if os.path.exists(yaml_file): - os.remove(yaml_file) + try: + if os.path.isdir(item): + shutil.rmtree(item) + else: + os.remove(item) + # see if a yaml file with same name exists + yaml_file = os.path.splitext(item)[0] + ".yaml" + if os.path.exists(yaml_file): + os.remove(yaml_file) + except PermissionError as e: + print_acc(f"Could not remove {item}: {e}. Skipping (file may be in use).") + try: + time.sleep(2) + if os.path.isdir(item): + shutil.rmtree(item) + else: + os.remove(item) + yaml_file = os.path.splitext(item)[0] + ".yaml" + if os.path.exists(yaml_file): + os.remove(yaml_file) + print_acc(f"Retry succeeded: removed {item}") + except PermissionError: + print_acc(f"Retry failed for {item}. Leaving file in place.") if combined_items: latest_item = combined_items[-1] return latest_item @@ -665,14 +686,14 @@ def save(self, step=None): print_acc(f"Saved checkpoint to {file_path}") - # save optimizer + # save optimizer (always unwrap so state matches the optimizer we load into before accelerator.prepare()) if self.optimizer is not None: try: filename = f'optimizer.pt' file_path = os.path.join(self.save_root, filename) try: state_dict = unwrap_model(self.optimizer).state_dict() - except Exception as e: + except Exception: state_dict = self.optimizer.state_dict() torch.save(state_dict, file_path) print_acc(f"Saved optimizer to {file_path}") @@ -854,12 +875,10 @@ def load_training_state_from_metadata(self, path): def load_weights(self, path): if self.network is not None: - extra_weights = self.network.load_weights(path) + self.network.load_weights(path) self.load_training_state_from_metadata(path) - return extra_weights else: print_acc("load_weights not implemented for non-network models") - return None def apply_snr(self, seperated_loss, timesteps): if self.train_config.learnable_snr_gos: @@ -1236,22 +1255,121 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): # for style, it is best to favor later timesteps orig_timesteps = torch.rand((batch_size,), device=latents.device) + ntt = self.train_config.num_train_timesteps if content_or_style == 'content': - timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps + timestep_indices = (1 - orig_timesteps) ** self.train_config.timestep_bias_exponent * self.train_config.num_train_timesteps elif content_or_style == 'style': - timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps + timestep_indices = orig_timesteps ** self.train_config.timestep_bias_exponent * self.train_config.num_train_timesteps timestep_indices = value_map( timestep_indices, 0, - self.train_config.num_train_timesteps - 1, - min_noise_steps, - max_noise_steps + ntt, + ntt - max_noise_steps, + ntt - min_noise_steps ) + timestep_indices = timestep_indices.long().clamp( - min_noise_steps, - max_noise_steps + ntt - max_noise_steps, + ntt - min_noise_steps + ) + + timestep_indices.sort() + + elif content_or_style == 'gaussian': + # Gaussian (normal) distribution with configurable mean and std + # gaussian_mean: controls the center of distribution (default 0.5) + # - Lower values (0.0-0.5): favor earlier timesteps (more noise) + # - Higher values (0.5-1.0): favor later timesteps (less noise) + # gaussian_std: controls the spread of distribution (default 0.2) + # - Affects the shape: smaller = narrower, larger = wider + + # Curriculum learning: linearly interpolate gaussian_std if target is set + if self.train_config.gaussian_std_target is not None: + progress = self.step_num / self.train_config.steps + current_std = self.train_config.gaussian_std + progress * (self.train_config.gaussian_std_target - self.train_config.gaussian_std) + else: + current_std = self.train_config.gaussian_std + + def truncated_normal_samples(batch_size, mu, sigma, device, low=0.0, high=1.0): + # Convert mu and sigma to tensors to ensure all operations are handled by PyTorch + # and to avoid TypeErrors with torch.special functions + mu = torch.as_tensor(mu, device=device, dtype=torch.float32) + sigma = torch.as_tensor(sigma, device=device, dtype=torch.float32) + + # Standardize the boundaries (calculate z-scores) + alpha = (low - mu) / sigma + beta = (high - mu) / sigma + + # Calculate Cumulative Distribution Function (CDF) values for the boundaries + cdf_low = torch.special.ndtr(alpha) + cdf_high = torch.special.ndtr(beta) + + # Generate uniform random samples in the [0, 1] range + u = torch.rand((batch_size,), device=device) + + # Linearly interpolate between the boundary CDF values + v = cdf_low + u * (cdf_high - cdf_low) + + # Clamp values to avoid numerical instability (infinities) when calling ndtri + v = v.clamp(1e-7, 1 - 1e-7) + + # Apply the Inverse CDF (Quantile function) to transform samples to the truncated normal distribution + samples = mu + sigma * torch.special.ndtri(v) + + return samples + + gaussian_samples = truncated_normal_samples( + batch_size, + self.train_config.gaussian_mean, + current_std, + latents.device + ) + + # Scale to num_train_timesteps + timestep_indices = gaussian_samples * self.train_config.num_train_timesteps + + ntt = self.train_config.num_train_timesteps + + # Map to min/max_noise_steps range (same as content/style) + timestep_indices = value_map( + timestep_indices, + 0, + ntt, + ntt - max_noise_steps, + ntt - min_noise_steps + ) + + timestep_indices = timestep_indices.long().clamp( + ntt - max_noise_steps, + ntt - min_noise_steps + ) + + timestep_indices.sort() + + elif content_or_style == 'fixed_cycle': + # Deterministic cycle over fixed timestep values (for Turbo LoRA reproducibility) + timestep_list = self.train_config.fixed_cycle_timesteps + if not timestep_list: + raise ValueError("content_or_style is 'fixed_cycle' but fixed_cycle_timesteps is empty") + resolved = getattr(self, '_fixed_cycle_resolved_timesteps', None) + if resolved is None: + list_copy = list(timestep_list) + if self.train_config.fixed_cycle_seed is not None: + random.Random(self.train_config.fixed_cycle_seed).shuffle(list_copy) + st = self.sd.noise_scheduler.timesteps + resolved = [] + for v in list_copy: + v_t = torch.tensor(v, device=st.device, dtype=st.dtype) + idx = (torch.abs(st - v_t)).argmin().item() + resolved.append(st[idx].item()) + self._fixed_cycle_resolved_timesteps = resolved + idx_cycle = self.step_num % len(resolved) + t_val = resolved[idx_cycle] + st = self.sd.noise_scheduler.timesteps + timesteps = torch.full( + (batch_size,), t_val, device=latents.device, dtype=st.dtype ) elif content_or_style == 'balanced': @@ -1275,8 +1393,71 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): else: raise ValueError(f"Unknown content_or_style {content_or_style}") with self.timer('convert_timestep_indices_to_timesteps'): - # convert the timestep_indices to a timestep - timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()] + # convert the timestep_indices to a timestep (fixed_cycle already set timesteps above) + if content_or_style != 'fixed_cycle': + timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()] + + # Debug logging for timestep distribution + if self.train_config.timestep_debug_log > 0: + # Always collect data (fixed_cycle has no timestep_indices, use cycle index) + if content_or_style == 'fixed_cycle': + self._collected_indices.append(self.step_num % len(self._fixed_cycle_resolved_timesteps)) + self._collected_timesteps.extend(timesteps.cpu().tolist()) + else: + self._collected_indices.extend(timestep_indices.cpu().tolist()) + self._collected_timesteps.extend(timesteps.cpu().tolist()) + + # Log when we have enough samples + if len(self._collected_indices) >= self.train_config.timestep_debug_log: + scheduler_timesteps = self.sd.noise_scheduler.timesteps.cpu().tolist() + + print_acc(f"\n{'='*70}") + print_acc(f"TIMESTEP DISTRIBUTION DEBUG") + print_acc(f"{'='*70}") + + print_acc(f"Total scheduler timesteps length: {len(scheduler_timesteps)}") + # print_acc(f"\nScheduler timesteps array (first 10): {scheduler_timesteps[:10]}") + # print_acc(f"Scheduler timesteps array (last 10): {scheduler_timesteps[-10:]}") + + print_acc(f"\nFirst 10 timestep_indices (generated indices):") + print_acc(f"{self._collected_indices[:10]}") + print_acc(f"\nFirst 10 timesteps (actual values after indexing):") + print_acc(f"{self._collected_timesteps[:10]}") + + print_acc(f"Config:") + print_acc(f" content_or_style: {content_or_style}") + print_acc(f" noise_scheduler: {self.train_config.noise_scheduler}") + print_acc(f" timestep_type: {self.train_config.timestep_type}") + print_acc(f" num_train_timesteps: {self.train_config.num_train_timesteps}") + print_acc(f" min_denoising_steps: {min_noise_steps}") + print_acc(f" max_denoising_steps: {max_noise_steps}") + print_acc(f" gaussian_mean: {self.train_config.gaussian_mean}") + print_acc(f" gaussian_std: {self.train_config.gaussian_std}") + print_acc(f" gaussian_std_target: {self.train_config.gaussian_std_target}") + + # Statistics + num_samples = self.train_config.timestep_debug_log + indices_min = min(self._collected_indices[:num_samples]) + indices_max = max(self._collected_indices[:num_samples]) + indices_mean = sum(self._collected_indices[:num_samples]) / num_samples + + timesteps_min = min(self._collected_timesteps[:num_samples]) + timesteps_max = max(self._collected_timesteps[:num_samples]) + timesteps_mean = sum(self._collected_timesteps[:num_samples]) / num_samples + + print_acc(f"\nStatistics ({num_samples} samples):") + print_acc(f" Indices: max={indices_max}, mean={indices_mean:.1f}, min={indices_min}") + print_acc(f" Timesteps: max={timesteps_max:.1f}, mean={timesteps_mean:.1f}, min={timesteps_min:.1f}") + if self.train_config.gaussian_std_target is not None: + progress = self.step_num / self.train_config.steps + current_std = self.train_config.gaussian_std + progress * (self.train_config.gaussian_std_target - self.train_config.gaussian_std) + percent = progress * 100.0 + print_acc(f" current_std: {current_std:.3f} (progress: {percent:.1f}%)") + print_acc(f" Step: {self.step_num} ({self.step_num * 100 / self.train_config.steps:.1f}%)") + print_acc(f"{'='*70}\n") + + self._collected_indices = [] + self._collected_timesteps = [] with self.timer('prepare_noise'): # get noise @@ -1758,6 +1939,8 @@ def run(self): is_ssd=self.model_config.is_ssd, is_vega=self.model_config.is_vega, dropout=self.network_config.dropout, + rank_dropout=self.network_config.rank_dropout, + module_dropout=self.network_config.module_dropout, use_text_encoder_1=self.model_config.use_text_encoder_1, use_text_encoder_2=self.model_config.use_text_encoder_2, use_bias=is_lorm, @@ -1806,6 +1989,65 @@ def run(self): self.network.prepare_grad_etc(text_encoder, unet) flush() + # Create sampling network if sampling_transformer is specified and LoRA is used + if hasattr(self.sd, '_sampling_transformer') and self.sd._sampling_transformer is not None: + with memory_debug(print_acc, "Creating sampling network"): + print_acc("Creating sampling network for _sampling_transformer") + sampling_network = NetworkClass( + text_encoder=text_encoder, + unet=self.sd._sampling_transformer, + lora_dim=self.network_config.linear, + multiplier=1.0, + alpha=self.network_config.linear_alpha, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + conv_lora_dim=self.network_config.conv, + conv_alpha=self.network_config.conv_alpha, + is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, + is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, + is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, + is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, + is_ssd=self.model_config.is_ssd, + is_vega=self.model_config.is_vega, + dropout=self.network_config.dropout, + rank_dropout=self.network_config.rank_dropout, + module_dropout=self.network_config.module_dropout, + use_text_encoder_1=self.model_config.use_text_encoder_1, + use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=is_lorm, + is_lorm=is_lorm, + network_config=self.network_config, + network_type=self.network_config.type, + transformer_only=self.network_config.transformer_only, + is_transformer=self.sd.is_transformer, + base_model=self.sd, + **network_kwargs + ) + + sampling_network.force_to(self.device_torch, dtype=torch.float32) + sampling_network._update_torch_multiplier() + + sampling_network.apply_to( + text_encoder, + self.sd._sampling_transformer, + self.train_config.train_text_encoder, + self.train_config.train_unet + ) + + # Set can_merge_in same as main network (False if quantized/layer_offloading) + if self.model_config.quantize or self.model_config.layer_offloading: + sampling_network.can_merge_in = False + + # Store sampling network on model for use during generation + self.sd._sampling_network = sampling_network + # One LoRA for both: share parameters with training network (no copy/sync) + if hasattr(sampling_network, "share_parameters_with"): + sampling_network.share_parameters_with(self.network) + flush() + # LyCORIS doesnt have default_lr config = { 'text_encoder_lr': self.train_config.lr, @@ -1831,13 +2073,14 @@ def run(self): lora_name = f"{lora_name}_LoRA" latest_save_path = self.get_latest_save_path(lora_name) - extra_weights = None if latest_save_path is not None: - print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") - print_acc(f"Loading from {latest_save_path}") - extra_weights = self.load_weights(latest_save_path) - self.network.multiplier = 1.0 - + with memory_debug(print_acc, "pretrained_lora_load"): + print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") + print_acc(f"Loading from {latest_save_path}") + self.load_weights(latest_save_path) + self.network.multiplier = 1.0 + flush() + if self.network_config.layer_offloading: MemoryManager.attach( self.network, @@ -1983,7 +2226,27 @@ def run(self): try: print_acc(f"Loading optimizer state from {optimizer_state_file_path}") optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) - optimizer.load_state_dict(optimizer_state_dict) + + # PyTorch maps optimizer state by param order; param count must match + # or state will be applied to wrong parameters + current_param_count = sum( + len(g["params"]) for g in optimizer.param_groups + ) + saved_param_count = sum( + len(g["params"]) for g in optimizer_state_dict.get("param_groups", []) + ) + + if current_param_count != saved_param_count: + print_acc( + f"WARNING: Optimizer state NOT loaded: param count mismatch. " + f"Current optimizer has {current_param_count} params, " + f"but saved state has {saved_param_count} params. " + "Training will continue with fresh optimizer state." + ) + else: + optimizer.load_state_dict(optimizer_state_dict) + print_acc("Optimizer state restored successfully.") + del optimizer_state_dict flush() except Exception as e: @@ -2036,7 +2299,8 @@ def run(self): print_acc("Skipping first sample due to config setting") elif self.step_num <= 1 or self.train_config.force_first_sample: print_acc("Generating baseline samples before training") - self.sample(self.step_num) + with memory_debug(print_acc, "Generating baseline samples before training"): + self.sample(self.step_num) if self.accelerator.is_local_main_process: self.progress_bar = ToolkitProgressBar( @@ -2064,6 +2328,11 @@ def run(self): dataloader_reg = None dataloader_iterator_reg = None + # Sync epoch_num to datasets so resume from checkpoint keeps correct shuffle-for-cached-embeddings behavior + if dataloader is not None: + for dataset in get_dataloader_datasets(dataloader): + dataset.set_epoch_num(self.epoch_num) + # zero any gradients optimizer.zero_grad() @@ -2149,6 +2418,11 @@ def run(self): dataloader_iterator = iter(dataloader) trigger_dataloader_setup_epoch(dataloader) self.epoch_num += 1 + for dataset in get_dataloader_datasets(dataloader): + dataset.set_epoch_num(self.epoch_num) + clear_embeddings_cache = not getattr(dataset.dataset_config, 'cache_text_embeddings', False) + if clear_embeddings_cache: + dataset.clear_cached_embeddings_memory() if self.train_config.gradient_accumulation_steps == -1: # if we are accumulating for an entire epoch, trigger a step self.is_grad_accumulation_step = False @@ -2226,6 +2500,7 @@ def run(self): # torch.cuda.empty_cache() # if optimizer has get_lrs method, then use it learning_rate = 0.0 + update_rms = 0.0 # Average weight update RMS (for monitoring optimizer step magnitude) if not did_oom and loss_dict is not None: if hasattr(optimizer, 'get_avg_learning_rate'): learning_rate = optimizer.get_avg_learning_rate() @@ -2239,8 +2514,14 @@ def run(self): ) else: learning_rate = optimizer.param_groups[0]['lr'] + + # Get average weight update RMS if optimizer supports it (e.g., Adafactor) + if hasattr(optimizer, 'get_avg_update_rms'): + update_rms = optimizer.get_avg_update_rms() prog_bar_string = f"lr: {learning_rate:.1e}" + if update_rms > 0: + prog_bar_string += f" upd: {update_rms:.2e}" for key, value in loss_dict.items(): prog_bar_string += f" {key}: {value:.3e}" @@ -2263,6 +2544,11 @@ def run(self): if self.progress_bar is not None: self.progress_bar.pause() print_acc(f"\nSaving at step {self.step_num}") + # free memory before save to reduce OOM risk + optimizer.zero_grad(set_to_none=True) + if torch.cuda.is_available(): + torch.cuda.synchronize() + flush() self.save(self.step_num) self.ensure_params_requires_grad() # clear any grads @@ -2300,6 +2586,9 @@ def run(self): for key, value in loss_dict.items(): self.writer.add_scalar(f"{key}", value, self.step_num) self.writer.add_scalar(f"lr", learning_rate, self.step_num) + # Log weight update RMS if available (shows optimizer step magnitude) + if update_rms > 0: + self.writer.add_scalar("train/update_rms", update_rms, self.step_num) if self.progress_bar is not None: self.progress_bar.unpause() @@ -2308,6 +2597,17 @@ def run(self): self.logger.log({ 'learning_rate': learning_rate, }) + # Log differential guidance norm if available + if hasattr(self, 'diff_guidance_norm') and self.diff_guidance_norm is not None: + self.logger.log({ + 'diff_guidance_norm': self.diff_guidance_norm, + }) + self.diff_guidance_norm = None + # Log weight update RMS if available (Adafactor optimizer statistic) + if update_rms > 0: + self.logger.log({ + 'train/update_rms': update_rms, + }) if loss_dict is not None: for key, value in loss_dict.items(): self.logger.log({ @@ -2319,6 +2619,17 @@ def run(self): self.logger.log({ 'learning_rate': learning_rate, }) + # Log differential guidance norm if available + if hasattr(self, 'diff_guidance_norm') and self.diff_guidance_norm is not None: + self.logger.log({ + 'diff_guidance_norm': self.diff_guidance_norm, + }) + self.diff_guidance_norm = None + # Log weight update RMS if available (Adafactor optimizer statistic) + if update_rms > 0: + self.logger.log({ + 'train/update_rms': update_rms, + }) for key, value in loss_dict.items(): self.logger.log({ f'loss/{key}': value, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a48cbe09b..7962c0788 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -6,6 +6,7 @@ import torch import torchaudio +from toolkit.paths import normalize_path from toolkit.prompt_utils import PromptEmbeds ImgExt = Literal['jpg', 'png', 'webp'] @@ -36,6 +37,7 @@ def __init__(self, **kwargs): self.verbose: bool = kwargs.get('verbose', False) self.use_wandb: bool = kwargs.get('use_wandb', False) self.use_ui_logger: bool = kwargs.get('use_ui_logger', False) + self.debug: bool = kwargs.get('debug', False) self.project_name: str = kwargs.get('project_name', 'ai-toolkit') self.run_name: str = kwargs.get('run_name', None) @@ -182,6 +184,8 @@ def __init__(self, **kwargs): self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) self.dropout: Union[float, None] = kwargs.get('dropout', None) + self.rank_dropout: Union[float, None] = kwargs.get('rank_dropout', None) + self.module_dropout: Union[float, None] = kwargs.get('module_dropout', None) self.network_kwargs: dict = kwargs.get('network_kwargs', {}) self.lorm_config: Union[LoRMConfig, None] = None @@ -217,7 +221,7 @@ def __init__(self, **kwargs): self.layer_offloading = kwargs.get('layer_offloading', False) # start from a pretrained lora - self.pretrained_lora_path = kwargs.get('pretrained_lora_path', None) + self.pretrained_lora_path = normalize_path(kwargs.get('pretrained_lora_path', None)) AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v'] @@ -344,7 +348,7 @@ def __init__(self, **kwargs): self.num_tokens: str = kwargs.get('num_tokens', 4) -ContentOrStyleType = Literal['balanced', 'style', 'content'] +ContentOrStyleType = Literal['balanced', 'style', 'content', 'gaussian', 'fixed_cycle'] LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise'] @@ -353,6 +357,18 @@ def __init__(self, **kwargs): self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') self.content_or_style_reg: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') + self.gaussian_mean: float = kwargs.get('gaussian_mean', 0.5) + self.gaussian_std: float = kwargs.get('gaussian_std', 0.2) + self.gaussian_std_target: float = kwargs.get('gaussian_std_target', None) + self.timestep_bias_exponent: float = kwargs.get('timestep_bias_exponent', 3.0) + self.timestep_debug_log: int = kwargs.get('timestep_debug_log', 0) + # fixed_cycle: deterministic cycle over fixed timestep values (for Turbo LoRA reproducibility) + _default_fixed_cycle = [999, 875, 750, 625, 500, 375, 250, 125] + _fc = kwargs.get('fixed_cycle_timesteps', _default_fixed_cycle) + self.fixed_cycle_timesteps: Optional[List[float]] = _fc if (_fc is not None and len(_fc) > 0) else _default_fixed_cycle + self.fixed_cycle_seed: Optional[int] = kwargs.get('fixed_cycle_seed', None) + self.fixed_cycle_weight_peak_timesteps: Optional[List[float]] = kwargs.get('fixed_cycle_weight_peak_timesteps', [500, 375]) + self.fixed_cycle_weight_sigma: float = kwargs.get('fixed_cycle_weight_sigma', 372.8) self.steps: int = kwargs.get('steps', 1000) self.lr = kwargs.get('lr', 1e-6) self.unet_lr = kwargs.get('unet_lr', self.lr) @@ -464,6 +480,7 @@ def __init__(self, **kwargs): # blank prompt preservation will preserve the model's knowledge of a blank prompt self.blank_prompt_preservation = kwargs.get('blank_prompt_preservation', False) self.blank_prompt_preservation_multiplier = kwargs.get('blank_prompt_preservation_multiplier', 1.0) + self.blank_prompt_probability = kwargs.get('blank_prompt_probability', 1.0) # legacy if match_adapter_assist and self.match_adapter_chance == 0.0: @@ -667,6 +684,13 @@ def __init__(self, **kwargs): # 20 different model variants self.extras_name_or_path = kwargs.get("extras_name_or_path", self.name_or_path) + # for models that support it (e.g., zimage), a separate model path for sampling/inference + # training uses name_or_path, sampling uses sampling_name_or_path if set + self.sampling_name_or_path: Optional[str] = kwargs.get("sampling_name_or_path", None) + + # enable debug logging for safetensors load (path, size, duration) to diagnose mmap/load differences + self.debug_zimage_load: bool = kwargs.get("debug_zimage_load", False) + # path to an accuracy recovery adapter, either local or remote self.accuracy_recovery_adapter = kwargs.get("accuracy_recovery_adapter", None) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 51605c982..cfb084aa0 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -598,7 +598,10 @@ def __len__(self): def _get_single_item(self, index) -> 'FileItemDTO': file_item: 'FileItemDTO' = copy.deepcopy(self.file_list[index]) + file_item._current_epoch_num = getattr(self, '_epoch_num', 0) file_item.load_and_process_image(self.transform) + if file_item.is_text_embedding_cached and file_item.prompt_embeds is not None: + self.file_list[index].prompt_embeds = file_item.prompt_embeds file_item.load_caption(self.caption_dict) return file_item diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 20140eb8b..151f60d2a 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -415,8 +415,9 @@ def get_caption( new_token_list.append(token) token_list = new_token_list - if self.dataset_config.shuffle_tokens: - random.shuffle(token_list) + # Redundant code. Shuffling is performed in the code below. + # if self.dataset_config.shuffle_tokens: + # random.shuffle(token_list) # join back together caption = ', '.join(token_list) @@ -436,14 +437,17 @@ def get_caption( # trigger = self.dataset_config.random_triggers[int(random.random() * (len(self.dataset_config.random_triggers)))] # caption = caption + ', ' + trigger - if self.dataset_config.shuffle_tokens: - # shuffle again + if self.dataset_config.shuffle_tokens and not self.dataset_config.cache_text_embeddings: + # shuffle, keep first segment (until first comma) in place token_list = caption.split(',') # trim whitespace token_list = [x.strip() for x in token_list] # remove empty strings token_list = [x for x in token_list if x] - random.shuffle(token_list) + if len(token_list) > 1: + rest = token_list[1:] + random.shuffle(rest) + token_list = [token_list[0]] + rest caption = ', '.join(token_list) if caption == '': pass @@ -578,6 +582,8 @@ def load_and_process_video( img = img.transpose(Image.FLIP_TOP_BOTTOM) # Apply bucketing + if self.scale_to_width <= 0 or self.scale_to_height <= 0: + raise ValueError(f"Invalid scale dimensions for video {self.path}: scale_to_width={self.scale_to_width}, scale_to_height={self.scale_to_height}") img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) img = img.crop(( self.crop_x, @@ -780,6 +786,8 @@ def load_and_process_image( if self.dataset_config.buckets: # scale and crop based on file item + if self.scale_to_width <= 0 or self.scale_to_height <= 0: + raise ValueError(f"Invalid scale dimensions for image {self.path}: scale_to_width={self.scale_to_width}, scale_to_height={self.scale_to_height}") img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) # crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height: @@ -1961,6 +1969,9 @@ def load_prompt_embedding(self, device=None): if self.prompt_embeds is None: # load it from disk self.prompt_embeds = PromptEmbeds.load(self.get_text_embedding_path()) + if getattr(self, '_current_epoch_num', 0) > 0 and getattr(self.dataset_config, 'shuffle_tokens', False): + self.prompt_embeds.shuffle_sequence() + print_acc(f"Cached text embedding tokens shuffled (epoch {getattr(self, '_current_epoch_num', 0)})") class TextEmbeddingCachingMixin: def __init__(self: 'AiToolkitDataset', **kwargs): @@ -1968,6 +1979,19 @@ def __init__(self: 'AiToolkitDataset', **kwargs): if hasattr(super(), '__init__'): super().__init__(**kwargs) self.is_caching_text_embeddings = self.dataset_config.cache_text_embeddings + self._epoch_num: int = 0 + + def set_epoch_num(self: 'AiToolkitDataset', epoch_num: int) -> None: + self._epoch_num = epoch_num + if epoch_num > 0 and self.is_caching_text_embeddings and self.dataset_config.shuffle_tokens: + for file_item in self.file_list: + if getattr(file_item, 'prompt_embeds', None) is not None: + file_item.prompt_embeds.shuffle_sequence() + print_acc(f"\nCached text embedding tokens shuffled (epoch {epoch_num})\n") + + def clear_cached_embeddings_memory(self: 'AiToolkitDataset') -> None: + for file_item in self.file_list: + file_item.cleanup_text_embedding() def cache_text_embeddings(self: 'AiToolkitDataset'): with accelerator.main_process_first(): diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 5cb19229a..2e799f7f1 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -14,6 +14,7 @@ from .config_modules import NetworkConfig from .lorm import count_parameters from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin +from toolkit.util.debug import memory_debug from toolkit.kohya_lora import LoRANetwork from toolkit.models.DoRA import DoRAModule @@ -279,246 +280,250 @@ def __init__( if self.network_type.lower() != "lokr" or not self.use_old_lokr_format: self.peft_format = True - if self.peft_format: - # no alpha for peft - self.alpha = self.lora_dim - alpha = self.alpha - self.conv_alpha = self.conv_lora_dim - conv_alpha = self.conv_alpha + # Allow alpha for peft + # Todo: Folding Alpha on save + # + # if self.peft_format: + # # no alpha for peft + # self.alpha = self.lora_dim + # alpha = self.alpha + # self.conv_alpha = self.conv_lora_dim + # conv_alpha = self.conv_alpha self.full_train_in_out = full_train_in_out - if modules_dim is not None: - print(f"create LoRA network from weights") - elif block_dims is not None: - print(f"create LoRA network from block_dims") - print( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - print(f"block_dims: {block_dims}") - print(f"block_alphas: {block_alphas}") - if conv_block_dims is not None: - print(f"conv_block_dims: {conv_block_dims}") - print(f"conv_block_alphas: {conv_block_alphas}") - else: - print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - print( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") - if self.conv_lora_dim is not None: + with memory_debug(print, "create LoRA network"): + if modules_dim is not None: + print(f"create LoRA network from weights") + elif block_dims is not None: + print(f"create LoRA network from block_dims") + print( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + print(f"block_dims: {block_dims}") + print(f"block_alphas: {block_alphas}") + if conv_block_dims is not None: + print(f"conv_block_dims: {conv_block_dims}") + print(f"conv_block_alphas: {conv_block_alphas}") + else: + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") print( - f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + if self.conv_lora_dim is not None: + print( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") - # create module instances - def create_modules( + # create module instances + def create_modules( is_unet: bool, text_encoder_idx: Optional[int], # None, 1, 2 root_module: torch.nn.Module, target_replace_modules: List[torch.nn.Module], - ) -> List[LoRAModule]: - unet_prefix = self.LORA_PREFIX_UNET - if self.peft_format: - unet_prefix = self.PEFT_PREFIX_UNET - if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer: - unet_prefix = f"lora_transformer" + ) -> List[LoRAModule]: + unet_prefix = self.LORA_PREFIX_UNET if self.peft_format: - unet_prefix = "transformer" - - prefix = ( - unet_prefix - if is_unet - else ( - self.LORA_PREFIX_TEXT_ENCODER - if text_encoder_idx is None - else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + unet_prefix = self.PEFT_PREFIX_UNET + if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer: + unet_prefix = f"lora_transformer" + if self.peft_format: + unet_prefix = "transformer" + + prefix = ( + unet_prefix + if is_unet + else ( + self.LORA_PREFIX_TEXT_ENCODER + if text_encoder_idx is None + else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) + ) ) - ) - loras = [] - skipped = [] - attached_modules = [] - lora_shape_dict = {} - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - is_linear = child_module.__class__.__name__ in LINEAR_MODULES - is_conv2d = child_module.__class__.__name__ in CONV_MODULES - is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) - - - lora_name = [prefix, name, child_name] - # filter out blank - lora_name = [x for x in lora_name if x and x != ""] - lora_name = ".".join(lora_name) - # if it doesnt have a name, it wil have two dots - lora_name.replace("..", ".") - clean_name = lora_name - if self.peft_format: - # we replace this on saving - lora_name = lora_name.replace(".", "$$") - else: - lora_name = lora_name.replace(".", "_") - - skip = False - if any([word in clean_name for word in self.ignore_if_contains]): - skip = True - - # see if it is over threshold - if count_parameters(child_module) < parameter_threshold: - skip = True - - if self.transformer_only and is_unet: - transformer_block_names = None - if base_model is not None: - transformer_block_names = base_model.get_transformer_block_names() - - if transformer_block_names is not None: - if not any([name in lora_name for name in transformer_block_names]): - skip = True + loras = [] + skipped = [] + attached_modules = [] + lora_shape_dict = {} + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ in LINEAR_MODULES + is_conv2d = child_module.__class__.__name__ in CONV_MODULES + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + + lora_name = [prefix, name, child_name] + # filter out blank + lora_name = [x for x in lora_name if x and x != ""] + lora_name = ".".join(lora_name) + # if it doesnt have a name, it wil have two dots + lora_name.replace("..", ".") + clean_name = lora_name + if self.peft_format: + # we replace this on saving + lora_name = lora_name.replace(".", "$$") else: - if self.is_pixart: - if "transformer_blocks" not in lora_name: - skip = True - if self.is_flux: - if "transformer_blocks" not in lora_name: - skip = True - if self.is_lumina2: - if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name: - skip = True - if self.is_v3: - if "transformer_blocks" not in lora_name: - skip = True - - # handle custom models - if hasattr(root_module, 'transformer_blocks'): - if "transformer_blocks" not in lora_name: - skip = True - - if hasattr(root_module, 'blocks'): - if "blocks" not in lora_name: - skip = True - - if hasattr(root_module, 'single_blocks'): - if "single_blocks" not in lora_name and "double_blocks" not in lora_name: - skip = True - - if (is_linear or is_conv2d) and not skip: + lora_name = lora_name.replace(".", "_") - if self.only_if_contains is not None: - if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]): - continue + skip = False + if any([word in clean_name for word in self.ignore_if_contains]): + skip = True - dim = None - alpha = None - - if modules_dim is not None: - # モジュール指定あり - if lora_name in modules_dim: - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] - else: - # 通常、すべて対象とする - if is_linear or is_conv2d_1x1: - dim = self.lora_dim - alpha = self.alpha - elif self.conv_lora_dim is not None: - dim = self.conv_lora_dim - alpha = self.conv_alpha - - if dim is None or dim == 0: - # skipした情報を出力 - if is_linear or is_conv2d_1x1 or ( - self.conv_lora_dim is not None or conv_block_dims is not None): - skipped.append(lora_name) - continue + # see if it is over threshold + if count_parameters(child_module) < parameter_threshold: + skip = True - module_kwargs = {} - - if self.network_type.lower() == "lokr": - module_kwargs["factor"] = self.network_config.lokr_factor - - if self.is_ara: - module_kwargs["is_ara"] = True - - lora = module_class( - lora_name, - child_module, - self.multiplier, - dim, - alpha, - dropout=dropout, - rank_dropout=rank_dropout, - module_dropout=module_dropout, - network=self, - parent=module, - use_bias=use_bias, - **module_kwargs - ) - loras.append(lora) - if self.network_type.lower() == "lokr": - try: - lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)] - except: - pass - else: - if self.full_rank: - lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape)] + if self.transformer_only and is_unet: + transformer_block_names = None + if base_model is not None: + transformer_block_names = base_model.get_transformer_block_names() + + if transformer_block_names is not None: + if not any([name in lora_name for name in transformer_block_names]): + skip = True else: - lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)] - return loras, skipped - - text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] - - # create LoRA for text encoder - # 毎回すべてのモジュールを作るのは無駄なので要検討 - self.text_encoder_loras = [] - skipped_te = [] - if train_text_encoder: - for i, text_encoder in enumerate(text_encoders): - if not use_text_encoder_1 and i == 0: - continue - if not use_text_encoder_2 and i == 1: - continue - if len(text_encoders) > 1: - index = i + 1 - print(f"create LoRA for Text Encoder {index}:") - else: - index = None - print(f"create LoRA for Text Encoder:") - - replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - - if self.is_pixart: - replace_modules = ["T5EncoderModel"] - - text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules) - self.text_encoder_loras.extend(text_encoder_loras) - skipped_te += skipped - print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") - - # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights - target_modules = target_lin_modules - if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: - target_modules += target_conv_modules - - if is_v3: - target_modules = ["SD3Transformer2DModel"] - - if is_pixart: - target_modules = ["PixArtTransformer2DModel"] - - if is_auraflow: - target_modules = ["AuraFlowTransformer2DModel"] - - if is_flux: - target_modules = ["FluxTransformer2DModel"] - - if is_lumina2: - target_modules = ["Lumina2Transformer2DModel"] + if self.is_pixart: + if "transformer_blocks" not in lora_name: + skip = True + if self.is_flux: + if "transformer_blocks" not in lora_name: + skip = True + if self.is_lumina2: + if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name: + skip = True + if self.is_v3: + if "transformer_blocks" not in lora_name: + skip = True + + # handle custom models + if hasattr(root_module, 'transformer_blocks'): + if "transformer_blocks" not in lora_name: + skip = True + + if hasattr(root_module, 'blocks'): + if "blocks" not in lora_name: + skip = True + + if hasattr(root_module, 'single_blocks'): + if "single_blocks" not in lora_name and "double_blocks" not in lora_name: + skip = True + + if (is_linear or is_conv2d) and not skip: + + if self.only_if_contains is not None: + if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]): + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + else: + # 通常、すべて対象とする + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or ( + self.conv_lora_dim is not None or conv_block_dims is not None): + skipped.append(lora_name) + continue + + module_kwargs = {} + + if self.network_type.lower() == "lokr": + module_kwargs["factor"] = self.network_config.lokr_factor + + if self.is_ara: + module_kwargs["is_ara"] = True + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + network=self, + parent=module, + use_bias=use_bias, + **module_kwargs + ) + loras.append(lora) + if self.network_type.lower() == "lokr": + try: + lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)] + except: + pass + else: + if self.full_rank: + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape)] + else: + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)] + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras = [] + skipped_te = [] + if train_text_encoder: + for i, text_encoder in enumerate(text_encoders): + if not use_text_encoder_1 and i == 0: + continue + if not use_text_encoder_2 and i == 1: + continue + if len(text_encoders) > 1: + index = i + 1 + print(f"create LoRA for Text Encoder {index}:") + else: + index = None + print(f"create LoRA for Text Encoder:") + + replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + + if self.is_pixart: + replace_modules = ["T5EncoderModel"] + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + target_modules = target_lin_modules + if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: + target_modules += target_conv_modules + + if is_v3: + target_modules = ["SD3Transformer2DModel"] + + if is_pixart: + target_modules = ["PixArtTransformer2DModel"] + + if is_auraflow: + target_modules = ["AuraFlowTransformer2DModel"] + + if is_flux: + target_modules = ["FluxTransformer2DModel"] + + if is_lumina2: + target_modules = ["Lumina2Transformer2DModel"] - if train_unet: - self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - else: - self.unet_loras = [] - skipped_un = [] - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + if train_unet: + self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) + else: + self.unet_loras = [] + skipped_un = [] + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: @@ -577,6 +582,31 @@ def create_modules( unet.conv_in = self.unet_conv_in unet.conv_out = self.unet_conv_out + def share_parameters_with(self, other: "LoRASpecialNetwork") -> None: + """ + Share all trainable parameters with another network of the same structure. + Used so sampling_network uses the same LoRA weights as the training network + (one LoRA for both training and sampling, no copy/sync). + """ + assert len(self.unet_loras) == len(other.unet_loras), "unet_loras length mismatch" + assert len(self.text_encoder_loras) == len(other.text_encoder_loras), "text_encoder_loras length mismatch" + + def _share_lora_pair(my_lora: torch.nn.Module, other_lora: torch.nn.Module) -> None: + assert getattr(my_lora, "lora_name", None) == getattr(other_lora, "lora_name", None), ( + f"lora name mismatch: {getattr(my_lora, 'lora_name', None)} vs {getattr(other_lora, 'lora_name', None)}" + ) + for name, param in other_lora.named_parameters(): + parts = name.split(".") + obj = my_lora + for p in parts[:-1]: + obj = getattr(obj, p) + setattr(obj, parts[-1], param) + + for my_lora, other_lora in zip(self.unet_loras, other.unet_loras): + _share_lora_pair(my_lora, other_lora) + for my_lora, other_lora in zip(self.text_encoder_loras, other.text_encoder_loras): + _share_lora_pair(my_lora, other_lora) + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): # call Lora prepare_optimizer_params all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index ae117c857..906996f7b 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -41,6 +41,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc +from toolkit.util.debug import is_debug_enabled if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork @@ -410,6 +411,7 @@ def generate_images( rng_state = torch.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + pipeline_created = (pipeline is None) if pipeline is None: pipeline = self.get_generation_pipeline() try: @@ -423,254 +425,277 @@ def generate_images( # pipeline.to(self.device_torch) - with network: - with torch.no_grad(): - if network is not None: - assert network.is_active + try: + with network: + with torch.no_grad(): + if network is not None: + assert network.is_active - for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): - gen_config = image_configs[i] + # Load sampling transformer onto device if it exists + if self._sampling_transformer is not None: + self.model.to("cpu") + self._sampling_transformer.to(self.device_torch, dtype=self.torch_dtype) + else: + self.model.to(self.device_torch, dtype=self.torch_dtype) - extra = {} - validation_image = None - if self.adapter is not None and gen_config.adapter_image_path is not None: - validation_image = Image.open(gen_config.adapter_image_path) - if ".inpaint." not in gen_config.adapter_image_path: - validation_image = validation_image.convert("RGB") - else: - # make sure it has an alpha - if validation_image.mode != "RGBA": - raise ValueError("Inpainting images must have an alpha channel") - if isinstance(self.adapter, T2IAdapter): - # not sure why this is double?? - validation_image = validation_image.resize( - (gen_config.width * 2, gen_config.height * 2)) - extra['image'] = validation_image - extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale - if isinstance(self.adapter, ControlNetModel): - validation_image = validation_image.resize( - (gen_config.width, gen_config.height)) - extra['image'] = validation_image - extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale - if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None: - validation_image = validation_image.resize((gen_config.width, gen_config.height)) - extra['control_image'] = validation_image - extra['control_image_idx'] = gen_config.ctrl_idx - if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): - transform = transforms.Compose([ - transforms.ToTensor(), - ]) - validation_image = transform(validation_image) - if isinstance(self.adapter, CustomAdapter): - # todo allow loading multiple - transform = transforms.Compose([ - transforms.ToTensor(), - ]) - validation_image = transform(validation_image) - self.adapter.num_images = 1 - if isinstance(self.adapter, ReferenceAdapter): - # need -1 to 1 - validation_image = transforms.ToTensor()(validation_image) - validation_image = validation_image * 2.0 - 1.0 - validation_image = validation_image.unsqueeze(0) - self.adapter.set_reference_images(validation_image) + for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): + gen_config = image_configs[i] - if network is not None: - network.multiplier = gen_config.network_multiplier - torch.manual_seed(gen_config.seed) - torch.cuda.manual_seed(gen_config.seed) - - generator = torch.manual_seed(gen_config.seed) - - if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ - and gen_config.adapter_image_path is not None: - # run through the adapter to saturate the embeds - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( - validation_image) - self.adapter(conditional_clip_embeds) - - if self.adapter is not None and isinstance(self.adapter, CustomAdapter): - # handle condition the prompts - gen_config.prompt = self.adapter.condition_prompt( - gen_config.prompt, - is_unconditional=False, - ) - gen_config.prompt_2 = gen_config.prompt - gen_config.negative_prompt = self.adapter.condition_prompt( - gen_config.negative_prompt, - is_unconditional=True, - ) - gen_config.negative_prompt_2 = gen_config.negative_prompt - - if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: - self.adapter.trigger_pre_te( - tensors_0_1=validation_image, - is_training=False, - has_been_preprocessed=False, - quad_count=4 - ) + extra = {} - if self.sample_prompts_cache is not None: - conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype) - unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) - else: - ctrl_img = None - has_control_images = False - if gen_config.ctrl_img is not None or gen_config.ctrl_img_1 is not None or gen_config.ctrl_img_2 is not None or gen_config.ctrl_img_3 is not None: - has_control_images = True - # load the control image if out model uses it in text encoding - if has_control_images and self.encode_control_in_text_embeddings: - ctrl_img_list = [] + validation_image = None + if self.adapter is not None and gen_config.adapter_image_path is not None: + validation_image = Image.open(gen_config.adapter_image_path) + if ".inpaint." not in gen_config.adapter_image_path: + validation_image = validation_image.convert("RGB") + else: + # make sure it has an alpha + if validation_image.mode != "RGBA": + raise ValueError("Inpainting images must have an alpha channel") + if isinstance(self.adapter, T2IAdapter): + # not sure why this is double?? + validation_image = validation_image.resize( + (gen_config.width * 2, gen_config.height * 2)) + extra['image'] = validation_image + extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, ControlNetModel): + validation_image = validation_image.resize( + (gen_config.width, gen_config.height)) + extra['image'] = validation_image + extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, CustomAdapter) and self.adapter.control_lora is not None: + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['control_image'] = validation_image + extra['control_image_idx'] = gen_config.ctrl_idx + if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + if isinstance(self.adapter, CustomAdapter): + # todo allow loading multiple + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + self.adapter.num_images = 1 + if isinstance(self.adapter, ReferenceAdapter): + # need -1 to 1 + validation_image = transforms.ToTensor()(validation_image) + validation_image = validation_image * 2.0 - 1.0 + validation_image = validation_image.unsqueeze(0) + self.adapter.set_reference_images(validation_image) + + if network is not None: + network.multiplier = gen_config.network_multiplier + torch.manual_seed(gen_config.seed) + torch.cuda.manual_seed(gen_config.seed) + + generator = torch.manual_seed(gen_config.seed) + + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ + and gen_config.adapter_image_path is not None: + # run through the adapter to saturate the embeds + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + validation_image) + self.adapter(conditional_clip_embeds) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + # handle condition the prompts + gen_config.prompt = self.adapter.condition_prompt( + gen_config.prompt, + is_unconditional=False, + ) + gen_config.prompt_2 = gen_config.prompt + gen_config.negative_prompt = self.adapter.condition_prompt( + gen_config.negative_prompt, + is_unconditional=True, + ) + gen_config.negative_prompt_2 = gen_config.negative_prompt + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: + self.adapter.trigger_pre_te( + tensors_0_1=validation_image, + is_training=False, + has_been_preprocessed=False, + quad_count=4 + ) + + if self.sample_prompts_cache is not None: + conditional_embeds = self.sample_prompts_cache[i]['conditional'].to(self.device_torch, dtype=self.torch_dtype) + unconditional_embeds = self.sample_prompts_cache[i]['unconditional'].to(self.device_torch, dtype=self.torch_dtype) + else: + ctrl_img = None + has_control_images = False + if gen_config.ctrl_img is not None or gen_config.ctrl_img_1 is not None or gen_config.ctrl_img_2 is not None or gen_config.ctrl_img_3 is not None: + has_control_images = True + # load the control image if out model uses it in text encoding + if has_control_images and self.encode_control_in_text_embeddings: + ctrl_img_list = [] - if gen_config.ctrl_img is not None: - ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img = ( - TF.to_tensor(ctrl_img) - .unsqueeze(0) - .to(self.device_torch, dtype=self.torch_dtype) - ) - ctrl_img_list.append(ctrl_img) + if gen_config.ctrl_img is not None: + ctrl_img = Image.open(gen_config.ctrl_img).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img = ( + TF.to_tensor(ctrl_img) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img) - if gen_config.ctrl_img_1 is not None: - ctrl_img_1 = Image.open(gen_config.ctrl_img_1).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img_1 = ( - TF.to_tensor(ctrl_img_1) - .unsqueeze(0) - .to(self.device_torch, dtype=self.torch_dtype) - ) - ctrl_img_list.append(ctrl_img_1) - if gen_config.ctrl_img_2 is not None: - ctrl_img_2 = Image.open(gen_config.ctrl_img_2).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img_2 = ( - TF.to_tensor(ctrl_img_2) - .unsqueeze(0) - .to(self.device_torch, dtype=self.torch_dtype) - ) - ctrl_img_list.append(ctrl_img_2) - if gen_config.ctrl_img_3 is not None: - ctrl_img_3 = Image.open(gen_config.ctrl_img_3).convert("RGB") - # convert to 0 to 1 tensor - ctrl_img_3 = ( - TF.to_tensor(ctrl_img_3) - .unsqueeze(0) - .to(self.device_torch, dtype=self.torch_dtype) - ) - ctrl_img_list.append(ctrl_img_3) + if gen_config.ctrl_img_1 is not None: + ctrl_img_1 = Image.open(gen_config.ctrl_img_1).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_1 = ( + TF.to_tensor(ctrl_img_1) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_1) + if gen_config.ctrl_img_2 is not None: + ctrl_img_2 = Image.open(gen_config.ctrl_img_2).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_2 = ( + TF.to_tensor(ctrl_img_2) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_2) + if gen_config.ctrl_img_3 is not None: + ctrl_img_3 = Image.open(gen_config.ctrl_img_3).convert("RGB") + # convert to 0 to 1 tensor + ctrl_img_3 = ( + TF.to_tensor(ctrl_img_3) + .unsqueeze(0) + .to(self.device_torch, dtype=self.torch_dtype) + ) + ctrl_img_list.append(ctrl_img_3) - if self.has_multiple_control_images: - ctrl_img = ctrl_img_list - else: - ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None - # encode the prompt ourselves so we can do fun stuff with embeddings - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = False - conditional_embeds = self.encode_prompt( - gen_config.prompt, - gen_config.prompt_2, - force_all=True, - control_images=ctrl_img - ) - - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = True - unconditional_embeds = self.encode_prompt( - gen_config.negative_prompt, - gen_config.negative_prompt_2, - force_all=True, - control_images=ctrl_img + if self.has_multiple_control_images: + ctrl_img = ctrl_img_list + else: + ctrl_img = ctrl_img_list[0] if len(ctrl_img_list) > 0 else None + # encode the prompt ourselves so we can do fun stuff with embeddings + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.encode_prompt( + gen_config.prompt, + gen_config.prompt_2, + force_all=True, + control_images=ctrl_img + ) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.encode_prompt( + gen_config.negative_prompt, + gen_config.negative_prompt_2, + force_all=True, + control_images=ctrl_img + ) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + # allow any manipulations to take place to embeddings + gen_config.post_process_embeddings( + conditional_embeds, + unconditional_embeds, ) - if isinstance(self.adapter, CustomAdapter): - self.adapter.is_unconditional_run = False - # allow any manipulations to take place to embeddings - gen_config.post_process_embeddings( - conditional_embeds, - unconditional_embeds, - ) - - if self.decorator is not None: - # apply the decorator to the embeddings - conditional_embeds.text_embeds = self.decorator( - conditional_embeds.text_embeds) - unconditional_embeds.text_embeds = self.decorator( - unconditional_embeds.text_embeds, is_unconditional=True) - - if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ - and gen_config.adapter_image_path is not None: - # apply the image projection - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( - validation_image) - unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, - True) - conditional_embeds = self.adapter( - conditional_embeds, conditional_clip_embeds, is_unconditional=False) - unconditional_embeds = self.adapter( - unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) - - if self.adapter is not None and isinstance(self.adapter, CustomAdapter): - conditional_embeds = self.adapter.condition_encoded_embeds( - tensors_0_1=validation_image, - prompt_embeds=conditional_embeds, - is_training=False, - has_been_preprocessed=False, - is_generating_samples=True, - ) - unconditional_embeds = self.adapter.condition_encoded_embeds( - tensors_0_1=validation_image, - prompt_embeds=unconditional_embeds, - is_training=False, - has_been_preprocessed=False, - is_unconditional=True, - is_generating_samples=True, + if self.decorator is not None: + # apply the decorator to the embeddings + conditional_embeds.text_embeds = self.decorator( + conditional_embeds.text_embeds) + unconditional_embeds.text_embeds = self.decorator( + unconditional_embeds.text_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ + and gen_config.adapter_image_path is not None: + # apply the image projection + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + validation_image) + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, + True) + conditional_embeds = self.adapter( + conditional_embeds, conditional_clip_embeds, is_unconditional=False) + unconditional_embeds = self.adapter( + unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=conditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_generating_samples=True, + ) + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=unconditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=True, + is_generating_samples=True, + ) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( + gen_config.extra_values) > 0: + extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, + dtype=self.torch_dtype) + # apply extra values to the embeddings + self.adapter.add_extra_values( + extra_values, is_unconditional=False) + self.adapter.add_extra_values(torch.zeros_like( + extra_values), is_unconditional=True) + pass # todo remove, for debugging + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: + # if we have a refiner loaded, set the denoising end at the refiner start + extra['denoising_end'] = gen_config.refiner_start_at + extra['output_type'] = 'latent' + if not self.is_xl: + raise ValueError( + "Refiner is only supported for XL models") + + conditional_embeds = conditional_embeds.to( + self.device_torch, dtype=self.unet.dtype) + unconditional_embeds = unconditional_embeds.to( + self.device_torch, dtype=self.unet.dtype) + + img = self.generate_single_image( + pipeline, + gen_config, + conditional_embeds, + unconditional_embeds, + generator, + extra, ) - if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( - gen_config.extra_values) > 0: - extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, - dtype=self.torch_dtype) - # apply extra values to the embeddings - self.adapter.add_extra_values( - extra_values, is_unconditional=False) - self.adapter.add_extra_values(torch.zeros_like( - extra_values), is_unconditional=True) - pass # todo remove, for debugging - - if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: - # if we have a refiner loaded, set the denoising end at the refiner start - extra['denoising_end'] = gen_config.refiner_start_at - extra['output_type'] = 'latent' - if not self.is_xl: - raise ValueError( - "Refiner is only supported for XL models") - - conditional_embeds = conditional_embeds.to( - self.device_torch, dtype=self.unet.dtype) - unconditional_embeds = unconditional_embeds.to( - self.device_torch, dtype=self.unet.dtype) - - img = self.generate_single_image( - pipeline, - gen_config, - conditional_embeds, - unconditional_embeds, - generator, - extra, - ) - - gen_config.save_image(img, i) - gen_config.log_image(img, i) - self._after_sample_image(i, len(image_configs)) - flush() + gen_config.save_image(img, i) + gen_config.log_image(img, i) + self._after_sample_image(i, len(image_configs)) + flush() if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): self.adapter.clear_memory() - # clear pipeline and cache to reduce vram usage - del pipeline - torch.cuda.empty_cache() + finally: + # Unload sampling transformer from GPU and restore main model to device + if self._sampling_transformer is not None: + self._sampling_transformer.to("cpu") + self.model.to(self.device_torch, dtype=self.torch_dtype) + if is_debug_enabled(): + print_acc("Unloaded sampling transformer to CPU") + # Ensure CUDA work is finished so VRAM is actually released before next use + if torch.cuda.is_available(): + torch.cuda.synchronize() + # Clear pipeline and cache to reduce vram usage (only if we created pipeline here) + if pipeline_created: + try: + del pipeline + except NameError: + pass + torch.cuda.empty_cache() # restore training state torch.set_rng_state(rng_state) @@ -682,7 +707,6 @@ def generate_images( network.train() network.multiplier = start_multiplier - self.unet.to(self.device_torch, dtype=self.torch_dtype) if network.is_merged_in: network.merge_out(merge_multiplier) # self.tokenizer.to(original_device_dict['tokenizer']) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index b8556f125..e27e83f28 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -547,8 +547,8 @@ def get_state_dict(self: Network, extra_state_dict=None, dtype=torch.float16): new_save_dict = {} for key, value in save_dict.items(): # lokr needs alpha - if key.endswith('.alpha') and self.network_type.lower() != "lokr": - continue + # if key.endswith('.alpha') and self.network_type.lower() != "lokr": + # continue new_key = key new_key = new_key.replace('lora_down', 'lora_A') new_key = new_key.replace('lora_up', 'lora_B') @@ -633,21 +633,21 @@ def load_weights(self: Network, file, force_weight_mapping=False): # lora_down = lora_A # lora_up = lora_B # no alpha - if load_key.endswith('.alpha') and self.network_type.lower() != "lokr": - continue + # if load_key.endswith('.alpha') and self.network_type.lower() != "lokr": + # continue load_key = load_key.replace('lora_A', 'lora_down') load_key = load_key.replace('lora_B', 'lora_up') # replace all . with $$ load_key = load_key.replace('.', '$$') load_key = load_key.replace('$$lora_down$$', '.lora_down.') load_key = load_key.replace('$$lora_up$$', '.lora_up.') - + if load_key.endswith('$$alpha'): + load_key = load_key[:-7] + '.alpha' + # patch lokr, not sure why we need to but whatever if self.network_type.lower() == "lokr": load_key = load_key.replace('$$lokr_w1', '.lokr_w1') load_key = load_key.replace('$$lokr_w2', '.lokr_w2') - if load_key.endswith('$$alpha'): - load_key = load_key[:-7] + '.alpha' if self.network_type.lower() == "lokr": # lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1 @@ -713,9 +713,10 @@ def load_weights(self: Network, file, force_weight_mapping=False): return self.load_weights(file, force_weight_mapping=True) info = self.load_state_dict(load_sd, False) - if len(extra_dict.keys()) == 0: - extra_dict = None - return extra_dict + del load_sd + if isinstance(file, str): + del weights_sd + return None @torch.no_grad() def _update_torch_multiplier(self: Network): diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py index 8897bdc07..16f16508e 100644 --- a/toolkit/optimizers/adafactor.py +++ b/toolkit/optimizers/adafactor.py @@ -32,6 +32,8 @@ class Adafactor(torch.optim.Optimizer): Coefficient used to compute running averages of square beta1 (`float`, *optional*): Coefficient used for computing running averages of gradient + (first moment, like in Adam). If not None, enables momentum. + Suggested values: 0.9 (default), 0.95 or 0.99 for smoother updates. weight_decay (`float`, *optional*, defaults to 0.0): Weight decay (L2 penalty) scale_parameter (`bool`, *optional*, defaults to `True`): @@ -40,6 +42,12 @@ class Adafactor(torch.optim.Optimizer): If True, time-dependent learning rate is computed instead of external learning rate warmup_init (`bool`, *optional*, defaults to `False`): Time-dependent learning rate computation depends on whether warm-up initialization is being used + min_lr (`float`, *optional*, defaults to `1e-6`): + Minimum learning rate multiplier for warmup phase when `warmup_init=True` and `relative_step=True`. + Controls the linear growth rate: `lr = min_lr * step` during warmup. + max_lr (`float`, *optional*, defaults to `1e-2`): + Maximum learning rate cap for relative step mode when `relative_step=True`. + Acts as upper bound for `min_step` when `warmup_init=False` or when warmup phase completes. This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. @@ -106,13 +114,15 @@ def __init__( scale_parameter=True, relative_step=True, warmup_init=False, - do_paramiter_swapping=False, - paramiter_swapping_factor=0.1, + min_lr=1e-6, + max_lr=1e-4, + do_parameter_swapping=False, + parameter_swapping_factor=0.1, stochastic_accumulation=True, stochastic_rounding=True, ): self.stochastic_rounding = stochastic_rounding - if lr is not None and relative_step: + if lr is not None and lr != 0 and relative_step: raise ValueError( "Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: @@ -129,11 +139,17 @@ def __init__( "scale_parameter": scale_parameter, "relative_step": relative_step, "warmup_init": warmup_init, + "min_lr": min_lr, + "max_lr": max_lr, } super().__init__(params, defaults) + # Store LR limits so they can be reapplied after load_state_dict (restart with new config). + self._min_lr = min_lr + self._max_lr = max_lr + self.base_lrs: List[float] = [ - lr for group in self.param_groups + group['lr'] for group in self.param_groups ] self.is_stochastic_rounding_accumulation = False @@ -148,60 +164,70 @@ def __init__( stochastic_grad_accummulation ) - self.do_paramiter_swapping = do_paramiter_swapping - self.paramiter_swapping_factor = paramiter_swapping_factor - self._total_paramiter_size = 0 - # count total paramiters + self.do_parameter_swapping = do_parameter_swapping + self.parameter_swapping_factor = parameter_swapping_factor + self._total_parameter_size = 0 + # count total parameters for group in self.param_groups: for param in group['params']: - self._total_paramiter_size += torch.numel(param) - # pretty print total paramiters with comma seperation - print(f"Total training paramiters: {self._total_paramiter_size:,}") - - # needs to be enabled to count paramiters - if self.do_paramiter_swapping: - self.enable_paramiter_swapping(self.paramiter_swapping_factor) + self._total_parameter_size += torch.numel(param) + # pretty print total parameters with comma separation + print(f"Total training parameters: {self._total_parameter_size:,}") - - def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1): - self.do_paramiter_swapping = True - self.paramiter_swapping_factor = paramiter_swapping_factor + # needs to be enabled to count parameters + if self.do_parameter_swapping: + self.enable_parameter_swapping(self.parameter_swapping_factor) + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + # Apply current run's min_lr/max_lr so changed config is used after restart. + for group in self.param_groups: + group["min_lr"] = self._min_lr + group["max_lr"] = self._max_lr + + def enable_parameter_swapping(self, parameter_swapping_factor=0.1): + self.do_parameter_swapping = True + self.parameter_swapping_factor = parameter_swapping_factor # call it an initial time - self.swap_paramiters() + self.swap_parameters() - def swap_paramiters(self): + def swap_parameters(self): all_params = [] - # deactivate all paramiters + # deactivate all parameters for group in self.param_groups: for param in group['params']: param.requires_grad_(False) # remove any grad param.grad = None all_params.append(param) - # shuffle all paramiters + # shuffle all parameters random.shuffle(all_params) - # keep activating paramiters until we are going to go over the target paramiters - target_paramiters = int(self._total_paramiter_size * self.paramiter_swapping_factor) - total_paramiters = 0 + # keep activating parameters until we are going to go over the target parameters + target_parameters = max(1, int(self._total_parameter_size * self.parameter_swapping_factor)) + total_parameters = 0 for param in all_params: - total_paramiters += torch.numel(param) - if total_paramiters >= target_paramiters: + param.requires_grad_(True) + total_parameters += torch.numel(param) + if total_parameters >= target_parameters: break - else: - param.requires_grad_(True) @staticmethod def _get_lr(param_group, param_state): rel_step_sz = param_group["lr"] if param_group["relative_step"]: - min_step = 1e-6 * \ - param_state["step"] if param_group["warmup_init"] else 1e-2 + if param_group["warmup_init"]: + min_step = param_group["min_lr"] * param_state["step"] + else: + min_step = param_group["max_lr"] rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) param_scale = 1.0 if param_group["scale_parameter"]: param_scale = max(param_group["eps"][1], param_state["RMS"]) - return param_scale * rel_step_sz + lr = param_scale * rel_step_sz + # Ensure learning rate is between min_lr and max_lr + lr = max(param_group["min_lr"], min(lr, param_group["max_lr"])) + return lr @staticmethod def _get_options(param_group, param_shape): @@ -234,11 +260,20 @@ def step_hook(self): # adafactor manages its own lr def get_learning_rates(self): - lrs = [ - self._get_lr(group, self.state[group["params"][0]]) - for group in self.param_groups - if group["params"][0].grad is not None - ] + lrs = [] + for group in self.param_groups: + # Find first param with initialized state + lr = None + for param in group["params"]: + if param in self.state and len(self.state[param]) > 0: + lr = self._get_lr(group, self.state[param]) + break + if lr is not None: + lrs.append(lr) + elif group["lr"] is not None: + # Fallback to group lr if state not initialized + lrs.append(group["lr"]) + if len(lrs) == 0: lrs = self.base_lrs # if called before stepping return lrs @@ -288,6 +323,7 @@ def step(self, closure=None): if factored: state["exp_avg_sq_row"] = torch.zeros( grad_shape[:-1]).to(grad) + # For 2D tensors, grad_shape[:-2] is empty tuple, which is correct for column stats state["exp_avg_sq_col"] = torch.zeros( grad_shape[:-2] + grad_shape[-1:]).to(grad) else: @@ -306,8 +342,9 @@ def step(self, closure=None): state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) p_data_fp32 = p + is_quantized = isinstance(p_data_fp32, QBytesTensor) - if isinstance(p_data_fp32, QBytesTensor): + if is_quantized: p_data_fp32 = p_data_fp32.dequantize() if p.dtype != torch.float32: p_data_fp32 = p_data_fp32.clone().float() @@ -356,8 +393,54 @@ def step(self, closure=None): p_data_fp32.add_(-update) - if p.dtype != torch.float32 and self.stochastic_rounding: + # Store update RMS for monitoring + state["update_rms"] = self._rms(update).item() + + if (p.dtype != torch.float32 or is_quantized) and self.stochastic_rounding: # apply stochastic rounding copy_stochastic(p, p_data_fp32) return loss + + def get_avg_learning_rate(self): + lrs = self.get_learning_rates() + if len(lrs) == 0: + return 0.0 + return sum(lrs) / len(lrs) + + def get_update_rms(self): + """ + Get RMS (root mean square) of weight updates for each parameter group. + + Returns: + List[float]: RMS of weight updates for each parameter group. + Returns 0.0 for groups that haven't been updated yet. + """ + update_rms_list = [] + for group in self.param_groups: + group_rms_sum = 0.0 + group_count = 0 + for p in group["params"]: + if p in self.state and "update_rms" in self.state[p]: + group_rms_sum += self.state[p]["update_rms"] + group_count += 1 + if group_count > 0: + update_rms_list.append(group_rms_sum / group_count) + else: + update_rms_list.append(0.0) + return update_rms_list + + def get_avg_update_rms(self): + """ + Get average RMS of weight updates across all parameter groups. + + This metric represents the average magnitude of weight changes per optimization step. + Useful for monitoring training stability and convergence. + + Returns: + float: Average RMS of weight updates across all parameter groups. + """ + update_rms_list = self.get_update_rms() + if len(update_rms_list) == 0: + return 0.0 + return sum(update_rms_list) / len(update_rms_list) diff --git a/toolkit/paths.py b/toolkit/paths.py index edd36ce19..ad8a31394 100644 --- a/toolkit/paths.py +++ b/toolkit/paths.py @@ -1,4 +1,5 @@ import os +from typing import Optional TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config') @@ -22,3 +23,13 @@ def get_path(path): if not os.path.isabs(path): path = os.path.join(TOOLKIT_ROOT, path) return path + + +def normalize_path(path: Optional[str]) -> Optional[str]: + """Strip leading/trailing whitespace and trailing path separators. + Use for any path: model dirs, LoRA/safetensors files, etc. + """ + if not isinstance(path, str): + return path + path = path.strip() + return path.rstrip(os.sep).rstrip("/") diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 0f3c10a9b..9813a7056 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -114,6 +114,39 @@ def expand_to_batch(self, batch_size): pe.attention_mask = pe.attention_mask.expand(batch_size, -1) return pe + def shuffle_sequence(self) -> None: + """Shuffle token order along sequence dimension (dim=1), keeping first token fixed. In-place. No-op if seq_len <= 1.""" + def make_perm(seq_len: int, device: torch.device) -> torch.Tensor: + if seq_len <= 1: + return torch.arange(seq_len, device=device, dtype=torch.long) + return torch.cat([ + torch.zeros(1, device=device, dtype=torch.long), + 1 + torch.randperm(seq_len - 1, device=device, dtype=torch.long), + ]) + if isinstance(self.text_embeds, list) or isinstance(self.text_embeds, tuple): + te_list = list(self.text_embeds) + attn_list = list(self.attention_mask) if self.attention_mask is not None and isinstance(self.attention_mask, (list, tuple)) else None + attn_is_tuple = isinstance(self.attention_mask, tuple) if self.attention_mask is not None else False + for i, t in enumerate(te_list): + if t.dim() >= 2 and t.shape[1] > 1: + perm = make_perm(t.shape[1], t.device) + te_list[i] = t[:, perm, ...].contiguous() + if attn_list is not None and i < len(attn_list): + attn = attn_list[i] + if attn.dim() >= 2 and attn.shape[1] > 1: + attn_list[i] = attn[:, perm, ...].contiguous() + elif self.attention_mask is not None and not isinstance(self.attention_mask, (list, tuple)) and i == 0: + self.attention_mask = self.attention_mask[:, perm, ...].contiguous() + self.text_embeds = tuple(te_list) if isinstance(self.text_embeds, tuple) else te_list + if attn_list is not None: + self.attention_mask = tuple(attn_list) if attn_is_tuple else attn_list + else: + if self.text_embeds.dim() >= 2 and self.text_embeds.shape[1] > 1: + perm = make_perm(self.text_embeds.shape[1], self.text_embeds.device) + self.text_embeds = self.text_embeds[:, perm, ...].contiguous() + if self.attention_mask is not None and self.attention_mask.dim() >= 2 and self.attention_mask.shape[1] > 1: + self.attention_mask = self.attention_mask[:, perm, ...].contiguous() + def save(self, path: str): """ Save the prompt embeds to a file. diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index bac7f3fde..cc1e71033 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -25,6 +25,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.init_noise_sigma = 1.0 self.timestep_type = "linear" + self._alphas_cumprod = None # Lazy initialization with torch.no_grad(): # create weights for timesteps @@ -56,6 +57,60 @@ def __init__(self, *args, **kwargs): self.linear_timesteps_weights2 = hbsmntw_weighing pass + def _compute_alphas_cumprod(self): + """ + Compute equivalent alphas_cumprod for flow matching for SNR weighting compatibility. + + For flow matching: x_t = (1-t)*x_0 + t*noise + Equivalent SNR: (1-t)^2 / t^2 + + For DDPM: SNR = alphas_cumprod / (1 - alphas_cumprod) + Therefore: alphas_cumprod = (1-t)^2 + """ + num_timesteps = 1000 + # Create timesteps from 1000 to 0 (descending) + timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + + # Normalize to [0, 1] range + t = timesteps / 1000.0 + + # Clamp to avoid numerical issues at boundaries + t = torch.clamp(t, min=1e-8, max=1.0 - 1e-8) + + # Compute equivalent alphas_cumprod: (1-t)^2 + # This gives correct SNR: alphas_cumprod / (1 - alphas_cumprod) = (1-t)^2 / t^2 + alphas_cumprod = (1.0 - t) ** 2 + + return alphas_cumprod + + def compute_snr(self): + """ + Compute SNR for each timestep in flow matching. + + For flow matching: x_t = (1-t)*x_0 + t*noise + Signal = (1-t)*x_0, Noise = t*noise + SNR = (1-t)^2 / t^2 + """ + num_timesteps = 1000 + timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + t = timesteps / 1000.0 + + # Add small epsilon to avoid division by zero + epsilon = 1e-8 + snr = ((1.0 - t) ** 2) / (t ** 2 + epsilon) + + return snr + + @property + def alphas_cumprod(self): + """ + Returns equivalent alphas_cumprod for flow matching. + This provides compatibility with SNR weighting functions in train_tools.py. + """ + if self._alphas_cumprod is None: + self._alphas_cumprod = self._compute_alphas_cumprod() + return self._alphas_cumprod + def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False, timestep_type="linear") -> torch.Tensor: # Get the indices of the timesteps step_indices = [(self.timesteps == t).nonzero().item() diff --git a/toolkit/scheduler.py b/toolkit/scheduler.py index f6f8f61ae..efd4cf5e0 100644 --- a/toolkit/scheduler.py +++ b/toolkit/scheduler.py @@ -3,22 +3,130 @@ from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION, get_constant_schedule_with_warmup +class SequentialLRWrapper(torch.optim.lr_scheduler.SequentialLR): + """ + Wrapper for SequentialLR that ignores extra arguments to step(). + + This is needed because the training code calls lr_scheduler.step(step_num), + but SequentialLR.step() doesn't accept any arguments. + """ + def step(self, *args, **kwargs): + # Ignore all arguments and call parent step() without them + super().step() + + +def _create_scheduler_with_warmup( + scheduler_type: str, + optimizer: torch.optim.Optimizer, + **scheduler_kwargs +): + """ + Creates a scheduler with optional warmup period using SequentialLR. + + Args: + scheduler_type: 'cosine' or 'cosine_with_restarts' + optimizer: The optimizer to schedule + scheduler_kwargs: Parameters for the scheduler. Can include: + - warmup_steps: Number of warmup steps (default: 0, no warmup) + - total_iters: TOTAL number of iterations INCLUDING warmup (default: 1000) + - T_0/T_max: Iterations for MAIN scheduler, overrides calculation from total_iters + - Other scheduler parameters (T_mult, eta_min, etc.) + + Semantics: + - total_iters: Total training iterations (warmup + main scheduler) + - T_0/T_max: Main scheduler iterations (if specified, total_iters is ignored) + - If only total_iters specified: main_iters = total_iters - warmup_steps + - If T_0/T_max specified: main_iters = T_0/T_max (total_iters ignored) + + Returns: + A scheduler (SequentialLR if warmup_steps > 0, otherwise base scheduler) + """ + # Extract warmup_steps (default: 0, no warmup) + warmup_steps = scheduler_kwargs.pop('warmup_steps', 0) + + # Extract total_iters (GENERAL total, including warmup) + total_iters = scheduler_kwargs.pop('total_iters', 1000) + + # Calculate main scheduler iterations + # T_0/T_max have priority and specify main scheduler iterations directly + if scheduler_type == "cosine": + if 'T_max' in scheduler_kwargs: + # T_max specifies main scheduler iterations (ignores total_iters) + main_total_iters = scheduler_kwargs.pop('T_max') + else: + # Calculate from total_iters + main_total_iters = total_iters - warmup_steps + elif scheduler_type == "cosine_with_restarts": + if 'T_0' in scheduler_kwargs: + # T_0 specifies main scheduler iterations (ignores total_iters) + main_total_iters = scheduler_kwargs.pop('T_0') + else: + # Calculate from total_iters + main_total_iters = total_iters - warmup_steps + + # Validation: warn if configuration seems incorrect + if main_total_iters <= 0: + raise ValueError( + f"Main scheduler iterations must be positive, got {main_total_iters}. " + f"Check your total_iters ({total_iters}) and warmup_steps ({warmup_steps})." + ) + if warmup_steps > 0 and warmup_steps >= total_iters: + print(f"WARNING: warmup_steps ({warmup_steps}) >= total_iters ({total_iters}). " + f"The main scheduler will have very few or no iterations.") + + if warmup_steps <= 0: + # No warmup, create base scheduler directly + if scheduler_type == "cosine": + return torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=main_total_iters, **scheduler_kwargs + ) + elif scheduler_type == "cosine_with_restarts": + return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=main_total_iters, **scheduler_kwargs + ) + + # Create warmup scheduler (linear from ~0 to 1.0) + warmup_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=0.1, # 1/10 of the target LR + end_factor=1.0, # End at full LR + total_iters=warmup_steps + ) + + # Create main scheduler + if scheduler_type == "cosine": + main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=main_total_iters, **scheduler_kwargs + ) + elif scheduler_type == "cosine_with_restarts": + main_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, T_0=main_total_iters, **scheduler_kwargs + ) + + # Combine schedulers using SequentialLRWrapper + combined_scheduler = SequentialLRWrapper( + optimizer, + schedulers=[warmup_scheduler, main_scheduler], + milestones=[warmup_steps] + ) + + return combined_scheduler + + def get_lr_scheduler( name: Optional[str], optimizer: torch.optim.Optimizer, **kwargs, ): if name == "cosine": - if 'total_iters' in kwargs: - kwargs['T_max'] = kwargs.pop('total_iters') - return torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, **kwargs + # All parameters passed via kwargs, handled in _create_scheduler_with_warmup + return _create_scheduler_with_warmup( + "cosine", optimizer, **kwargs ) elif name == "cosine_with_restarts": - if 'total_iters' in kwargs: - kwargs['T_0'] = kwargs.pop('total_iters') - return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( - optimizer, **kwargs + # All parameters passed via kwargs, handled in _create_scheduler_with_warmup + return _create_scheduler_with_warmup( + "cosine_with_restarts", optimizer, **kwargs ) elif name == "step": diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 78e2183c3..f75c41961 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -624,7 +624,16 @@ def add_all_snr_to_noise_scheduler(noise_scheduler, device): try: if hasattr(noise_scheduler, "all_snr"): return - # compute it + + # Handle flow matching schedulers that have compute_snr method + if hasattr(noise_scheduler, "compute_snr") and callable(getattr(noise_scheduler, "compute_snr")): + with torch.no_grad(): + all_snr = noise_scheduler.compute_snr() + all_snr.requires_grad = False + noise_scheduler.all_snr = all_snr.to(device) + return + + # Standard DDPM/DDIM scheduler path with torch.no_grad(): alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) @@ -642,6 +651,15 @@ def add_all_snr_to_noise_scheduler(noise_scheduler, device): def get_all_snr(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): return noise_scheduler.all_snr.to(device) + + # Handle flow matching schedulers that have compute_snr method + if hasattr(noise_scheduler, "compute_snr") and callable(getattr(noise_scheduler, "compute_snr")): + with torch.no_grad(): + all_snr = noise_scheduler.compute_snr() + all_snr.requires_grad = False + return all_snr.to(device) + + # Standard DDPM/DDIM scheduler path # compute it with torch.no_grad(): alphas_cumprod = noise_scheduler.alphas_cumprod diff --git a/toolkit/util/debug.py b/toolkit/util/debug.py new file mode 100644 index 000000000..61d72a2f2 --- /dev/null +++ b/toolkit/util/debug.py @@ -0,0 +1,82 @@ +""" +Debug utilities. memory_debug context manager measures GPU/RAM around a block; +enabled via set_debug_config(config) with config.debug. Extensible to RAM later. +""" +import contextlib +from typing import Callable + +import torch + +_debug_config = None + + +def set_debug_config(config) -> None: + """Register the config object used to decide if memory debug is enabled (config.debug).""" + global _debug_config + _debug_config = config + + +def is_debug_enabled() -> bool: + """Return True if debug logging is enabled (config.debug). Used for optional debug messages.""" + if _debug_config is None: + return False + return bool(getattr(_debug_config, "debug", False)) + + +def _is_enabled_for_cuda() -> bool: + if _debug_config is None: + return False + if not getattr(_debug_config, "debug", False): + return False + return torch.cuda.is_available() + + +def _cuda_snapshot_mb(): + """Return (allocated_mb, max_allocated_mb).""" + return ( + torch.cuda.memory_allocated() / 2**20, + torch.cuda.max_memory_allocated() / 2**20, + ) + + +def _format_cuda_diff(label: str, before: tuple, after: tuple) -> list: + mem_before, max_before = before + mem_after, max_after = after + delta = mem_before - mem_after + delta_str = f"(freed {delta:.1f} MB)" if delta >= 0 else f"(+{-delta:.1f} MB)" + return [ + f"[DEBUG {label}] CUDA allocated: {mem_before:.1f} MB -> {mem_after:.1f} MB {delta_str}", + f"[DEBUG {label}] CUDA max: {max_before:.1f} MB -> {max_after:.1f} MB", + ] + + +@contextlib.contextmanager +def memory_debug( + print_fn: Callable[[str], None], + label: str, + kind: str = "cuda", +): + """ + Context manager: measure memory around the block and log if debug is enabled. + enabled is read from the config set via set_debug_config(); no need to pass it. + kind="cuda" measures CUDA allocated/max; other kinds (e.g. "ram") are stubs for now. + """ + if kind != "cuda": + yield + return + if not _is_enabled_for_cuda(): + yield + return + before = _cuda_snapshot_mb() + try: + yield + finally: + torch.cuda.synchronize() + after = _cuda_snapshot_mb() + for line in _format_cuda_diff(label, before, after): + print_fn(line) + + +def cuda_memory_debug(print_fn: Callable[[str], None], label: str): + """Alias for memory_debug(print_fn, label, kind="cuda").""" + return memory_debug(print_fn, label, kind="cuda") diff --git a/toolkit/util/device.py b/toolkit/util/device.py new file mode 100644 index 000000000..0a06817c0 --- /dev/null +++ b/toolkit/util/device.py @@ -0,0 +1,28 @@ +""" +Device utilities. safe_module_to_device moves a module in-place without using +Module.to(), avoiding PyTorch's swap_tensors path which fails on quantized +parameters (e.g. QLinear) that have requires_grad=False. +""" +from typing import Optional + +import torch + + +def safe_module_to_device( + module: torch.nn.Module, + device: torch.device, + dtype: Optional[torch.dtype] = None, +) -> None: + """ + Move module to device (and optionally dtype) by moving param/buffer .data in-place. + Avoids Module.to() which uses swap_tensors and can raise on tensors + that do not require gradients (e.g. QLinear in quantized models). + """ + for _, param in module.named_parameters(recurse=False): + if param.device != device or (dtype is not None and param.dtype != dtype): + param.data = param.data.to(device=device, dtype=dtype or param.dtype) + for _, buf in module.named_buffers(recurse=False): + if buf.device != device or (dtype is not None and buf.dtype != dtype): + buf.data = buf.data.to(device=device, dtype=dtype or buf.dtype) + for _, child in module.named_children(): + safe_module_to_device(child, device, dtype) diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py index 27d6733a5..af188be89 100644 --- a/toolkit/util/quantize.py +++ b/toolkit/util/quantize.py @@ -129,6 +129,8 @@ def quantize( def quantize_model( base_model: "BaseModel", model_to_quantize: torch.nn.Module, + use_accuracy_recovery: bool = True, + adapter_attr_name: str = "accuracy_recovery_adapter", ): from toolkit.dequantize import patch_dequantization_on_save @@ -140,7 +142,7 @@ def quantize_model( # patch the state dict method patch_dequantization_on_save(model_to_quantize) - if base_model.model_config.accuracy_recovery_adapter is not None: + if use_accuracy_recovery and base_model.model_config.accuracy_recovery_adapter is not None: from toolkit.config_modules import NetworkConfig from toolkit.lora_special import LoRASpecialNetwork diff --git a/ui/src/app/api/jobs/[jobID]/stop/route.ts b/ui/src/app/api/jobs/[jobID]/stop/route.ts index 73b352dfc..9d1ce421e 100644 --- a/ui/src/app/api/jobs/[jobID]/stop/route.ts +++ b/ui/src/app/api/jobs/[jobID]/stop/route.ts @@ -1,8 +1,20 @@ import { NextRequest, NextResponse } from 'next/server'; import { PrismaClient } from '@prisma/client'; +import { getTrainingFolder } from '@/server/settings'; +import path from 'path'; +import fs from 'fs'; const prisma = new PrismaClient(); +function isProcessAlive(pid: number): boolean { + try { + process.kill(pid, 0); + return true; + } catch { + return false; + } +} + export async function GET(request: NextRequest, { params }: { params: { jobID: string } }) { const { jobID } = await params; @@ -10,14 +22,54 @@ export async function GET(request: NextRequest, { params }: { params: { jobID: s where: { id: jobID }, }); - // update job status to 'running' - await prisma.job.update({ - where: { id: jobID }, - data: { - stop: true, - info: 'Stopping job...', - }, - }); + if (!job) { + return NextResponse.json({ error: 'Job not found' }, { status: 404 }); + } + + const markStopped = (info: string) => + prisma.job.update({ + where: { id: jobID }, + data: { + stop: true, + status: 'stopped', + info, + }, + }); + + if (job.status !== 'running') { + const updated = await markStopped(job.status === 'stopped' ? job.info : 'Job stopped'); + return NextResponse.json(updated); + } + + let pid: number | null = null; + try { + const trainingRoot = await getTrainingFolder(); + const pidPath = path.join(trainingRoot, job.name, 'pid.txt'); + if (fs.existsSync(pidPath)) { + const raw = fs.readFileSync(pidPath, 'utf8').trim(); + const n = parseInt(raw, 10); + if (Number.isInteger(n) && n > 0) pid = n; + } + } catch { + // pid file missing or unreadable — treat as no process + } + + if (pid === null) { + const updated = await markStopped('Job stopped'); + return NextResponse.json(updated); + } + + if (!isProcessAlive(pid)) { + const updated = await markStopped('Job stopped'); + return NextResponse.json(updated); + } + + try { + process.kill(pid, 'SIGTERM'); + } catch { + // Process may have exited; still mark as stopped + } - return NextResponse.json(job); + const updated = await markStopped('Job stopped'); + return NextResponse.json(updated); } diff --git a/ui/src/app/jobs/new/SimpleJob.tsx b/ui/src/app/jobs/new/SimpleJob.tsx index 5db650c30..fd20e39aa 100644 --- a/ui/src/app/jobs/new/SimpleJob.tsx +++ b/ui/src/app/jobs/new/SimpleJob.tsx @@ -531,12 +531,15 @@ export default function SimpleJob({ setJobConfig(value, 'config.process[0].train.content_or_style')} options={[ { value: 'balanced', label: 'Balanced' }, { value: 'content', label: 'High Noise' }, { value: 'style', label: 'Low Noise' }, + { value: 'gaussian', label: 'Gaussian (Normal)' }, + { value: 'fixed_cycle', label: 'Fixed Cycle' }, ]} /> + {jobConfig.config.process[0].train.content_or_style === 'fixed_cycle' && ( + <> + { + const arr = value + .split(',') + .map(s => parseFloat(s.trim())) + .filter(n => !isNaN(n)); + setJobConfig(arr, 'config.process[0].train.fixed_cycle_timesteps'); + }} + placeholder="eg. 999, 875, 750, 625, 500, 375, 250, 125" + /> + setJobConfig(value ? parseInt(value as string) : null, 'config.process[0].train.fixed_cycle_seed')} + placeholder="eg. 42 (leave empty for no shuffle)" + min={0} + /> + { + const arr = value + .split(',') + .map(s => parseFloat(s.trim())) + .filter(n => !isNaN(n)); + setJobConfig(arr.length > 0 ? arr : null, 'config.process[0].train.fixed_cycle_weight_peak_timesteps'); + }} + placeholder="eg. 500, 375 (leave empty to disable)" + /> + setJobConfig(value, 'config.process[0].train.fixed_cycle_weight_sigma')} + placeholder="eg. 372.8" + min={0} + step={0.1} + /> + + )}
@@ -675,6 +737,21 @@ export default function SimpleJob({ placeholder="eg. 1.0" min={0} /> + + setJobConfig(value, 'config.process[0].train.blank_prompt_probability') + } + placeholder="eg. 0.1 for 10%, 1.0 for 100%" + min={0} + max={1} + step={0.1} + /> )} diff --git a/ui/src/components/JobLossGraph.tsx b/ui/src/components/JobLossGraph.tsx index 78a77bfb5..1a736a0db 100644 --- a/ui/src/components/JobLossGraph.tsx +++ b/ui/src/components/JobLossGraph.tsx @@ -1,7 +1,7 @@ 'use client'; import { Job } from '@prisma/client'; -import useJobLossLog, { LossPoint } from '@/hooks/useJobLossLog'; +import useJobLossLog, { LossPoint, MetricFilter } from '@/hooks/useJobLossLog'; import { useMemo, useState, useEffect } from 'react'; import { ResponsiveContainer, LineChart, Line, XAxis, YAxis, Tooltip, CartesianGrid, Legend } from 'recharts'; @@ -65,7 +65,8 @@ function strokeForKey(key: string) { } export default function JobLossGraph({ job }: Props) { - const { series, lossKeys, status, refreshLoss } = useJobLossLog(job.id, 2000); + const [metricFilter, setMetricFilter] = useState('loss'); + const { series, filteredKeys, status, refreshLoss } = useJobLossLog(job.id, 2000, metricFilter); // Controls const [useLogScale, setUseLogScale] = useState(false); @@ -91,18 +92,18 @@ export default function JobLossGraph({ job }: Props) { useEffect(() => { setEnabled(prev => { const next = { ...prev }; - for (const k of lossKeys) { + for (const k of filteredKeys) { if (next[k] === undefined) next[k] = true; } // drop removed keys for (const k of Object.keys(next)) { - if (!lossKeys.includes(k)) delete next[k]; + if (!filteredKeys.includes(k)) delete next[k]; } return next; }); - }, [lossKeys]); + }, [filteredKeys]); - const activeKeys = useMemo(() => lossKeys.filter(k => enabled[k] !== false), [lossKeys, enabled]); + const activeKeys = useMemo(() => filteredKeys.filter(k => enabled[k] !== false), [filteredKeys, enabled]); const perSeries = useMemo(() => { // Build per-series processed point arrays (raw + smoothed), then merge by step for charting. @@ -318,7 +319,7 @@ export default function JobLossGraph({ job }: Props) { {/* Controls */}
-
+
@@ -329,13 +330,44 @@ export default function JobLossGraph({ job }: Props) {
+
+ +
+ setMetricFilter('loss')} + label="Loss" + /> + setMetricFilter('learning_rate')} + label="Learning Rate" + /> + setMetricFilter('diff_guidance')} + label="Diff Guidance" + /> + setMetricFilter('all')} + label="All" + /> + setMetricFilter('other')} + label="Other" + /> +
+
+
- {lossKeys.length === 0 ? ( -
No loss keys found yet.
+ {filteredKeys.length === 0 ? ( +
No keys found yet.
) : (
- {lossKeys.map(k => ( + {filteredKeys.map(k => (
-
+
{windowSize === 0 ? 'all' : windowSize.toLocaleString()} diff --git a/ui/src/components/SampleImageViewer.tsx b/ui/src/components/SampleImageViewer.tsx index e8804a9a4..799427835 100644 --- a/ui/src/components/SampleImageViewer.tsx +++ b/ui/src/components/SampleImageViewer.tsx @@ -82,7 +82,7 @@ export default function SampleImageViewer({ if (idx < 0 || idx >= sampleImages.length) return; onChange(sampleImages[idx]); }, - [sampleImages, numSamples, onChange], + [sampleImages, onChange], ); const currentIndex = useMemo(() => { @@ -167,9 +167,11 @@ export default function SampleImageViewer({ onCancel(); break; case 'ArrowUp': + event.preventDefault(); handleArrowUp(); break; case 'ArrowDown': + event.preventDefault(); handleArrowDown(); break; case 'ArrowLeft': diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index c82d800bd..6da72c11c 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -287,6 +287,78 @@ const docs: { [key: string]: ConfigDoc } = { ), }, + 'train.blank_prompt_probability': { + title: 'BPP Probability', + description: ( + <> + Controls how often the Blank Prompt Preservation check runs during training. + Value between 0.0 and 1.0. Default is 1.0 (runs every step). + Setting to 0.1 means BPP runs ~10% of steps, reducing training time by up to 45% + while still preventing model degradation. Lower values give the model more freedom + to adapt to new concepts between BPP corrections. Recommended: 0.1-0.2 for Turbo models. + + ), + }, + 'train.content_or_style': { + title: 'Timestep Bias', + description: ( + <> + Controls how timesteps are sampled during training: +

+ Balanced: Uniform distribution across all timesteps. +

+ High Noise: Cubic distribution favoring earlier timesteps (more noise). Best for learning content/structure. +

+ Low Noise: Cubic distribution favoring later timesteps (less noise). Best for learning style/details. +

+ Gaussian (Normal): Normal distribution with configurable center and spread. Use gaussian_mean and gaussian_std in YAML config: +
+ • gaussian_mean (default 0.5): Center of distribution. Lower values (0.0-0.5) = more noise/earlier timesteps, higher values (0.5-1.0) = less noise/later timesteps. +
+ • gaussian_std (default 0.2): Spread of distribution. Smaller = narrower focus, larger = wider coverage. +
+ • gaussian_std_target (optional, default None): Enable curriculum learning. When set, gaussian_std will linearly interpolate from initial value to this target value during training. Example: start with gaussian_std: 0.001 (narrow distribution, focused training) → end with gaussian_std_target: 0.3 (wide distribution, diverse timestep coverage). +
+ • timestep_bias_exponent (default 3.0): Controls the cubic bias exponent for content/style timestep distribution. Higher values create stronger bias toward edges (early timesteps for content, late timesteps for style). Lower values create more uniform distribution. +

+ Fixed Cycle: Deterministic cycle over a fixed list of timestep values. Same step number always gets the same timestep, so training is reproducible. Recommended for distilled/Turbo models (e.g. Z-Image-Turbo LoRA) that are sensitive to random timestep sampling. Configure via YAML: fixed_cycle_timesteps, fixed_cycle_seed, fixed_cycle_weight_peak_timesteps. + + ), + }, + 'train.fixed_cycle_timesteps': { + title: 'Fixed Cycle Timesteps', + description: ( + <> + Used when content_or_style is fixed_cycle. List of timestep values (scheduler scale, typically 0–1000) to cycle through deterministically. At step s, the whole batch uses fixed_cycle_timesteps[s % length] (after optional shuffle by fixed_cycle_seed). Values are snapped to the nearest scheduler timestep. +

+ Default: [999, 875, 750, 625, 500, 375, 250, 125]. Set in config file (e.g. YAML) as an array of numbers. + + ), + }, + 'train.fixed_cycle_seed': { + title: 'Fixed Cycle Seed', + description: ( + <> + Used when content_or_style is fixed_cycle. If set, the list fixed_cycle_timesteps is shuffled once at the start of training using this seed. Same seed gives the same cycle order and reproducible runs. If null or unset, the cycle uses the order from the config. + + ), + }, + 'train.fixed_cycle_weight_peak_timesteps': { + title: 'Fixed Cycle Weight Peak Timesteps', + description: ( + <> + Used when content_or_style is fixed_cycle. Enables timestep-weighted loss (like timestep_type: weighted): loss is multiplied by weights with peaks at these timestep values. Default: [500, 375] so maximum weight is around 500 and 375. Set to null or empty to disable weighting. Set in config file as an array of numbers. + + ), + }, + 'train.fixed_cycle_weight_sigma': { + title: 'Fixed Cycle Weight Sigma', + description: ( + <> + Used when content_or_style is fixed_cycle and fixed_cycle_weight_peak_timesteps is set. Standard deviation (sigma) for the Gaussian distribution used to weight loss around the peak timesteps. Controls the width of the Gaussian peaks — larger values create wider, flatter distributions, smaller values create sharper peaks. Default: 372.8. + + ), + }, 'train.do_differential_guidance': { title: 'Differential Guidance', description: ( diff --git a/ui/src/hooks/useJobLossLog.tsx b/ui/src/hooks/useJobLossLog.tsx index a0bf27ed3..5c4d062cb 100644 --- a/ui/src/hooks/useJobLossLog.tsx +++ b/ui/src/hooks/useJobLossLog.tsx @@ -11,13 +11,33 @@ export interface LossPoint { type SeriesMap = Record; +export type MetricFilter = 'loss' | 'learning_rate' | 'diff_guidance' | 'all' | 'other'; + +function categorizeMetric(key: string): 'loss' | 'learning_rate' | 'diff_guidance' | 'other' { + if (key === 'learning_rate') return 'learning_rate'; + if (key === 'diff_guidance_norm') return 'diff_guidance'; + if (/loss/i.test(key)) return 'loss'; + return 'other'; +} + +function matchesFilter(key: string, filter: MetricFilter): boolean { + if (filter === 'all') return true; + const category = categorizeMetric(key); + if (filter === 'other') return category === 'other'; + return category === filter; +} + function isLossKey(key: string) { // treat anything containing "loss" as a loss-series // (covers loss, train_loss, val_loss, loss/xyz, etc.) return /loss/i.test(key); } -export default function useJobLossLog(jobID: string, reloadInterval: null | number = null) { +export default function useJobLossLog( + jobID: string, + reloadInterval: null | number = null, + metricFilter: MetricFilter = 'loss' +) { const [series, setSeries] = useState({}); const [keys, setKeys] = useState([]); const [status, setStatus] = useState<'idle' | 'loading' | 'success' | 'error' | 'refreshing'>('idle'); @@ -28,12 +48,12 @@ export default function useJobLossLog(jobID: string, reloadInterval: null | numb // track last step per key so polling is incremental per series const lastStepByKeyRef = useRef>({}); - const lossKeys = useMemo(() => { - const base = (keys ?? []).filter(isLossKey); + const filteredKeys = useMemo(() => { + const filtered = (keys ?? []).filter(k => matchesFilter(k, metricFilter)); // if keys table is empty early on, fall back to just "loss" - if (base.length === 0) return ['loss']; - return base.sort(); - }, [keys]); + if (filtered.length === 0 && metricFilter === 'loss') return ['loss']; + return filtered.sort(); + }, [keys, metricFilter]); const refreshLoss = useCallback(async () => { if (!jobID) return; @@ -54,10 +74,13 @@ export default function useJobLossLog(jobID: string, reloadInterval: null | numb const newKeys = first.keys ?? []; setKeys(newKeys); - const wantedLossKeys = (newKeys.filter(isLossKey).length ? newKeys.filter(isLossKey) : ['loss']).sort(); + const wantedKeys = (newKeys.filter(k => matchesFilter(k, metricFilter)).length + ? newKeys.filter(k => matchesFilter(k, metricFilter)) + : (metricFilter === 'loss' ? ['loss'] : []) + ).sort(); // Step 2: fetch each loss key incrementally (since_step per key if polling) - const requests = wantedLossKeys.map(k => { + const requests = wantedKeys.map(k => { const params: Record = { key: k }; if (reloadInterval && lastStepByKeyRef.current[k] != null) { @@ -101,9 +124,9 @@ export default function useJobLossLog(jobID: string, reloadInterval: null | numb : (lastStepByKeyRef.current[k] ?? null); } - // remove stale loss keys that no longer exist (rare, but keeps UI clean) + // remove stale keys that no longer exist (rare, but keeps UI clean) for (const existingKey of Object.keys(next)) { - if (isLossKey(existingKey) && !wantedLossKeys.includes(existingKey)) { + if (matchesFilter(existingKey, metricFilter) && !wantedKeys.includes(existingKey)) { delete next[existingKey]; delete lastStepByKeyRef.current[existingKey]; } @@ -120,7 +143,7 @@ export default function useJobLossLog(jobID: string, reloadInterval: null | numb } finally { inFlightRef.current = false; } - }, [jobID, reloadInterval]); + }, [jobID, reloadInterval, metricFilter]); useEffect(() => { // reset when job changes @@ -141,5 +164,5 @@ export default function useJobLossLog(jobID: string, reloadInterval: null | numb } }, [jobID, reloadInterval, refreshLoss]); - return { series, keys, lossKeys, status, refreshLoss, setSeries }; + return { series, keys, filteredKeys, status, refreshLoss, setSeries }; } diff --git a/ui/src/types.ts b/ui/src/types.ts index 07f1e1d18..8e1d692aa 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -142,10 +142,15 @@ export interface TrainConfig { diff_output_preservation_class: string; blank_prompt_preservation?: boolean; blank_prompt_preservation_multiplier?: number; + blank_prompt_probability?: number; switch_boundary_every: number; loss_type: 'mse' | 'mae' | 'wavelet' | 'stepped'; do_differential_guidance?: boolean; differential_guidance_scale?: number; + fixed_cycle_timesteps?: number[]; + fixed_cycle_seed?: number | null; + fixed_cycle_weight_peak_timesteps?: number[] | null; + fixed_cycle_weight_sigma?: number; } export interface QuantizeKwargsConfig {