Skip to content

Commit 98ec6d8

Browse files
author
Darshil Jariwala
committed
merging diffusers
2 parents 7fc3daa + 8cdcdd9 commit 98ec6d8

File tree

78 files changed

+4459
-398
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+4459
-398
lines changed

.github/workflows/push_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
name: Fast GPU Tests on main
22

33
on:
4+
workflow_dispatch:
45
push:
56
branches:
67
- main

docs/source/en/api/pipelines/cogvideox.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc
9898
- all
9999
- __call__
100100

101+
## CogVideoXVideoToVideoPipeline
102+
103+
[[autodoc]] CogVideoXVideoToVideoPipeline
104+
- all
105+
- __call__
106+
101107
## CogVideoXPipelineOutput
102108

103-
[[autodoc]] pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput
109+
[[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput

docs/source/en/api/pipelines/flux.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,15 @@ image.save("flux-fp8-dev.png")
163163
[[autodoc]] FluxPipeline
164164
- all
165165
- __call__
166+
167+
## FluxImg2ImgPipeline
168+
169+
[[autodoc]] FluxImg2ImgPipeline
170+
- all
171+
- __call__
172+
173+
## FluxInpaintPipeline
174+
175+
[[autodoc]] FluxInpaintPipeline
176+
- all
177+
- __call__

examples/dreambooth/README_flux.md

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced
88
>
99
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
1010
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
11-
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
1211
12+
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
13+
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
14+
> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training)
1315
1416
> [!NOTE]
1517
> **Gated model**
@@ -100,8 +102,10 @@ accelerate launch train_dreambooth_flux.py \
100102
--instance_prompt="a photo of sks dog" \
101103
--resolution=1024 \
102104
--train_batch_size=1 \
105+
--guidance_scale=1 \
103106
--gradient_accumulation_steps=4 \
104-
--learning_rate=1e-4 \
107+
--optimizer="prodigy" \
108+
--learning_rate=1. \
105109
--report_to="wandb" \
106110
--lr_scheduler="constant" \
107111
--lr_warmup_steps=0 \
@@ -120,15 +124,23 @@ To better track our training experiments, we're using the following flags in the
120124
> [!NOTE]
121125
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
122126
123-
> [!TIP]
124-
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
125-
126127
## LoRA + DreamBooth
127128

128129
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
129130

130131
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
131132

133+
### Prodigy Optimizer
134+
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
135+
By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
136+
137+
to use prodigy, specify
138+
```bash
139+
--optimizer="prodigy"
140+
```
141+
> [!TIP]
142+
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
143+
132144
To perform DreamBooth with LoRA, run:
133145

134146
```bash
@@ -144,8 +156,10 @@ accelerate launch train_dreambooth_lora_flux.py \
144156
--instance_prompt="a photo of sks dog" \
145157
--resolution=512 \
146158
--train_batch_size=1 \
159+
--guidance_scale=1 \
147160
--gradient_accumulation_steps=4 \
148-
--learning_rate=1e-5 \
161+
--optimizer="prodigy" \
162+
--learning_rate=1. \
149163
--report_to="wandb" \
150164
--lr_scheduler="constant" \
151165
--lr_warmup_steps=0 \
@@ -162,6 +176,7 @@ Alongside the transformer, fine-tuning of the CLIP text encoder is also supporte
162176
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
163177

164178
> [!NOTE]
179+
> This is still an experimental feature.
165180
> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL).
166181
By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed.
167182
> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
@@ -180,8 +195,10 @@ accelerate launch train_dreambooth_lora_flux.py \
180195
--instance_prompt="a photo of sks dog" \
181196
--resolution=512 \
182197
--train_batch_size=1 \
198+
--guidance_scale=1 \
183199
--gradient_accumulation_steps=4 \
184-
--learning_rate=1e-5 \
200+
--optimizer="prodigy" \
201+
--learning_rate=1. \
185202
--report_to="wandb" \
186203
--lr_scheduler="constant" \
187204
--lr_warmup_steps=0 \
@@ -191,5 +208,21 @@ accelerate launch train_dreambooth_lora_flux.py \
191208
--push_to_hub
192209
```
193210

211+
## Memory Optimizations
212+
As mentioned, Flux Dreambooth LoRA training is very memory intensive Here are some options (some still experimental) for a more memory efficient training.
213+
### Image Resolution
214+
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
215+
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
216+
### Gradient Checkpointing and Accumulation
217+
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
218+
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
219+
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
220+
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
221+
### 8-bit-Adam Optimizer
222+
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
223+
Make sure to install `bitsandbytes` if you want to do so.
224+
### latent caching
225+
When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory.
226+
to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents`
194227
## Other notes
195-
Thanks to `bghira` for their help with reviewing & insight sharing ♥️
228+
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import argparse
1717
import copy
18-
import gc
1918
import itertools
2019
import logging
2120
import math
@@ -56,6 +55,7 @@
5655
from diffusers.training_utils import (
5756
_set_state_dict_into_text_encoder,
5857
cast_training_params,
58+
clear_objs_and_retain_memory,
5959
compute_density_for_timestep_sampling,
6060
compute_loss_weighting_for_sd3,
6161
)
@@ -210,9 +210,7 @@ def log_validation(
210210
}
211211
)
212212

213-
del pipeline
214-
if torch.cuda.is_available():
215-
torch.cuda.empty_cache()
213+
clear_objs_and_retain_memory(objs=[pipeline])
216214

217215
return images
218216

@@ -1107,9 +1105,7 @@ def main(args):
11071105
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
11081106
image.save(image_filename)
11091107

1110-
del pipeline
1111-
if torch.cuda.is_available():
1112-
torch.cuda.empty_cache()
1108+
clear_objs_and_retain_memory(objs=[pipeline])
11131109

11141110
# Handle the repository creation
11151111
if accelerator.is_main_process:
@@ -1455,12 +1451,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14551451

14561452
# Clear the memory here
14571453
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
1458-
del tokenizers, text_encoders
14591454
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
1460-
del text_encoder_one, text_encoder_two, text_encoder_three
1461-
gc.collect()
1462-
if torch.cuda.is_available():
1463-
torch.cuda.empty_cache()
1455+
clear_objs_and_retain_memory(
1456+
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
1457+
)
14641458

14651459
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14661460
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1795,11 +1789,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17951789
pipeline_args=pipeline_args,
17961790
epoch=epoch,
17971791
)
1792+
objs = []
17981793
if not args.train_text_encoder:
1799-
del text_encoder_one, text_encoder_two, text_encoder_three
1794+
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
18001795

1801-
torch.cuda.empty_cache()
1802-
gc.collect()
1796+
clear_objs_and_retain_memory(objs=objs)
18031797

18041798
# Save the lora layers
18051799
accelerator.wait_for_everyone()

src/diffusers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,11 @@
255255
"BlipDiffusionPipeline",
256256
"CLIPImageProjection",
257257
"CogVideoXPipeline",
258+
"CogVideoXVideoToVideoPipeline",
258259
"CycleDiffusionPipeline",
259260
"FluxControlNetPipeline",
261+
"FluxImg2ImgPipeline",
262+
"FluxInpaintPipeline",
260263
"FluxPipeline",
261264
"HunyuanDiTControlNetPipeline",
262265
"HunyuanDiTPAGPipeline",
@@ -700,8 +703,11 @@
700703
AuraFlowPipeline,
701704
CLIPImageProjection,
702705
CogVideoXPipeline,
706+
CogVideoXVideoToVideoPipeline,
703707
CycleDiffusionPipeline,
704708
FluxControlNetPipeline,
709+
FluxImg2ImgPipeline,
710+
FluxInpaintPipeline,
705711
FluxPipeline,
706712
HunyuanDiTControlNetPipeline,
707713
HunyuanDiTPAGPipeline,

src/diffusers/image_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def preprocess(
569569

570570
channel = image.shape[1]
571571
# don't need any preprocess if the image is latents
572-
if channel == self.vae_latent_channels:
572+
if channel == self.config.vae_latent_channels:
573573
return image
574574

575575
height, width = self.get_default_height_width(image, height, width)

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,8 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
562562
new_key += ".attn.to_out.0"
563563
elif "processor.proj_lora2" in old_key:
564564
new_key += ".attn.to_add_out"
565-
elif "processor.qkv_lora1" in old_key and "up" not in old_key:
565+
# Handle text latents.
566+
elif "processor.qkv_lora2" in old_key and "up" not in old_key:
566567
handle_qkv(
567568
old_state_dict,
568569
new_state_dict,
@@ -574,7 +575,8 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
574575
],
575576
)
576577
# continue
577-
elif "processor.qkv_lora2" in old_key and "up" not in old_key:
578+
# Handle image latents.
579+
elif "processor.qkv_lora1" in old_key and "up" not in old_key:
578580
handle_qkv(
579581
old_state_dict,
580582
new_state_dict,

src/diffusers/models/attention.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,8 +1104,26 @@ def forward(
11041104
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
11051105
num_times_accumulated[:, frame_start:frame_end] += weights
11061106

1107-
hidden_states = torch.where(
1108-
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1107+
# TODO(aryan): Maybe this could be done in a better way.
1108+
#
1109+
# Previously, this was:
1110+
# hidden_states = torch.where(
1111+
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1112+
# )
1113+
#
1114+
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1115+
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1116+
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1117+
# looked into this deeply because other memory optimizations led to more pronounced reductions.
1118+
hidden_states = torch.cat(
1119+
[
1120+
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1121+
for accumulated_split, num_times_split in zip(
1122+
accumulated_values.split(self.context_length, dim=1),
1123+
num_times_accumulated.split(self.context_length, dim=1),
1124+
)
1125+
],
1126+
dim=1,
11091127
).to(dtype)
11101128

11111129
# 3. Feed-forward

0 commit comments

Comments
 (0)