Skip to content

Commit d457bee

Browse files
authored
Update README.md to update the MaPO project (#8470)
Update README.md
1 parent 1d9a6a8 commit d457bee

File tree

1 file changed

+1
-121
lines changed
  • examples/research_projects/diffusion_orpo

1 file changed

+1
-121
lines changed
Lines changed: 1 addition & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1 @@
1-
This project is an attempt to check if it's possible to apply to [ORPO](https://arxiv.org/abs/2403.07691) on a text-conditioned diffusion model to align it on preference data WITHOUT a reference model. The implementation is based on https://github.com/huggingface/trl/pull/1435/.
2-
3-
> [!WARNING]
4-
> We assume that MSE in the diffusion formulation approximates the log-probs as required by ORPO (hat-tip to [@kashif](https://github.com/kashif) for the idea). So, please consider this to be extremely experimental.
5-
6-
## Training
7-
8-
Here's training command you can use on a 40GB A100 to validate things on a [small preference
9-
dataset](https://hf.co/datasets/kashif/pickascore):
10-
11-
```bash
12-
accelerate launch train_diffusion_orpo_sdxl_lora.py \
13-
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
14-
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
15-
--output_dir="diffusion-sdxl-orpo" \
16-
--mixed_precision="fp16" \
17-
--dataset_name=kashif/pickascore \
18-
--train_batch_size=8 \
19-
--gradient_accumulation_steps=2 \
20-
--gradient_checkpointing \
21-
--use_8bit_adam \
22-
--rank=8 \
23-
--learning_rate=1e-5 \
24-
--report_to="wandb" \
25-
--lr_scheduler="constant" \
26-
--lr_warmup_steps=0 \
27-
--max_train_steps=2000 \
28-
--checkpointing_steps=500 \
29-
--run_validation --validation_steps=50 \
30-
--seed="0" \
31-
--report_to="wandb" \
32-
--push_to_hub
33-
```
34-
35-
We also provide a simple script to scale up the training on the [yuvalkirstain/pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset:
36-
37-
```bash
38-
accelerate launch --multi_gpu train_diffusion_orpo_sdxl_lora_wds.py \
39-
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
40-
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
41-
--dataset_path="pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -" \
42-
--output_dir="diffusion-sdxl-orpo-wds" \
43-
--mixed_precision="fp16" \
44-
--gradient_accumulation_steps=1 \
45-
--gradient_checkpointing \
46-
--use_8bit_adam \
47-
--rank=8 \
48-
--dataloader_num_workers=8 \
49-
--learning_rate=3e-5 \
50-
--report_to="wandb" \
51-
--lr_scheduler="constant" \
52-
--lr_warmup_steps=0 \
53-
--max_train_steps=50000 \
54-
--checkpointing_steps=2000 \
55-
--run_validation --validation_steps=500 \
56-
--seed="0" \
57-
--report_to="wandb" \
58-
--push_to_hub
59-
```
60-
61-
We tested the above on a node of 8 H100s but it should also work on A100s. It requires the `webdataset` library for faster dataloading. Note that we kept the dataset shards on an S3 bucket but it should be also possible to have them stored locally.
62-
63-
You can use the code below to convert the original dataset into `webdataset` shards:
64-
65-
```python
66-
import os
67-
import io
68-
import ray
69-
import webdataset as wds
70-
from datasets import Dataset
71-
from PIL import Image
72-
73-
ray.init(num_cpus=8)
74-
75-
76-
def convert_to_image(im_bytes):
77-
return Image.open(io.BytesIO(im_bytes)).convert("RGB")
78-
79-
def main():
80-
dataset_path = "/pickapic_v2/data"
81-
wds_shards_path = "/pickapic_v2_webdataset"
82-
# get all .parquet files in the dataset path
83-
dataset_files = [
84-
os.path.join(dataset_path, f)
85-
for f in os.listdir(dataset_path)
86-
if f.endswith(".parquet")
87-
]
88-
89-
@ray.remote
90-
def create_shard(path):
91-
# get basename of the file
92-
basename = os.path.basename(path)
93-
# get the shard number data-00123-of-01034.parquet -> 00123
94-
shard_num = basename.split("-")[1]
95-
dataset = Dataset.from_parquet(path)
96-
# create a webdataset shard
97-
shard = wds.TarWriter(os.path.join(wds_shards_path, f"{shard_num}.tar"))
98-
99-
for i, example in enumerate(dataset):
100-
wds_example = {
101-
"__key__": str(i),
102-
"original_prompt.txt": example["caption"],
103-
"jpg_0.jpg": convert_to_image(example["jpg_0"]),
104-
"jpg_1.jpg": convert_to_image(example["jpg_1"]),
105-
"label_0.txt": str(example["label_0"]),
106-
"label_1.txt": str(example["label_1"])
107-
}
108-
shard.write(wds_example)
109-
shard.close()
110-
111-
futures = [create_shard.remote(path) for path in dataset_files]
112-
ray.get(futures)
113-
114-
115-
if __name__ == "__main__":
116-
main()
117-
```
118-
119-
## Inference
120-
121-
Refer to [sayakpaul/diffusion-sdxl-orpo](https://huggingface.co/sayakpaul/diffusion-sdxl-orpo) for an experimental checkpoint.
1+
This project has a new home now: [https://mapo-t2i.github.io/](https://mapo-t2i.github.io/). We formally studied the use of ORPO in the context of diffusion models and open-sourced our codebase, models, and datasets. We released our paper too!

0 commit comments

Comments
 (0)