|
1 | | -This project is an attempt to check if it's possible to apply to [ORPO](https://arxiv.org/abs/2403.07691) on a text-conditioned diffusion model to align it on preference data WITHOUT a reference model. The implementation is based on https://github.com/huggingface/trl/pull/1435/. |
2 | | - |
3 | | -> [!WARNING] |
4 | | -> We assume that MSE in the diffusion formulation approximates the log-probs as required by ORPO (hat-tip to [@kashif](https://github.com/kashif) for the idea). So, please consider this to be extremely experimental. |
5 | | -
|
6 | | -## Training |
7 | | - |
8 | | -Here's training command you can use on a 40GB A100 to validate things on a [small preference |
9 | | -dataset](https://hf.co/datasets/kashif/pickascore): |
10 | | - |
11 | | -```bash |
12 | | -accelerate launch train_diffusion_orpo_sdxl_lora.py \ |
13 | | - --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \ |
14 | | - --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ |
15 | | - --output_dir="diffusion-sdxl-orpo" \ |
16 | | - --mixed_precision="fp16" \ |
17 | | - --dataset_name=kashif/pickascore \ |
18 | | - --train_batch_size=8 \ |
19 | | - --gradient_accumulation_steps=2 \ |
20 | | - --gradient_checkpointing \ |
21 | | - --use_8bit_adam \ |
22 | | - --rank=8 \ |
23 | | - --learning_rate=1e-5 \ |
24 | | - --report_to="wandb" \ |
25 | | - --lr_scheduler="constant" \ |
26 | | - --lr_warmup_steps=0 \ |
27 | | - --max_train_steps=2000 \ |
28 | | - --checkpointing_steps=500 \ |
29 | | - --run_validation --validation_steps=50 \ |
30 | | - --seed="0" \ |
31 | | - --report_to="wandb" \ |
32 | | - --push_to_hub |
33 | | -``` |
34 | | - |
35 | | -We also provide a simple script to scale up the training on the [yuvalkirstain/pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset: |
36 | | - |
37 | | -```bash |
38 | | -accelerate launch --multi_gpu train_diffusion_orpo_sdxl_lora_wds.py \ |
39 | | - --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \ |
40 | | - --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \ |
41 | | - --dataset_path="pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -" \ |
42 | | - --output_dir="diffusion-sdxl-orpo-wds" \ |
43 | | - --mixed_precision="fp16" \ |
44 | | - --gradient_accumulation_steps=1 \ |
45 | | - --gradient_checkpointing \ |
46 | | - --use_8bit_adam \ |
47 | | - --rank=8 \ |
48 | | - --dataloader_num_workers=8 \ |
49 | | - --learning_rate=3e-5 \ |
50 | | - --report_to="wandb" \ |
51 | | - --lr_scheduler="constant" \ |
52 | | - --lr_warmup_steps=0 \ |
53 | | - --max_train_steps=50000 \ |
54 | | - --checkpointing_steps=2000 \ |
55 | | - --run_validation --validation_steps=500 \ |
56 | | - --seed="0" \ |
57 | | - --report_to="wandb" \ |
58 | | - --push_to_hub |
59 | | -``` |
60 | | - |
61 | | -We tested the above on a node of 8 H100s but it should also work on A100s. It requires the `webdataset` library for faster dataloading. Note that we kept the dataset shards on an S3 bucket but it should be also possible to have them stored locally. |
62 | | - |
63 | | -You can use the code below to convert the original dataset into `webdataset` shards: |
64 | | - |
65 | | -```python |
66 | | -import os |
67 | | -import io |
68 | | -import ray |
69 | | -import webdataset as wds |
70 | | -from datasets import Dataset |
71 | | -from PIL import Image |
72 | | - |
73 | | -ray.init(num_cpus=8) |
74 | | - |
75 | | - |
76 | | -def convert_to_image(im_bytes): |
77 | | - return Image.open(io.BytesIO(im_bytes)).convert("RGB") |
78 | | - |
79 | | -def main(): |
80 | | - dataset_path = "/pickapic_v2/data" |
81 | | - wds_shards_path = "/pickapic_v2_webdataset" |
82 | | - # get all .parquet files in the dataset path |
83 | | - dataset_files = [ |
84 | | - os.path.join(dataset_path, f) |
85 | | - for f in os.listdir(dataset_path) |
86 | | - if f.endswith(".parquet") |
87 | | - ] |
88 | | - |
89 | | - @ray.remote |
90 | | - def create_shard(path): |
91 | | - # get basename of the file |
92 | | - basename = os.path.basename(path) |
93 | | - # get the shard number data-00123-of-01034.parquet -> 00123 |
94 | | - shard_num = basename.split("-")[1] |
95 | | - dataset = Dataset.from_parquet(path) |
96 | | - # create a webdataset shard |
97 | | - shard = wds.TarWriter(os.path.join(wds_shards_path, f"{shard_num}.tar")) |
98 | | - |
99 | | - for i, example in enumerate(dataset): |
100 | | - wds_example = { |
101 | | - "__key__": str(i), |
102 | | - "original_prompt.txt": example["caption"], |
103 | | - "jpg_0.jpg": convert_to_image(example["jpg_0"]), |
104 | | - "jpg_1.jpg": convert_to_image(example["jpg_1"]), |
105 | | - "label_0.txt": str(example["label_0"]), |
106 | | - "label_1.txt": str(example["label_1"]) |
107 | | - } |
108 | | - shard.write(wds_example) |
109 | | - shard.close() |
110 | | - |
111 | | - futures = [create_shard.remote(path) for path in dataset_files] |
112 | | - ray.get(futures) |
113 | | - |
114 | | - |
115 | | -if __name__ == "__main__": |
116 | | - main() |
117 | | -``` |
118 | | - |
119 | | -## Inference |
120 | | - |
121 | | -Refer to [sayakpaul/diffusion-sdxl-orpo](https://huggingface.co/sayakpaul/diffusion-sdxl-orpo) for an experimental checkpoint. |
| 1 | +This project has a new home now: [https://mapo-t2i.github.io/](https://mapo-t2i.github.io/). We formally studied the use of ORPO in the context of diffusion models and open-sourced our codebase, models, and datasets. We released our paper too! |
0 commit comments