-
Couldn't load subscription status.
- Fork 6.5k
[Flux Dreambooth lora] add latent caching #9160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
90686c2
7b12ed2
17dca18
de24a4f
8b314e9
a59b063
df54cd8
e0e0319
18aa369
f97d53d
0156bec
c4c2c48
d514c7b
7ee6041
d5c2a36
e760cda
f78ba77
1b19593
fbacbb5
23f0636
51c7667
feae3dc
b53ae0b
79e5234
5cdb4f5
e047ae2
a882c41
75058d7
d61868e
88c0275
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -103,6 +103,39 @@ def test_dreambooth_lora_text_encoder_flux(self): | |
| ) | ||
| self.assertTrue(starts_with_expected_prefix) | ||
|
|
||
| def test_dreambooth_lora_latent_caching(self): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Love! |
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| test_args = f""" | ||
| {self.script_path} | ||
| --pretrained_model_name_or_path {self.pretrained_model_name_or_path} | ||
| --instance_data_dir {self.instance_data_dir} | ||
| --instance_prompt {self.instance_prompt} | ||
| --resolution 64 | ||
| --train_batch_size 1 | ||
| --gradient_accumulation_steps 1 | ||
| --max_train_steps 2 | ||
| --cache_latents | ||
| --learning_rate 5.0e-04 | ||
| --scale_lr | ||
| --lr_scheduler constant | ||
| --lr_warmup_steps 0 | ||
| --output_dir {tmpdir} | ||
| """.split() | ||
|
|
||
| run_command(self._launch_args + test_args) | ||
| # save_pretrained smoke test | ||
| self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) | ||
|
|
||
| # make sure the state_dict has the correct naming in the parameters. | ||
| lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) | ||
| is_lora = all("lora" in k for k in lora_state_dict.keys()) | ||
| self.assertTrue(is_lora) | ||
|
|
||
| # when not training the text encoder, all the parameters in the state dict should start | ||
| # with `"transformer"` in their names. | ||
| starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) | ||
| self.assertTrue(starts_with_transformer) | ||
|
|
||
| def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| test_args = f""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |||
| import warnings | ||||
| from contextlib import nullcontext | ||||
| from pathlib import Path | ||||
| from typing import Union | ||||
|
|
||||
| import numpy as np | ||||
| import torch | ||||
|
|
@@ -599,6 +600,12 @@ def parse_args(input_args=None): | |||
| " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" | ||||
| ), | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--cache_latents", | ||||
| action="store_true", | ||||
| default=False, | ||||
| help="Cache the VAE latents", | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--report_to", | ||||
| type=str, | ||||
|
|
@@ -997,6 +1004,108 @@ def encode_prompt( | |||
| return prompt_embeds, pooled_prompt_embeds, text_ids | ||||
|
|
||||
|
|
||||
| # CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: | ||||
| # https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 | ||||
| class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): | ||||
|
||||
| def __init__(self, *args, **kwargs): | ||||
| super().__init__(*args, **kwargs) | ||||
|
|
||||
| with torch.no_grad(): | ||||
| # create weights for timesteps | ||||
| num_timesteps = 1000 | ||||
|
|
||||
| # generate the multiplier based on cosmap loss weighing | ||||
| # this is only used on linear timesteps for now | ||||
|
|
||||
| # cosine map weighing is higher in the middle and lower at the ends | ||||
| # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2 | ||||
| # cosmap_weighing = 2 / (math.pi * bot) | ||||
|
|
||||
| # sigma sqrt weighing is significantly higher at the end and lower at the beginning | ||||
| sigma_sqrt_weighing = (self.sigmas**-2.0).float() | ||||
| # clip at 1e4 (1e6 is too high) | ||||
| sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) | ||||
| # bring to a mean of 1 | ||||
| sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() | ||||
|
|
||||
| # Create linear timesteps from 1000 to 0 | ||||
| timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu") | ||||
|
|
||||
| self.linear_timesteps = timesteps | ||||
| # self.linear_timesteps_weights = cosmap_weighing | ||||
| self.linear_timesteps_weights = sigma_sqrt_weighing | ||||
|
|
||||
| # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu') | ||||
| pass | ||||
|
|
||||
| def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: | ||||
| # Get the indices of the timesteps | ||||
| step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] | ||||
|
|
||||
| # Get the weights for the timesteps | ||||
| weights = self.linear_timesteps_weights[step_indices].flatten() | ||||
|
|
||||
| return weights | ||||
|
|
||||
| def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: | ||||
| sigmas = self.sigmas.to(device=device, dtype=dtype) | ||||
| schedule_timesteps = self.timesteps.to(device) | ||||
| timesteps = timesteps.to(device) | ||||
| step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | ||||
|
|
||||
| sigma = sigmas[step_indices].flatten() | ||||
| while len(sigma.shape) < n_dim: | ||||
| sigma = sigma.unsqueeze(-1) | ||||
|
|
||||
| return sigma | ||||
|
|
||||
| def add_noise( | ||||
| self, | ||||
| original_samples: torch.Tensor, | ||||
| noise: torch.Tensor, | ||||
| timesteps: torch.Tensor, | ||||
| ) -> torch.Tensor: | ||||
| ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 | ||||
| ## Add noise according to flow matching. | ||||
| ## zt = (1 - texp) * x + texp * z1 | ||||
|
|
||||
| # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) | ||||
| # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise | ||||
|
|
||||
| # timestep needs to be in [0, 1], we store them in [0, 1000] | ||||
| # noisy_sample = (1 - timestep) * latent + timestep * noise | ||||
| t_01 = (timesteps / 1000).to(original_samples.device) | ||||
| noisy_model_input = (1 - t_01) * original_samples + t_01 * noise | ||||
|
|
||||
| # n_dim = original_samples.ndim | ||||
| # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) | ||||
| # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise | ||||
| return noisy_model_input | ||||
|
|
||||
| def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: | ||||
| return sample | ||||
|
|
||||
| def set_train_timesteps(self, num_timesteps, device, linear=False): | ||||
| if linear: | ||||
| timesteps = torch.linspace(1000, 0, num_timesteps, device=device) | ||||
| self.timesteps = timesteps | ||||
| return timesteps | ||||
| else: | ||||
| # distribute them closer to center. Inference distributes them as a bias toward first | ||||
| # Generate values from 0 to 1 | ||||
| t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) | ||||
|
|
||||
| # Scale and reverse the values to go from 1000 to 0 | ||||
| timesteps = (1 - t) * 1000 | ||||
|
|
||||
| # Sort the timesteps in descending order | ||||
| timesteps, _ = torch.sort(timesteps, descending=True) | ||||
|
|
||||
| self.timesteps = timesteps.to(device=device) | ||||
|
|
||||
| return timesteps | ||||
|
|
||||
|
|
||||
| def main(args): | ||||
| if args.report_to == "wandb" and args.hub_token is not None: | ||||
| raise ValueError( | ||||
|
|
@@ -1127,7 +1236,7 @@ def main(args): | |||
| ) | ||||
|
|
||||
| # Load scheduler and models | ||||
| noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | ||||
| noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained( | ||||
| args.pretrained_model_name_or_path, subfolder="scheduler" | ||||
| ) | ||||
| noise_scheduler_copy = copy.deepcopy(noise_scheduler) | ||||
|
|
@@ -1456,6 +1565,24 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): | |||
| tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) | ||||
| tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) | ||||
|
|
||||
| vae_config_shift_factor = vae.config.shift_factor | ||||
| vae_config_scaling_factor = vae.config.scaling_factor | ||||
| vae_config_block_out_channels = vae.config.block_out_channels | ||||
| if args.cache_latents: | ||||
| latents_cache = [] | ||||
| for batch in tqdm(train_dataloader, desc="Caching latents"): | ||||
| with torch.no_grad(): | ||||
| batch["pixel_values"] = batch["pixel_values"].to( | ||||
| accelerator.device, non_blocking=True, dtype=weight_dtype | ||||
| ) | ||||
| latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) | ||||
|
|
||||
| if args.validation_prompt is None: | ||||
| del vae | ||||
| if torch.cuda.is_available(): | ||||
| torch.cuda.empty_cache() | ||||
| gc.collect() | ||||
|
||||
| def clear_objs_and_retain_memory(objs: List[Any]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Uh oh!
There was an error while loading. Please reload this page.