|
| 1 | +# LCM Distillation for InstructPix2Pix SDXL |
| 2 | + |
| 3 | +This repository contains a training script for distilling Latent Consistency Models (LCM) for InstructPix2Pix with Stable Diffusion XL (SDXL). The script enables fast, few-step image editing by distilling knowledge from a teacher model into a student model. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +This implementation performs **LCM distillation** on InstructPix2Pix SDXL models, which allows for: |
| 8 | +- **Fast image editing** with significantly fewer sampling steps (1-4 steps vs 50+ steps) |
| 9 | +- **Instruction-based image manipulation** using text prompts |
| 10 | +- **High-quality outputs** that maintain the teacher model's capabilities |
| 11 | + |
| 12 | +The training uses a teacher-student distillation approach where: |
| 13 | +- **Teacher**: Pre-trained InstructPix2Pix SDXL model (8-channel input U-Net) |
| 14 | +- **Student**: Lightweight model with time conditioning that learns to match teacher outputs |
| 15 | +- **Target**: EMA (Exponential Moving Average) version of student for stable training |
| 16 | + |
| 17 | +## Requirements |
| 18 | +```bash |
| 19 | +pip install torch torchvision |
| 20 | +pip install diffusers transformers accelerate |
| 21 | +pip install datasets pillow |
| 22 | +pip install tensorboard # or wandb for logging |
| 23 | +pip install xformers # optional, for memory efficiency |
| 24 | +pip install bitsandbytes # optional, for 8-bit Adam optimizer |
| 25 | +``` |
| 26 | +## Dataset Format |
| 27 | + |
| 28 | +The script expects datasets with three components per sample: |
| 29 | +1. **Original Image**: The input image to be edited |
| 30 | +2. **Edit Prompt**: Text instruction describing the desired edit |
| 31 | +3. **Edited Image**: The target output after applying the edit |
| 32 | + |
| 33 | +### Supported Formats |
| 34 | + |
| 35 | +**Option 1: HuggingFace Dataset** |
| 36 | +```bash |
| 37 | +python train_lcm_distil_instruct_pix2pix_sdxl.py \ |
| 38 | + --dataset_name="your/dataset-name" \ |
| 39 | + --dataset_config_name="default" |
| 40 | +``` |
| 41 | +**Option 2: Local ImageFolder** |
| 42 | +```bash |
| 43 | +python train_lcm_distil_instruct_pix2pix_sdxl.py \ |
| 44 | + --train_data_dir="./data/train" |
| 45 | +``` |
| 46 | +## Key Arguments |
| 47 | + |
| 48 | +### Model Configuration |
| 49 | + |
| 50 | +- `--pretrained_teacher_model`: Path/ID of the teacher InstructPix2Pix SDXL model |
| 51 | +- `--pretrained_vae_model_name_or_path`: Optional separate VAE model path |
| 52 | +- `--vae_precision`: VAE precision (`fp16`, `fp32`, `bf16`) |
| 53 | +- `--unet_time_cond_proj_dim`: Time conditioning projection dimension (default: 256) |
| 54 | + |
| 55 | +### Dataset Arguments |
| 56 | + |
| 57 | +- `--dataset_name`: HuggingFace dataset name |
| 58 | +- `--train_data_dir`: Local training data directory |
| 59 | +- `--original_image_column`: Column name for original images |
| 60 | +- `--edit_prompt_column`: Column name for edit prompts |
| 61 | +- `--edited_image_column`: Column name for edited images |
| 62 | +- `--max_train_samples`: Limit number of training samples |
| 63 | + |
| 64 | +### Training Configuration |
| 65 | + |
| 66 | +- `--resolution`: Image resolution (default: 512) |
| 67 | +- `--train_batch_size`: Batch size per device |
| 68 | +- `--num_train_epochs`: Number of training epochs |
| 69 | +- `--max_train_steps`: Maximum training steps |
| 70 | +- `--gradient_accumulation_steps`: Gradient accumulation steps |
| 71 | +- `--learning_rate`: Learning rate (default: 1e-4) |
| 72 | +- `--lr_scheduler`: Learning rate scheduler type |
| 73 | +- `--lr_warmup_steps`: Number of warmup steps |
| 74 | + |
| 75 | +### LCM-Specific Arguments |
| 76 | + |
| 77 | +- `--w_min`: Minimum guidance scale for sampling (default: 3.0) |
| 78 | +- `--w_max`: Maximum guidance scale for sampling (default: 15.0) |
| 79 | +- `--num_ddim_timesteps`: Number of DDIM timesteps (default: 50) |
| 80 | +- `--loss_type`: Loss function type (`l2` or `huber`) |
| 81 | +- `--huber_c`: Huber loss parameter (default: 0.001) |
| 82 | +- `--ema_decay`: EMA decay rate for target network (default: 0.95) |
| 83 | + |
| 84 | +### Optimization |
| 85 | + |
| 86 | +- `--use_8bit_adam`: Use 8-bit Adam optimizer |
| 87 | +- `--adam_beta1`, `--adam_beta2`: Adam optimizer betas |
| 88 | +- `--adam_weight_decay`: Weight decay |
| 89 | +- `--adam_epsilon`: Adam epsilon |
| 90 | +- `--max_grad_norm`: Maximum gradient norm for clipping |
| 91 | +- `--mixed_precision`: Mixed precision training (`no`, `fp16`, `bf16`) |
| 92 | +- `--gradient_checkpointing`: Enable gradient checkpointing |
| 93 | +- `--enable_xformers_memory_efficient_attention`: Use xFormers |
| 94 | +- `--allow_tf32`: Allow TF32 on Ampere GPUs |
| 95 | + |
| 96 | +### Validation |
| 97 | + |
| 98 | +- `--val_image_url_or_path`: Validation image path/URL |
| 99 | +- `--validation_prompt`: Validation edit prompt |
| 100 | +- `--num_validation_images`: Number of validation images to generate |
| 101 | +- `--validation_steps`: Validate every N steps |
| 102 | + |
| 103 | +### Logging & Checkpointing |
| 104 | + |
| 105 | +- `--output_dir`: Output directory for checkpoints |
| 106 | +- `--logging_dir`: TensorBoard logging directory |
| 107 | +- `--report_to`: Reporting integration (`tensorboard`, `wandb`) |
| 108 | +- `--checkpointing_steps`: Save checkpoint every N steps |
| 109 | +- `--checkpoints_total_limit`: Maximum number of checkpoints to keep |
| 110 | +- `--resume_from_checkpoint`: Resume from checkpoint path |
| 111 | + |
| 112 | +### Hub Integration |
| 113 | + |
| 114 | +- `--push_to_hub`: Push model to HuggingFace Hub |
| 115 | +- `--hub_token`: HuggingFace Hub token |
| 116 | +- `--hub_model_id`: Hub model ID |
| 117 | + |
| 118 | +## Training Example |
| 119 | + |
| 120 | +### Basic Training |
| 121 | + |
| 122 | +```bash |
| 123 | +python train_lcm_distil_instruct_pix2pix_sdxl.py \ |
| 124 | + --pretrained_teacher_model="diffusers/sdxl-instructpix2pix" \ |
| 125 | + --dataset_name="your/instruct-pix2pix-dataset" \ |
| 126 | + --output_dir="./output/lcm-sdxl-instruct" \ |
| 127 | + --resolution=768 \ |
| 128 | + --train_batch_size=4 \ |
| 129 | + --gradient_accumulation_steps=4 \ |
| 130 | + --learning_rate=1e-4 \ |
| 131 | + --max_train_steps=10000 \ |
| 132 | + --validation_steps=500 \ |
| 133 | + --checkpointing_steps=500 \ |
| 134 | + --mixed_precision="fp16" \ |
| 135 | + --seed=42 |
| 136 | +``` |
| 137 | +### Advanced Training with Optimizations |
| 138 | + |
| 139 | +```bash |
| 140 | +accelerate launch --multi_gpu train_lcm_distil_instruct_pix2pix_sdxl.py \ |
| 141 | + --pretrained_teacher_model="diffusers/sdxl-instructpix2pix" \ |
| 142 | + --dataset_name="your/instruct-pix2pix-dataset" \ |
| 143 | + --output_dir="./output/lcm-sdxl-instruct" \ |
| 144 | + --resolution=768 \ |
| 145 | + --train_batch_size=2 \ |
| 146 | + --gradient_accumulation_steps=8 \ |
| 147 | + --learning_rate=5e-5 \ |
| 148 | + --max_train_steps=20000 \ |
| 149 | + --num_ddim_timesteps=50 \ |
| 150 | + --w_min=3.0 \ |
| 151 | + --w_max=15.0 \ |
| 152 | + --ema_decay=0.95 \ |
| 153 | + --loss_type="huber" \ |
| 154 | + --huber_c=0.001 \ |
| 155 | + --mixed_precision="bf16" \ |
| 156 | + --gradient_checkpointing \ |
| 157 | + --enable_xformers_memory_efficient_attention \ |
| 158 | + --use_8bit_adam \ |
| 159 | + --validation_steps=250 \ |
| 160 | + --val_image_url_or_path="path/to/val_image.jpg" \ |
| 161 | + --validation_prompt="make it sunset" \ |
| 162 | + --num_validation_images=4 \ |
| 163 | + --checkpointing_steps=500 \ |
| 164 | + --checkpoints_total_limit=3 \ |
| 165 | + --push_to_hub \ |
| 166 | + --hub_model_id="your-username/lcm-sdxl-instruct" \ |
| 167 | + --report_to="wandb" |
| 168 | +``` |
| 169 | +## How It Works |
| 170 | + |
| 171 | +### Architecture |
| 172 | + |
| 173 | +1. **Teacher U-Net**: Pre-trained 8-channel InstructPix2Pix SDXL U-Net |
| 174 | + - Input: Concatenated noisy latent + original image latent (8 channels) |
| 175 | + - Performs multi-step diffusion with classifier-free guidance |
| 176 | + |
| 177 | +2. **Student U-Net**: Distilled model with time conditioning |
| 178 | + - Learns to predict in a single step what teacher predicts in multiple steps |
| 179 | + - Uses guidance scale embedding for conditioning |
| 180 | + |
| 181 | +3. **Target U-Net**: EMA copy of student |
| 182 | + - Provides stable training targets |
| 183 | + - Updated with exponential moving average |
| 184 | + |
| 185 | +### Training Process |
| 186 | + |
| 187 | +The training loop implements the LCM distillation algorithm: |
| 188 | + |
| 189 | +1. **Sample timestep** from DDIM schedule |
| 190 | +2. **Add noise** to latents at sampled timestep |
| 191 | +3. **Sample guidance scale** $w$ from uniform distribution $[w_{min}, w_{max}]$ |
| 192 | +4. **Student prediction**: Single-step prediction from noisy latents |
| 193 | +5. **Teacher prediction**: Multi-step DDIM prediction with CFG |
| 194 | +6. **Target prediction**: Prediction from EMA target network |
| 195 | +7. **Compute loss**: L2 or Huber loss between student and target |
| 196 | +8. **Update**: Backpropagate and update student, then EMA update target |
| 197 | + |
| 198 | +### Loss Functions |
| 199 | + |
| 200 | +**L2 Loss:** |
| 201 | +$$\mathcal{L} = \text{MSE}(\text{model\_pred}, \text{target})$$ |
| 202 | + |
| 203 | +**Huber Loss:** |
| 204 | +$$\mathcal{L} = \sqrt{(\text{model\_pred} - \text{target})^2 + c^2} - c$$ |
| 205 | + |
| 206 | +## Output Structure |
| 207 | + |
| 208 | +After training, the output directory contains: |
| 209 | + |
| 210 | + |
| 211 | +output_dir/ |
| 212 | +├── unet/ # Final student U-Net |
| 213 | +├── unet_target/ # Final target U-Net (recommended for inference) |
| 214 | +├── text_encoder/ # Text encoder (copied from teacher) |
| 215 | +├── text_encoder_2/ # Second text encoder (SDXL) |
| 216 | +├── tokenizer/ # Tokenizer |
| 217 | +├── tokenizer_2/ # Second tokenizer |
| 218 | +├── vae/ # VAE |
| 219 | +├── scheduler/ # LCM Scheduler |
| 220 | +├── checkpoint-{step}/ # Training checkpoints |
| 221 | +└── logs/ # TensorBoard logs |
| 222 | + |
| 223 | +## Inference |
| 224 | + |
| 225 | +After training, use the model for fast image editing: |
| 226 | + |
| 227 | +python |
| 228 | +from diffusers import StableDiffusionXLInstructPix2PixPipeline, LCMScheduler |
| 229 | +from PIL import Image |
| 230 | + |
| 231 | +# Load the trained model |
| 232 | +```bash |
| 233 | +pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( |
| 234 | +"./output/lcm-sdxl-instruct", |
| 235 | +torch_dtype=torch.float16 |
| 236 | +).to("cuda") |
| 237 | + |
| 238 | +# Load image |
| 239 | +image = Image.open("input.jpg") |
| 240 | + |
| 241 | +# Edit with just 4 steps! |
| 242 | +edited_image = pipeline( |
| 243 | +prompt="make it a sunset scene", |
| 244 | +image=image, |
| 245 | +num_inference_steps=4, |
| 246 | +guidance_scale=7.5, |
| 247 | +image_guidance_scale=1.5 |
| 248 | +).images[0] |
| 249 | + |
| 250 | +edited_image.save("output.jpg") |
| 251 | +``` |
| 252 | +## Tips & Best Practices |
| 253 | + |
| 254 | +### Memory Optimization |
| 255 | +- Use `--gradient_checkpointing` to reduce memory usage |
| 256 | +- Enable `--enable_xformers_memory_efficient_attention` for efficiency |
| 257 | +- Use `--mixed_precision="fp16"` or `"bf16"` |
| 258 | +- Reduce `--train_batch_size` and increase `--gradient_accumulation_steps` |
| 259 | + |
| 260 | +### Training Stability |
| 261 | +- Start with `--ema_decay=0.95` for stable target updates |
| 262 | +- Use `--loss_type="huber"` for more robust training |
| 263 | +- Adjust `--w_min` and `--w_max` based on your dataset |
| 264 | +- Monitor validation outputs regularly |
| 265 | + |
| 266 | +### Quality vs Speed |
| 267 | +- More `--num_ddim_timesteps` = better teacher guidance but slower training |
| 268 | +- Higher `--ema_decay` = more stable but slower convergence |
| 269 | +- Experiment with different `--learning_rate` values (1e-5 to 5e-4) |
| 270 | + |
| 271 | +### Multi-GPU Training |
| 272 | +Use Accelerate for distributed training: |
| 273 | +bash |
| 274 | +accelerate config # Configure once |
| 275 | +accelerate launch train_lcm_distil_instruct_pix2pix_sdxl.py [args] |
| 276 | + |
| 277 | +## Troubleshooting |
| 278 | + |
| 279 | +**NaN Loss**: |
| 280 | +- Try `--vae_precision="fp32"` |
| 281 | +- Reduce learning rate |
| 282 | +- Use gradient clipping with appropriate `--max_grad_norm` |
| 283 | + |
| 284 | +**Out of Memory**: |
| 285 | +- Enable gradient checkpointing |
| 286 | +- Reduce batch size |
| 287 | +- Lower resolution |
| 288 | +- Use xFormers attention |
| 289 | + |
| 290 | +**Poor Quality**: |
| 291 | +- Increase training steps |
| 292 | +- Adjust guidance scale range |
| 293 | +- Check dataset quality |
| 294 | +- Validate teacher model performance first |
| 295 | + |
| 296 | +## Citation |
| 297 | + |
| 298 | +If you use this code, please cite the relevant papers: |
| 299 | + |
| 300 | +bibtex |
| 301 | +@article{luo2023latent, |
| 302 | + title={Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference}, |
| 303 | + author={Luo, Simian and Tan, Yiqin and Huang, Longbo and Li, Jian and Zhao, Hang}, |
| 304 | + journal={arXiv preprint arXiv:2310.04378}, |
| 305 | + year={2023} |
| 306 | +} |
| 307 | + |
| 308 | +@article{brooks2023instructpix2pix, |
| 309 | + title={InstructPix2Pix: Learning to Follow Image Editing Instructions}, |
| 310 | + author={Brooks, Tim and Holynski, Aleksander and Efros, Alexei A}, |
| 311 | + journal={CVPR}, |
| 312 | + year={2023} |
| 313 | +} |
| 314 | + |
| 315 | +## License |
| 316 | + |
| 317 | +Please refer to the original model licenses and dataset licenses when using this code. |
| 318 | + |
| 319 | +## Acknowledgments |
| 320 | + |
| 321 | +This implementation is based on: |
| 322 | +- [Diffusers](https://github.com/huggingface/diffusers) library |
| 323 | +- Latent Consistency Models paper |
| 324 | +- InstructPix2Pix methodology |
| 325 | +- Stable Diffusion XL architecture |
| 326 | + |
| 327 | +Developer by (https://medium.com/@mzeynali01) |
0 commit comments