|
1 |
| -This project is an attempt to check if it's possible to apply to [ORPO](https://arxiv.org/abs/2403.07691) on a text-conditioned diffusion model to align it on preference data WITHOUT a reference model. The implementation is based on https://github.com/huggingface/trl/pull/1435/. |
2 |
| - |
3 |
| -> [!WARNING] |
4 |
| -> We assume that MSE in the diffusion formulation approximates the log-probs as required by ORPO (hat-tip to [@kashif](https://github.com/kashif) for the idea). So, please consider this to be extremely experimental. |
5 |
| -
|
6 |
| -## Training |
7 |
| - |
8 |
| -Here's training command you can use on a 40GB A100 to validate things on a [small preference |
9 |
| -dataset](https://hf.co/datasets/kashif/pickascore): |
10 |
| - |
11 |
| -```bash |
12 |
| -accelerate launch train_diffusion_orpo_sdxl_lora.py \ |
13 |
| - --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \ |
14 |
| - --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ |
15 |
| - --output_dir="diffusion-sdxl-orpo" \ |
16 |
| - --mixed_precision="fp16" \ |
17 |
| - --dataset_name=kashif/pickascore \ |
18 |
| - --train_batch_size=8 \ |
19 |
| - --gradient_accumulation_steps=2 \ |
20 |
| - --gradient_checkpointing \ |
21 |
| - --use_8bit_adam \ |
22 |
| - --rank=8 \ |
23 |
| - --learning_rate=1e-5 \ |
24 |
| - --report_to="wandb" \ |
25 |
| - --lr_scheduler="constant" \ |
26 |
| - --lr_warmup_steps=0 \ |
27 |
| - --max_train_steps=2000 \ |
28 |
| - --checkpointing_steps=500 \ |
29 |
| - --run_validation --validation_steps=50 \ |
30 |
| - --seed="0" \ |
31 |
| - --report_to="wandb" \ |
32 |
| - --push_to_hub |
33 |
| -``` |
34 |
| - |
35 |
| -We also provide a simple script to scale up the training on the [yuvalkirstain/pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset: |
36 |
| - |
37 |
| -```bash |
38 |
| -accelerate launch --multi_gpu train_diffusion_orpo_sdxl_lora_wds.py \ |
39 |
| - --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \ |
40 |
| - --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ |
41 |
| - --dataset_path="pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -" \ |
42 |
| - --output_dir="diffusion-sdxl-orpo-wds" \ |
43 |
| - --mixed_precision="fp16" \ |
44 |
| - --gradient_accumulation_steps=1 \ |
45 |
| - --gradient_checkpointing \ |
46 |
| - --use_8bit_adam \ |
47 |
| - --rank=8 \ |
48 |
| - --dataloader_num_workers=8 \ |
49 |
| - --learning_rate=3e-5 \ |
50 |
| - --report_to="wandb" \ |
51 |
| - --lr_scheduler="constant" \ |
52 |
| - --lr_warmup_steps=0 \ |
53 |
| - --max_train_steps=50000 \ |
54 |
| - --checkpointing_steps=2000 \ |
55 |
| - --run_validation --validation_steps=500 \ |
56 |
| - --seed="0" \ |
57 |
| - --report_to="wandb" \ |
58 |
| - --push_to_hub |
59 |
| -``` |
60 |
| - |
61 |
| -We tested the above on a node of 8 H100s but it should also work on A100s. It requires the `webdataset` library for faster dataloading. Note that we kept the dataset shards on an S3 bucket but it should be also possible to have them stored locally. |
62 |
| - |
63 |
| -You can use the code below to convert the original dataset into `webdataset` shards: |
64 |
| - |
65 |
| -```python |
66 |
| -import os |
67 |
| -import io |
68 |
| -import ray |
69 |
| -import webdataset as wds |
70 |
| -from datasets import Dataset |
71 |
| -from PIL import Image |
72 |
| - |
73 |
| -ray.init(num_cpus=8) |
74 |
| - |
75 |
| - |
76 |
| -def convert_to_image(im_bytes): |
77 |
| - return Image.open(io.BytesIO(im_bytes)).convert("RGB") |
78 |
| - |
79 |
| -def main(): |
80 |
| - dataset_path = "/pickapic_v2/data" |
81 |
| - wds_shards_path = "/pickapic_v2_webdataset" |
82 |
| - # get all .parquet files in the dataset path |
83 |
| - dataset_files = [ |
84 |
| - os.path.join(dataset_path, f) |
85 |
| - for f in os.listdir(dataset_path) |
86 |
| - if f.endswith(".parquet") |
87 |
| - ] |
88 |
| - |
89 |
| - @ray.remote |
90 |
| - def create_shard(path): |
91 |
| - # get basename of the file |
92 |
| - basename = os.path.basename(path) |
93 |
| - # get the shard number data-00123-of-01034.parquet -> 00123 |
94 |
| - shard_num = basename.split("-")[1] |
95 |
| - dataset = Dataset.from_parquet(path) |
96 |
| - # create a webdataset shard |
97 |
| - shard = wds.TarWriter(os.path.join(wds_shards_path, f"{shard_num}.tar")) |
98 |
| - |
99 |
| - for i, example in enumerate(dataset): |
100 |
| - wds_example = { |
101 |
| - "__key__": str(i), |
102 |
| - "original_prompt.txt": example["caption"], |
103 |
| - "jpg_0.jpg": convert_to_image(example["jpg_0"]), |
104 |
| - "jpg_1.jpg": convert_to_image(example["jpg_1"]), |
105 |
| - "label_0.txt": str(example["label_0"]), |
106 |
| - "label_1.txt": str(example["label_1"]) |
107 |
| - } |
108 |
| - shard.write(wds_example) |
109 |
| - shard.close() |
110 |
| - |
111 |
| - futures = [create_shard.remote(path) for path in dataset_files] |
112 |
| - ray.get(futures) |
113 |
| - |
114 |
| - |
115 |
| -if __name__ == "__main__": |
116 |
| - main() |
117 |
| -``` |
118 |
| - |
119 |
| -## Inference |
120 |
| - |
121 |
| -Refer to [sayakpaul/diffusion-sdxl-orpo](https://huggingface.co/sayakpaul/diffusion-sdxl-orpo) for an experimental checkpoint. |
| 1 | +This project has a new home now: [https://mapo-t2i.github.io/](https://mapo-t2i.github.io/). We formally studied the use of ORPO in the context of diffusion models and open-sourced our codebase, models, and datasets. We released our paper too! |
0 commit comments