Skip to content

Commit ccfaf0b

Browse files
committed
initial script copied from the dpo trainer
1 parent 7fb481f commit ccfaf0b

File tree

3 files changed

+1201
-0
lines changed

3 files changed

+1201
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Diffusion Model Alignment Using GRPO
2+
3+
4+
This directory provides LoRA implementations of Diffusion [GRPO](https://arxiv.org/abs/2402.03300) an RL based alignment method which is a variant of Proximal Policy Optimization (PPO) in the diffusion model setting.
5+
6+
## SDXL training command
7+
8+
```bash
9+
accelerate launch train_diffusion_grpo_sdxl.py \
10+
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
11+
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
12+
--output_dir="diffusion-sdxl-dpo" \
13+
--mixed_precision="fp16" \
14+
--dataset_name=kashif/pickascore \
15+
--train_batch_size=8 \
16+
--gradient_accumulation_steps=2 \
17+
--gradient_checkpointing \
18+
--use_8bit_adam \
19+
--rank=8 \
20+
--learning_rate=1e-5 \
21+
--report_to="wandb" \
22+
--lr_scheduler="constant" \
23+
--lr_warmup_steps=0 \
24+
--max_train_steps=2000 \
25+
--checkpointing_steps=500 \
26+
--run_validation --validation_steps=50 \
27+
--seed="0" \
28+
--report_to="wandb" \
29+
--push_to_hub
30+
```
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
accelerate>=0.16.0
2+
torchvision
3+
transformers>=4.25.1
4+
ftfy
5+
tensorboard
6+
Jinja2
7+
peft
8+
wandb

0 commit comments

Comments
 (0)