Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
90686c2
add ostris trainer to README & add cache latents of vae
linoytsaban Aug 12, 2024
7b12ed2
add ostris trainer to README & add cache latents of vae
linoytsaban Aug 12, 2024
17dca18
style
linoytsaban Aug 12, 2024
de24a4f
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 13, 2024
8b314e9
readme
linoytsaban Aug 13, 2024
a59b063
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 14, 2024
df54cd8
add test for latent caching
linoytsaban Aug 14, 2024
e0e0319
add ostris noise scheduler
linoytsaban Aug 14, 2024
18aa369
style
linoytsaban Aug 14, 2024
f97d53d
fix import
linoytsaban Aug 14, 2024
0156bec
style
linoytsaban Aug 14, 2024
c4c2c48
fix tests
linoytsaban Aug 14, 2024
d514c7b
style
linoytsaban Aug 14, 2024
7ee6041
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 14, 2024
d5c2a36
--change upcasting of transformer?
linoytsaban Aug 16, 2024
e760cda
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 21, 2024
f78ba77
Merge branch 'main' into dreambooth-lora
sayakpaul Aug 22, 2024
1b19593
Merge branch 'main' into dreambooth-lora
sayakpaul Aug 22, 2024
fbacbb5
update readme according to main
linoytsaban Sep 11, 2024
23f0636
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 11, 2024
51c7667
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 13, 2024
feae3dc
keep only latent caching
linoytsaban Sep 13, 2024
b53ae0b
add configurable param for final saving of trained layers- --upcast_b…
linoytsaban Sep 13, 2024
79e5234
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 13, 2024
5cdb4f5
style
linoytsaban Sep 13, 2024
e047ae2
Update examples/dreambooth/README_flux.md
linoytsaban Sep 14, 2024
a882c41
Update examples/dreambooth/README_flux.md
linoytsaban Sep 14, 2024
75058d7
use clear_objs_and_retain_memory from utilities
linoytsaban Sep 14, 2024
d61868e
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 14, 2024
88c0275
style
linoytsaban Sep 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/dreambooth/README_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced
>
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)

> For more tips & guidance on training on a resource-constrained device please check out these great guides and trainers for FLUX:
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training)


> [!NOTE]
Expand Down
33 changes: 33 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ def test_dreambooth_lora_text_encoder_flux(self):
)
self.assertTrue(starts_with_expected_prefix)

def test_dreambooth_lora_latent_caching(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love!

with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
Expand Down
142 changes: 136 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import Union

import numpy as np
import torch
Expand Down Expand Up @@ -599,6 +600,12 @@ def parse_args(input_args=None):
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
parser.add_argument(
"--report_to",
type=str,
Expand Down Expand Up @@ -997,6 +1004,108 @@ def encode_prompt(
return prompt_embeds, pooled_prompt_embeds, text_ids


# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@linoytsaban I would just re-use the relevant parts from the original CustomFlowMatchEulerDiscreteScheduler here rather copy-pasting it entirely. For example, we don't need the get_simas() method 'cause we already have one inside the script.

Also option to use this scheduler related changes should be made configurable IMO.

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000

# generate the multiplier based on cosmap loss weighing
# this is only used on linear timesteps for now

# cosine map weighing is higher in the middle and lower at the ends
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
# cosmap_weighing = 2 / (math.pi * bot)

# sigma sqrt weighing is significantly higher at the end and lower at the beginning
sigma_sqrt_weighing = (self.sigmas**-2.0).float()
# clip at 1e4 (1e6 is too high)
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
# bring to a mean of 1
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()

# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")

self.linear_timesteps = timesteps
# self.linear_timesteps_weights = cosmap_weighing
self.linear_timesteps_weights = sigma_sqrt_weighing

# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
pass

def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]

# Get the weights for the timesteps
weights = self.linear_timesteps_weights[step_indices].flatten()

return weights

def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)

return sigma

def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1

# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise

# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise

# n_dim = original_samples.ndim
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
return noisy_model_input

def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample

def set_train_timesteps(self, num_timesteps, device, linear=False):
if linear:
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
else:
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))

# Scale and reverse the values to go from 1000 to 0
timesteps = (1 - t) * 1000

# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)

self.timesteps = timesteps.to(device=device)

return timesteps


def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
Expand Down Expand Up @@ -1127,7 +1236,7 @@ def main(args):
)

# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
Expand Down Expand Up @@ -1456,6 +1565,24 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)

vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
vae_config_block_out_channels = vae.config.block_out_channels
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)

if args.validation_prompt is None:
del vae
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't have to be conditioned on the availability of CUDA, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I think maybe not, but for some reason we've used this condition in most other places too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be - there's one for mps to call too. i think there should be a utility helper for it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let's have it as is for now. I am working on a small utility for cleaning models and retaining accelerator memory.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@linoytsaban possible to use?

def clear_objs_and_retain_memory(objs: List[Any]):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -1578,7 +1705,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if args.train_text_encoder:
models_to_accumulate.extend([text_encoder_one])
with accelerator.accumulate(models_to_accumulate):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]

# encode batch prompts when custom prompts are provided for each image -
Expand Down Expand Up @@ -1610,11 +1736,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)

vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))

latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
Expand Down Expand Up @@ -1793,7 +1923,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
transformer = transformer.to(torch.float32)
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)

if args.train_text_encoder:
Expand Down