-
Couldn't load subscription status.
- Fork 6.5k
[Flux Dreambooth lora] add latent caching #9160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
90686c2
7b12ed2
17dca18
de24a4f
8b314e9
a59b063
df54cd8
e0e0319
18aa369
f97d53d
0156bec
c4c2c48
d514c7b
7ee6041
d5c2a36
e760cda
f78ba77
1b19593
fbacbb5
23f0636
51c7667
feae3dc
b53ae0b
79e5234
5cdb4f5
e047ae2
a882c41
75058d7
d61868e
88c0275
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -599,6 +599,12 @@ def parse_args(input_args=None): | |||
| " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" | ||||
| ), | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--cache_latents", | ||||
| action="store_true", | ||||
| default=False, | ||||
| help="Cache the VAE latents", | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--report_to", | ||||
| type=str, | ||||
|
|
@@ -1456,6 +1462,24 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): | |||
| tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) | ||||
| tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) | ||||
|
|
||||
| vae_config_shift_factor = vae.config.shift_factor | ||||
| vae_config_scaling_factor = vae.config.scaling_factor | ||||
| vae_config_block_out_channels = vae.config.block_out_channels | ||||
| if args.cache_latents: | ||||
| latents_cache = [] | ||||
| for batch in tqdm(train_dataloader, desc="Caching latents"): | ||||
| with torch.no_grad(): | ||||
| batch["pixel_values"] = batch["pixel_values"].to( | ||||
| accelerator.device, non_blocking=True, dtype=weight_dtype | ||||
| ) | ||||
| latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) | ||||
|
|
||||
| if args.validation_prompt is None: | ||||
| del vae | ||||
| if torch.cuda.is_available(): | ||||
| torch.cuda.empty_cache() | ||||
| gc.collect() | ||||
|
||||
| def clear_objs_and_retain_memory(objs: List[Any]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Uh oh!
There was an error while loading. Please reload this page.