Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
90686c2
add ostris trainer to README & add cache latents of vae
linoytsaban Aug 12, 2024
7b12ed2
add ostris trainer to README & add cache latents of vae
linoytsaban Aug 12, 2024
17dca18
style
linoytsaban Aug 12, 2024
de24a4f
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 13, 2024
8b314e9
readme
linoytsaban Aug 13, 2024
a59b063
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 14, 2024
df54cd8
add test for latent caching
linoytsaban Aug 14, 2024
e0e0319
add ostris noise scheduler
linoytsaban Aug 14, 2024
18aa369
style
linoytsaban Aug 14, 2024
f97d53d
fix import
linoytsaban Aug 14, 2024
0156bec
style
linoytsaban Aug 14, 2024
c4c2c48
fix tests
linoytsaban Aug 14, 2024
d514c7b
style
linoytsaban Aug 14, 2024
7ee6041
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 14, 2024
d5c2a36
--change upcasting of transformer?
linoytsaban Aug 16, 2024
e760cda
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 21, 2024
f78ba77
Merge branch 'main' into dreambooth-lora
sayakpaul Aug 22, 2024
1b19593
Merge branch 'main' into dreambooth-lora
sayakpaul Aug 22, 2024
fbacbb5
update readme according to main
linoytsaban Sep 11, 2024
23f0636
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 11, 2024
51c7667
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 13, 2024
feae3dc
keep only latent caching
linoytsaban Sep 13, 2024
b53ae0b
add configurable param for final saving of trained layers- --upcast_b…
linoytsaban Sep 13, 2024
79e5234
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 13, 2024
5cdb4f5
style
linoytsaban Sep 13, 2024
e047ae2
Update examples/dreambooth/README_flux.md
linoytsaban Sep 14, 2024
a882c41
Update examples/dreambooth/README_flux.md
linoytsaban Sep 14, 2024
75058d7
use clear_objs_and_retain_memory from utilities
linoytsaban Sep 14, 2024
d61868e
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 14, 2024
88c0275
style
linoytsaban Sep 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/dreambooth/README_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,12 @@ Instead, only a subset of these activations (the checkpoints) are stored and the
### 8-bit-Adam Optimizer
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
Make sure to install `bitsandbytes` if you want to do so.
### latent caching
### Latent caching
When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory.
to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents`
to enable `latent_caching` simply pass `--cache_latents`.
### Precision of saved LoRA layers
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.

## Other notes
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
33 changes: 33 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ def test_dreambooth_lora_text_encoder_flux(self):
)
self.assertTrue(starts_with_expected_prefix)

def test_dreambooth_lora_latent_caching(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love!

with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
Expand Down
59 changes: 44 additions & 15 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import argparse
import copy
import gc
import itertools
import logging
import math
Expand Down Expand Up @@ -56,6 +55,7 @@
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
Expand Down Expand Up @@ -600,6 +600,12 @@ def parse_args(input_args=None):
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
parser.add_argument(
"--report_to",
type=str,
Expand All @@ -620,6 +626,15 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
Expand Down Expand Up @@ -1422,12 +1437,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):

# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down Expand Up @@ -1457,6 +1467,21 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)

vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
vae_config_block_out_channels = vae.config.block_out_channels
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)

if args.validation_prompt is None:
clear_objs_and_retain_memory([vae])

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -1579,7 +1604,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if args.train_text_encoder:
models_to_accumulate.extend([text_encoder_one])
with accelerator.accumulate(models_to_accumulate):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]

# encode batch prompts when custom prompts are provided for each image -
Expand Down Expand Up @@ -1613,11 +1637,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)

vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))

latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
Expand Down Expand Up @@ -1789,15 +1817,16 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
torch_dtype=weight_dtype,
)
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two
torch.cuda.empty_cache()
gc.collect()
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])

# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
transformer = transformer.to(torch.float32)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)

if args.train_text_encoder:
Expand Down
Loading