Skip to content

Commit feb721e

Browse files
committed
added training for lcm distil instruct-pix2pix-sdxl
1 parent db32cc7 commit feb721e

File tree

2 files changed

+1848
-0
lines changed

2 files changed

+1848
-0
lines changed
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
# LCM Distillation for InstructPix2Pix SDXL
2+
3+
This repository contains a training script for distilling Latent Consistency Models (LCM) for InstructPix2Pix with Stable Diffusion XL (SDXL). The script enables fast, few-step image editing by distilling knowledge from a teacher model into a student model.
4+
5+
## Overview
6+
7+
This implementation performs **LCM distillation** on InstructPix2Pix SDXL models, which allows for:
8+
- **Fast image editing** with significantly fewer sampling steps (1-4 steps vs 50+ steps)
9+
- **Instruction-based image manipulation** using text prompts
10+
- **High-quality outputs** that maintain the teacher model's capabilities
11+
12+
The training uses a teacher-student distillation approach where:
13+
- **Teacher**: Pre-trained InstructPix2Pix SDXL model (8-channel input U-Net)
14+
- **Student**: Lightweight model with time conditioning that learns to match teacher outputs
15+
- **Target**: EMA (Exponential Moving Average) version of student for stable training
16+
17+
## Requirements
18+
```bash
19+
pip install torch torchvision
20+
pip install diffusers transformers accelerate
21+
pip install datasets pillow
22+
pip install tensorboard # or wandb for logging
23+
pip install xformers # optional, for memory efficiency
24+
pip install bitsandbytes # optional, for 8-bit Adam optimizer
25+
```
26+
## Dataset Format
27+
28+
The script expects datasets with three components per sample:
29+
1. **Original Image**: The input image to be edited
30+
2. **Edit Prompt**: Text instruction describing the desired edit
31+
3. **Edited Image**: The target output after applying the edit
32+
33+
### Supported Formats
34+
35+
**Option 1: HuggingFace Dataset**
36+
```bash
37+
python train_lcm_distil_instruct_pix2pix_sdxl.py \
38+
--dataset_name="your/dataset-name" \
39+
--dataset_config_name="default"
40+
```
41+
**Option 2: Local ImageFolder**
42+
```bash
43+
python train_lcm_distil_instruct_pix2pix_sdxl.py \
44+
--train_data_dir="./data/train"
45+
```
46+
## Key Arguments
47+
48+
### Model Configuration
49+
50+
- `--pretrained_teacher_model`: Path/ID of the teacher InstructPix2Pix SDXL model
51+
- `--pretrained_vae_model_name_or_path`: Optional separate VAE model path
52+
- `--vae_precision`: VAE precision (`fp16`, `fp32`, `bf16`)
53+
- `--unet_time_cond_proj_dim`: Time conditioning projection dimension (default: 256)
54+
55+
### Dataset Arguments
56+
57+
- `--dataset_name`: HuggingFace dataset name
58+
- `--train_data_dir`: Local training data directory
59+
- `--original_image_column`: Column name for original images
60+
- `--edit_prompt_column`: Column name for edit prompts
61+
- `--edited_image_column`: Column name for edited images
62+
- `--max_train_samples`: Limit number of training samples
63+
64+
### Training Configuration
65+
66+
- `--resolution`: Image resolution (default: 512)
67+
- `--train_batch_size`: Batch size per device
68+
- `--num_train_epochs`: Number of training epochs
69+
- `--max_train_steps`: Maximum training steps
70+
- `--gradient_accumulation_steps`: Gradient accumulation steps
71+
- `--learning_rate`: Learning rate (default: 1e-4)
72+
- `--lr_scheduler`: Learning rate scheduler type
73+
- `--lr_warmup_steps`: Number of warmup steps
74+
75+
### LCM-Specific Arguments
76+
77+
- `--w_min`: Minimum guidance scale for sampling (default: 3.0)
78+
- `--w_max`: Maximum guidance scale for sampling (default: 15.0)
79+
- `--num_ddim_timesteps`: Number of DDIM timesteps (default: 50)
80+
- `--loss_type`: Loss function type (`l2` or `huber`)
81+
- `--huber_c`: Huber loss parameter (default: 0.001)
82+
- `--ema_decay`: EMA decay rate for target network (default: 0.95)
83+
84+
### Optimization
85+
86+
- `--use_8bit_adam`: Use 8-bit Adam optimizer
87+
- `--adam_beta1`, `--adam_beta2`: Adam optimizer betas
88+
- `--adam_weight_decay`: Weight decay
89+
- `--adam_epsilon`: Adam epsilon
90+
- `--max_grad_norm`: Maximum gradient norm for clipping
91+
- `--mixed_precision`: Mixed precision training (`no`, `fp16`, `bf16`)
92+
- `--gradient_checkpointing`: Enable gradient checkpointing
93+
- `--enable_xformers_memory_efficient_attention`: Use xFormers
94+
- `--allow_tf32`: Allow TF32 on Ampere GPUs
95+
96+
### Validation
97+
98+
- `--val_image_url_or_path`: Validation image path/URL
99+
- `--validation_prompt`: Validation edit prompt
100+
- `--num_validation_images`: Number of validation images to generate
101+
- `--validation_steps`: Validate every N steps
102+
103+
### Logging & Checkpointing
104+
105+
- `--output_dir`: Output directory for checkpoints
106+
- `--logging_dir`: TensorBoard logging directory
107+
- `--report_to`: Reporting integration (`tensorboard`, `wandb`)
108+
- `--checkpointing_steps`: Save checkpoint every N steps
109+
- `--checkpoints_total_limit`: Maximum number of checkpoints to keep
110+
- `--resume_from_checkpoint`: Resume from checkpoint path
111+
112+
### Hub Integration
113+
114+
- `--push_to_hub`: Push model to HuggingFace Hub
115+
- `--hub_token`: HuggingFace Hub token
116+
- `--hub_model_id`: Hub model ID
117+
118+
## Training Example
119+
120+
### Basic Training
121+
122+
```bash
123+
python train_lcm_distil_instruct_pix2pix_sdxl.py \
124+
--pretrained_teacher_model="diffusers/sdxl-instructpix2pix" \
125+
--dataset_name="your/instruct-pix2pix-dataset" \
126+
--output_dir="./output/lcm-sdxl-instruct" \
127+
--resolution=768 \
128+
--train_batch_size=4 \
129+
--gradient_accumulation_steps=4 \
130+
--learning_rate=1e-4 \
131+
--max_train_steps=10000 \
132+
--validation_steps=500 \
133+
--checkpointing_steps=500 \
134+
--mixed_precision="fp16" \
135+
--seed=42
136+
```
137+
### Advanced Training with Optimizations
138+
139+
```bash
140+
accelerate launch --multi_gpu train_lcm_distil_instruct_pix2pix_sdxl.py \
141+
--pretrained_teacher_model="diffusers/sdxl-instructpix2pix" \
142+
--dataset_name="your/instruct-pix2pix-dataset" \
143+
--output_dir="./output/lcm-sdxl-instruct" \
144+
--resolution=768 \
145+
--train_batch_size=2 \
146+
--gradient_accumulation_steps=8 \
147+
--learning_rate=5e-5 \
148+
--max_train_steps=20000 \
149+
--num_ddim_timesteps=50 \
150+
--w_min=3.0 \
151+
--w_max=15.0 \
152+
--ema_decay=0.95 \
153+
--loss_type="huber" \
154+
--huber_c=0.001 \
155+
--mixed_precision="bf16" \
156+
--gradient_checkpointing \
157+
--enable_xformers_memory_efficient_attention \
158+
--use_8bit_adam \
159+
--validation_steps=250 \
160+
--val_image_url_or_path="path/to/val_image.jpg" \
161+
--validation_prompt="make it sunset" \
162+
--num_validation_images=4 \
163+
--checkpointing_steps=500 \
164+
--checkpoints_total_limit=3 \
165+
--push_to_hub \
166+
--hub_model_id="your-username/lcm-sdxl-instruct" \
167+
--report_to="wandb"
168+
```
169+
## How It Works
170+
171+
### Architecture
172+
173+
1. **Teacher U-Net**: Pre-trained 8-channel InstructPix2Pix SDXL U-Net
174+
- Input: Concatenated noisy latent + original image latent (8 channels)
175+
- Performs multi-step diffusion with classifier-free guidance
176+
177+
2. **Student U-Net**: Distilled model with time conditioning
178+
- Learns to predict in a single step what teacher predicts in multiple steps
179+
- Uses guidance scale embedding for conditioning
180+
181+
3. **Target U-Net**: EMA copy of student
182+
- Provides stable training targets
183+
- Updated with exponential moving average
184+
185+
### Training Process
186+
187+
The training loop implements the LCM distillation algorithm:
188+
189+
1. **Sample timestep** from DDIM schedule
190+
2. **Add noise** to latents at sampled timestep
191+
3. **Sample guidance scale** $w$ from uniform distribution $[w_{min}, w_{max}]$
192+
4. **Student prediction**: Single-step prediction from noisy latents
193+
5. **Teacher prediction**: Multi-step DDIM prediction with CFG
194+
6. **Target prediction**: Prediction from EMA target network
195+
7. **Compute loss**: L2 or Huber loss between student and target
196+
8. **Update**: Backpropagate and update student, then EMA update target
197+
198+
### Loss Functions
199+
200+
**L2 Loss:**
201+
$$\mathcal{L} = \text{MSE}(\text{model\_pred}, \text{target})$$
202+
203+
**Huber Loss:**
204+
$$\mathcal{L} = \sqrt{(\text{model\_pred} - \text{target})^2 + c^2} - c$$
205+
206+
## Output Structure
207+
208+
After training, the output directory contains:
209+
210+
211+
output_dir/
212+
├── unet/ # Final student U-Net
213+
├── unet_target/ # Final target U-Net (recommended for inference)
214+
├── text_encoder/ # Text encoder (copied from teacher)
215+
├── text_encoder_2/ # Second text encoder (SDXL)
216+
├── tokenizer/ # Tokenizer
217+
├── tokenizer_2/ # Second tokenizer
218+
├── vae/ # VAE
219+
├── scheduler/ # LCM Scheduler
220+
├── checkpoint-{step}/ # Training checkpoints
221+
└── logs/ # TensorBoard logs
222+
223+
## Inference
224+
225+
After training, use the model for fast image editing:
226+
227+
python
228+
from diffusers import StableDiffusionXLInstructPix2PixPipeline, LCMScheduler
229+
from PIL import Image
230+
231+
# Load the trained model
232+
```bash
233+
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
234+
"./output/lcm-sdxl-instruct",
235+
torch_dtype=torch.float16
236+
).to("cuda")
237+
238+
# Load image
239+
image = Image.open("input.jpg")
240+
241+
# Edit with just 4 steps!
242+
edited_image = pipeline(
243+
prompt="make it a sunset scene",
244+
image=image,
245+
num_inference_steps=4,
246+
guidance_scale=7.5,
247+
image_guidance_scale=1.5
248+
).images[0]
249+
250+
edited_image.save("output.jpg")
251+
```
252+
## Tips & Best Practices
253+
254+
### Memory Optimization
255+
- Use `--gradient_checkpointing` to reduce memory usage
256+
- Enable `--enable_xformers_memory_efficient_attention` for efficiency
257+
- Use `--mixed_precision="fp16"` or `"bf16"`
258+
- Reduce `--train_batch_size` and increase `--gradient_accumulation_steps`
259+
260+
### Training Stability
261+
- Start with `--ema_decay=0.95` for stable target updates
262+
- Use `--loss_type="huber"` for more robust training
263+
- Adjust `--w_min` and `--w_max` based on your dataset
264+
- Monitor validation outputs regularly
265+
266+
### Quality vs Speed
267+
- More `--num_ddim_timesteps` = better teacher guidance but slower training
268+
- Higher `--ema_decay` = more stable but slower convergence
269+
- Experiment with different `--learning_rate` values (1e-5 to 5e-4)
270+
271+
### Multi-GPU Training
272+
Use Accelerate for distributed training:
273+
bash
274+
accelerate config # Configure once
275+
accelerate launch train_lcm_distil_instruct_pix2pix_sdxl.py [args]
276+
277+
## Troubleshooting
278+
279+
**NaN Loss**:
280+
- Try `--vae_precision="fp32"`
281+
- Reduce learning rate
282+
- Use gradient clipping with appropriate `--max_grad_norm`
283+
284+
**Out of Memory**:
285+
- Enable gradient checkpointing
286+
- Reduce batch size
287+
- Lower resolution
288+
- Use xFormers attention
289+
290+
**Poor Quality**:
291+
- Increase training steps
292+
- Adjust guidance scale range
293+
- Check dataset quality
294+
- Validate teacher model performance first
295+
296+
## Citation
297+
298+
If you use this code, please cite the relevant papers:
299+
300+
bibtex
301+
@article{luo2023latent,
302+
title={Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference},
303+
author={Luo, Simian and Tan, Yiqin and Huang, Longbo and Li, Jian and Zhao, Hang},
304+
journal={arXiv preprint arXiv:2310.04378},
305+
year={2023}
306+
}
307+
308+
@article{brooks2023instructpix2pix,
309+
title={InstructPix2Pix: Learning to Follow Image Editing Instructions},
310+
author={Brooks, Tim and Holynski, Aleksander and Efros, Alexei A},
311+
journal={CVPR},
312+
year={2023}
313+
}
314+
315+
## License
316+
317+
Please refer to the original model licenses and dataset licenses when using this code.
318+
319+
## Acknowledgments
320+
321+
This implementation is based on:
322+
- [Diffusers](https://github.com/huggingface/diffusers) library
323+
- Latent Consistency Models paper
324+
- InstructPix2Pix methodology
325+
- Stable Diffusion XL architecture
326+
327+
Developer by (https://medium.com/@mzeynali01)

0 commit comments

Comments
 (0)