Skip to content

Commit acdd37b

Browse files
authored
Merge branch 'main' into controlnet_num_train_epochs_patch
2 parents 224471b + 0028c34 commit acdd37b

File tree

2 files changed

+18
-138
lines changed

2 files changed

+18
-138
lines changed
Lines changed: 1 addition & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1 @@
1-
This project is an attempt to check if it's possible to apply to [ORPO](https://arxiv.org/abs/2403.07691) on a text-conditioned diffusion model to align it on preference data WITHOUT a reference model. The implementation is based on https://github.com/huggingface/trl/pull/1435/.
2-
3-
> [!WARNING]
4-
> We assume that MSE in the diffusion formulation approximates the log-probs as required by ORPO (hat-tip to [@kashif](https://github.com/kashif) for the idea). So, please consider this to be extremely experimental.
5-
6-
## Training
7-
8-
Here's training command you can use on a 40GB A100 to validate things on a [small preference
9-
dataset](https://hf.co/datasets/kashif/pickascore):
10-
11-
```bash
12-
accelerate launch train_diffusion_orpo_sdxl_lora.py \
13-
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
14-
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
15-
--output_dir="diffusion-sdxl-orpo" \
16-
--mixed_precision="fp16" \
17-
--dataset_name=kashif/pickascore \
18-
--train_batch_size=8 \
19-
--gradient_accumulation_steps=2 \
20-
--gradient_checkpointing \
21-
--use_8bit_adam \
22-
--rank=8 \
23-
--learning_rate=1e-5 \
24-
--report_to="wandb" \
25-
--lr_scheduler="constant" \
26-
--lr_warmup_steps=0 \
27-
--max_train_steps=2000 \
28-
--checkpointing_steps=500 \
29-
--run_validation --validation_steps=50 \
30-
--seed="0" \
31-
--report_to="wandb" \
32-
--push_to_hub
33-
```
34-
35-
We also provide a simple script to scale up the training on the [yuvalkirstain/pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset:
36-
37-
```bash
38-
accelerate launch --multi_gpu train_diffusion_orpo_sdxl_lora_wds.py \
39-
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
40-
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
41-
--dataset_path="pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -" \
42-
--output_dir="diffusion-sdxl-orpo-wds" \
43-
--mixed_precision="fp16" \
44-
--gradient_accumulation_steps=1 \
45-
--gradient_checkpointing \
46-
--use_8bit_adam \
47-
--rank=8 \
48-
--dataloader_num_workers=8 \
49-
--learning_rate=3e-5 \
50-
--report_to="wandb" \
51-
--lr_scheduler="constant" \
52-
--lr_warmup_steps=0 \
53-
--max_train_steps=50000 \
54-
--checkpointing_steps=2000 \
55-
--run_validation --validation_steps=500 \
56-
--seed="0" \
57-
--report_to="wandb" \
58-
--push_to_hub
59-
```
60-
61-
We tested the above on a node of 8 H100s but it should also work on A100s. It requires the `webdataset` library for faster dataloading. Note that we kept the dataset shards on an S3 bucket but it should be also possible to have them stored locally.
62-
63-
You can use the code below to convert the original dataset into `webdataset` shards:
64-
65-
```python
66-
import os
67-
import io
68-
import ray
69-
import webdataset as wds
70-
from datasets import Dataset
71-
from PIL import Image
72-
73-
ray.init(num_cpus=8)
74-
75-
76-
def convert_to_image(im_bytes):
77-
return Image.open(io.BytesIO(im_bytes)).convert("RGB")
78-
79-
def main():
80-
dataset_path = "/pickapic_v2/data"
81-
wds_shards_path = "/pickapic_v2_webdataset"
82-
# get all .parquet files in the dataset path
83-
dataset_files = [
84-
os.path.join(dataset_path, f)
85-
for f in os.listdir(dataset_path)
86-
if f.endswith(".parquet")
87-
]
88-
89-
@ray.remote
90-
def create_shard(path):
91-
# get basename of the file
92-
basename = os.path.basename(path)
93-
# get the shard number data-00123-of-01034.parquet -> 00123
94-
shard_num = basename.split("-")[1]
95-
dataset = Dataset.from_parquet(path)
96-
# create a webdataset shard
97-
shard = wds.TarWriter(os.path.join(wds_shards_path, f"{shard_num}.tar"))
98-
99-
for i, example in enumerate(dataset):
100-
wds_example = {
101-
"__key__": str(i),
102-
"original_prompt.txt": example["caption"],
103-
"jpg_0.jpg": convert_to_image(example["jpg_0"]),
104-
"jpg_1.jpg": convert_to_image(example["jpg_1"]),
105-
"label_0.txt": str(example["label_0"]),
106-
"label_1.txt": str(example["label_1"])
107-
}
108-
shard.write(wds_example)
109-
shard.close()
110-
111-
futures = [create_shard.remote(path) for path in dataset_files]
112-
ray.get(futures)
113-
114-
115-
if __name__ == "__main__":
116-
main()
117-
```
118-
119-
## Inference
120-
121-
Refer to [sayakpaul/diffusion-sdxl-orpo](https://huggingface.co/sayakpaul/diffusion-sdxl-orpo) for an experimental checkpoint.
1+
This project has a new home now: [https://mapo-t2i.github.io/](https://mapo-t2i.github.io/). We formally studied the use of ORPO in the context of diffusion models and open-sourced our codebase, models, and datasets. We released our paper too!

src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ def __call__(
376376

377377
# 2. Define call parameters
378378
batch_size = 1 if isinstance(prompt, str) else len(prompt)
379+
device = self._execution_device
379380

380381
if editing_prompt:
381382
enable_edit_guidance = True
@@ -405,7 +406,7 @@ def __call__(
405406
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
406407
)
407408
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
408-
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
409+
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
409410

410411
# duplicate text embeddings for each generation per prompt, using mps friendly method
411412
bs_embed, seq_len, _ = text_embeddings.shape
@@ -433,9 +434,9 @@ def __call__(
433434
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
434435
)
435436
edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length]
436-
edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0]
437+
edit_concepts = self.text_encoder(edit_concepts_input_ids.to(device))[0]
437438
else:
438-
edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1)
439+
edit_concepts = editing_prompt_embeddings.to(device).repeat(batch_size, 1, 1)
439440

440441
# duplicate text embeddings for each generation per prompt, using mps friendly method
441442
bs_embed_edit, seq_len_edit, _ = edit_concepts.shape
@@ -476,7 +477,7 @@ def __call__(
476477
truncation=True,
477478
return_tensors="pt",
478479
)
479-
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
480+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
480481

481482
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
482483
seq_len = uncond_embeddings.shape[1]
@@ -493,7 +494,7 @@ def __call__(
493494
# get the initial random noise unless the user supplied it
494495

495496
# 4. Prepare timesteps
496-
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
497+
self.scheduler.set_timesteps(num_inference_steps, device=device)
497498
timesteps = self.scheduler.timesteps
498499

499500
# 5. Prepare latent variables
@@ -504,7 +505,7 @@ def __call__(
504505
height,
505506
width,
506507
text_embeddings.dtype,
507-
self.device,
508+
device,
508509
generator,
509510
latents,
510511
)
@@ -562,12 +563,12 @@ def __call__(
562563
if enable_edit_guidance:
563564
concept_weights = torch.zeros(
564565
(len(noise_pred_edit_concepts), noise_guidance.shape[0]),
565-
device=self.device,
566+
device=device,
566567
dtype=noise_guidance.dtype,
567568
)
568569
noise_guidance_edit = torch.zeros(
569570
(len(noise_pred_edit_concepts), *noise_guidance.shape),
570-
device=self.device,
571+
device=device,
571572
dtype=noise_guidance.dtype,
572573
)
573574
# noise_guidance_edit = torch.zeros_like(noise_guidance)
@@ -644,21 +645,19 @@ def __call__(
644645

645646
# noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
646647

647-
warmup_inds = torch.tensor(warmup_inds).to(self.device)
648+
warmup_inds = torch.tensor(warmup_inds).to(device)
648649
if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
649650
concept_weights = concept_weights.to("cpu") # Offload to cpu
650651
noise_guidance_edit = noise_guidance_edit.to("cpu")
651652

652-
concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds)
653+
concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds)
653654
concept_weights_tmp = torch.where(
654655
concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
655656
)
656657
concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
657658
# concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
658659

659-
noise_guidance_edit_tmp = torch.index_select(
660-
noise_guidance_edit.to(self.device), 0, warmup_inds
661-
)
660+
noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)
662661
noise_guidance_edit_tmp = torch.einsum(
663662
"cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
664663
)
@@ -669,8 +668,8 @@ def __call__(
669668

670669
del noise_guidance_edit_tmp
671670
del concept_weights_tmp
672-
concept_weights = concept_weights.to(self.device)
673-
noise_guidance_edit = noise_guidance_edit.to(self.device)
671+
concept_weights = concept_weights.to(device)
672+
noise_guidance_edit = noise_guidance_edit.to(device)
674673

675674
concept_weights = torch.where(
676675
concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
@@ -679,6 +678,7 @@ def __call__(
679678
concept_weights = torch.nan_to_num(concept_weights)
680679

681680
noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
681+
noise_guidance_edit = noise_guidance_edit.to(edit_momentum.device)
682682

683683
noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
684684

@@ -689,7 +689,7 @@ def __call__(
689689
self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
690690

691691
if sem_guidance is not None:
692-
edit_guidance = sem_guidance[i].to(self.device)
692+
edit_guidance = sem_guidance[i].to(device)
693693
noise_guidance = noise_guidance + edit_guidance
694694

695695
noise_pred = noise_pred_uncond + noise_guidance
@@ -705,7 +705,7 @@ def __call__(
705705
# 8. Post-processing
706706
if not output_type == "latent":
707707
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
708-
image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
708+
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
709709
else:
710710
image = latents
711711
has_nsfw_concept = None

0 commit comments

Comments
 (0)