Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
90686c2
add ostris trainer to README & add cache latents of vae
linoytsaban Aug 12, 2024
7b12ed2
add ostris trainer to README & add cache latents of vae
linoytsaban Aug 12, 2024
17dca18
style
linoytsaban Aug 12, 2024
de24a4f
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 13, 2024
8b314e9
readme
linoytsaban Aug 13, 2024
a59b063
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 14, 2024
df54cd8
add test for latent caching
linoytsaban Aug 14, 2024
e0e0319
add ostris noise scheduler
linoytsaban Aug 14, 2024
18aa369
style
linoytsaban Aug 14, 2024
f97d53d
fix import
linoytsaban Aug 14, 2024
0156bec
style
linoytsaban Aug 14, 2024
c4c2c48
fix tests
linoytsaban Aug 14, 2024
d514c7b
style
linoytsaban Aug 14, 2024
7ee6041
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 14, 2024
d5c2a36
--change upcasting of transformer?
linoytsaban Aug 16, 2024
e760cda
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 21, 2024
f78ba77
Merge branch 'main' into dreambooth-lora
sayakpaul Aug 22, 2024
1b19593
Merge branch 'main' into dreambooth-lora
sayakpaul Aug 22, 2024
fbacbb5
update readme according to main
linoytsaban Sep 11, 2024
23f0636
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 11, 2024
51c7667
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 13, 2024
feae3dc
keep only latent caching
linoytsaban Sep 13, 2024
b53ae0b
add configurable param for final saving of trained layers- --upcast_b…
linoytsaban Sep 13, 2024
79e5234
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 13, 2024
5cdb4f5
style
linoytsaban Sep 13, 2024
e047ae2
Update examples/dreambooth/README_flux.md
linoytsaban Sep 14, 2024
a882c41
Update examples/dreambooth/README_flux.md
linoytsaban Sep 14, 2024
75058d7
use clear_objs_and_retain_memory from utilities
linoytsaban Sep 14, 2024
d61868e
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 14, 2024
88c0275
style
linoytsaban Sep 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/dreambooth/README_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced
>
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
> For more tips & guidance on training on a resource-constrained device please check out these great guides and trainers for FLUX: [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md) & [`ostris` guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training)


> [!NOTE]
Expand Down
35 changes: 31 additions & 4 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,12 @@ def parse_args(input_args=None):
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--cache_latents",
action="store_true",
default=False,
help="Cache the VAE latents",
)
parser.add_argument(
"--report_to",
type=str,
Expand Down Expand Up @@ -1456,6 +1462,24 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)

vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
vae_config_block_out_channels = vae.config.block_out_channels
if args.cache_latents:
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)

if args.validation_prompt is None:
del vae
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't have to be conditioned on the availability of CUDA, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I think maybe not, but for some reason we've used this condition in most other places too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be - there's one for mps to call too. i think there should be a utility helper for it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let's have it as is for now. I am working on a small utility for cleaning models and retaining accelerator memory.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@linoytsaban possible to use?

def clear_objs_and_retain_memory(objs: List[Any]):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -1578,7 +1602,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if args.train_text_encoder:
models_to_accumulate.extend([text_encoder_one])
with accelerator.accumulate(models_to_accumulate):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]

# encode batch prompts when custom prompts are provided for each image -
Expand Down Expand Up @@ -1610,11 +1633,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)

vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))

latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
Expand Down