From adf42c4b01d007e122f59c954435575cf70dba44 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 11:50:33 +0300 Subject: [PATCH 01/13] add latent caching + smol updates --- .../dreambooth/train_dreambooth_lora_sd3.py | 35 +++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 02f5a7ee0f7a..5ce25273cdeb 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -608,6 +608,12 @@ def parse_args(input_args=None): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) parser.add_argument( "--report_to", type=str, @@ -1394,6 +1400,16 @@ def load_model_hook(models, input_dir): logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warning( + f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, @@ -1440,6 +1456,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers @@ -1500,7 +1519,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): power=args.lr_power, ) - # Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`. if args.train_text_encoder: ( @@ -1607,8 +1625,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] + if args.train_text_encoder: + models_to_accumulate.extend([text_encoder_one]) with accelerator.accumulate(models_to_accumulate): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image - @@ -1639,7 +1658,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # Convert images to latent space - model_input = vae.encode(pixel_values).latent_dist.sample() + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) @@ -1793,9 +1816,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): epoch=epoch, torch_dtype=weight_dtype, ) - - del text_encoder_one, text_encoder_two, text_encoder_three - free_memory() + if not args.train_text_encoder: + del text_encoder_one, text_encoder_two, text_encoder_three + free_memory() # Save the lora layers accelerator.wait_for_everyone() From 0c7355189e3bd06ca5aaeb8f13d97a1ee6a2c449 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 11:54:04 +0300 Subject: [PATCH 02/13] update license --- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- examples/dreambooth/train_dreambooth_sd3.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 5ce25273cdeb..2a046cd9ea8c 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -140,7 +140,7 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="openrail++", + license="other", base_model=base_model, prompt=instance_prompt, model_description=model_description, diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index c34024f478c1..97ca44726e8f 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -119,7 +119,7 @@ def save_model_card( model_card = load_or_create_model_card( repo_id_or_path=repo_id, from_training=True, - license="openrail++", + license="other", base_model=base_model, prompt=instance_prompt, model_description=model_description, From 58e93fdfe594cdecd9193ca605f6e74efa80c905 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 11:57:12 +0300 Subject: [PATCH 03/13] replace with free_memory --- examples/dreambooth/train_dreambooth_sd3.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 97ca44726e8f..ff077b825846 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -51,7 +51,7 @@ StableDiffusion3Pipeline, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3 +from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory from diffusers.utils import ( check_min_version, is_wandb_available, @@ -190,8 +190,7 @@ def log_validation( ) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() return images @@ -1065,8 +1064,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1386,9 +1384,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): del tokenizers, text_encoders # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection del text_encoder_one, text_encoder_two, text_encoder_three - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1730,8 +1726,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) if not args.train_text_encoder: del text_encoder_one, text_encoder_two, text_encoder_three - torch.cuda.empty_cache() - gc.collect() + free_memory() # Save the lora layers accelerator.wait_for_everyone() From e372e0a20bcd79b72bd5b1536bfd98de6d55e7bb Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 12:00:32 +0300 Subject: [PATCH 04/13] add --upcast_before_saving to allow saving transformer weights in lower precision --- examples/dreambooth/train_dreambooth_lora_sd3.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 2a046cd9ea8c..7b418d078c08 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -634,6 +634,15 @@ def parse_args(input_args=None): " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) parser.add_argument( "--prior_generation_precision", type=str, @@ -1824,7 +1833,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) - transformer = transformer.to(torch.float32) + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) if args.train_text_encoder: From 3e57cfe9495d09aeddaf04daae05ac4ed16c0374 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 14:39:44 +0300 Subject: [PATCH 05/13] fix models to accumulate --- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 7b418d078c08..5c2addb4f31b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1635,7 +1635,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] if args.train_text_encoder: - models_to_accumulate.extend([text_encoder_one]) + models_to_accumulate.extend([text_encoder_one, text_encoder_two]) with accelerator.accumulate(models_to_accumulate): prompts = batch["prompts"] From 5c07de9fe3e27bd9338b95db9f8738a96d5323f6 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 15:04:31 +0300 Subject: [PATCH 06/13] fix mixed precision issue as proposed in https://github.com/huggingface/diffusers/pull/9565 --- examples/dreambooth/train_dreambooth_lora_sd3.py | 4 +++- examples/dreambooth/train_dreambooth_sd3.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 5c2addb4f31b..b21f0aae377e 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -186,7 +186,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1805,6 +1805,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three ) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) pipeline = StableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index ff077b825846..7273f99c1293 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -164,7 +164,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1704,6 +1704,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three ) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) + text_encoder_three.to(weight_dtype) pipeline = StableDiffusion3Pipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, From 0e56904445eb8f36b05b609919d6a6decdac93e1 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 15:09:42 +0300 Subject: [PATCH 07/13] smol update to readme --- examples/dreambooth/README_sd3.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/README_sd3.md b/examples/dreambooth/README_sd3.md index 6f41c395629a..a340be350db8 100644 --- a/examples/dreambooth/README_sd3.md +++ b/examples/dreambooth/README_sd3.md @@ -136,7 +136,7 @@ accelerate launch train_dreambooth_lora_sd3.py \ --resolution=512 \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ - --learning_rate=1e-5 \ + --learning_rate=4e-4 \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ From b3a7860d2d872a59a553b000f1861bffb2075d5a Mon Sep 17 00:00:00 2001 From: Linoy Date: Tue, 15 Oct 2024 12:44:28 +0000 Subject: [PATCH 08/13] style --- examples/dreambooth/train_dreambooth_sd3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 7273f99c1293..455ba5a9293d 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -15,7 +15,6 @@ import argparse import copy -import gc import itertools import logging import math From a571ac900e72347069cef704d7f34fb4ffb7174e Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 15:47:25 +0300 Subject: [PATCH 09/13] fix caching latents --- examples/dreambooth/train_dreambooth_lora_sd3.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index b21f0aae377e..01449cbfbcee 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1512,6 +1512,20 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0) + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + free_memory() + + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) From 6090998acc09958399687105fbeb07269635c06a Mon Sep 17 00:00:00 2001 From: Linoy Date: Tue, 15 Oct 2024 12:48:07 +0000 Subject: [PATCH 10/13] style --- examples/dreambooth/train_dreambooth_lora_sd3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 01449cbfbcee..703c64f8e399 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1525,7 +1525,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): del vae free_memory() - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) From 5c32829212e312b472a1701faedae2d13d7c3575 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 18:12:00 +0300 Subject: [PATCH 11/13] add tests for latent caching --- .../dreambooth/test_dreambooth_lora_sd3.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/examples/dreambooth/test_dreambooth_lora_sd3.py b/examples/dreambooth/test_dreambooth_lora_sd3.py index 518738b78246..7b0da0ac4d10 100644 --- a/examples/dreambooth/test_dreambooth_lora_sd3.py +++ b/examples/dreambooth/test_dreambooth_lora_sd3.py @@ -102,7 +102,38 @@ def test_dreambooth_lora_text_encoder_sd3(self): (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() ) self.assertTrue(starts_with_expected_prefix) + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" From 93f5e046913f270d799d4ecbd96f6fab916e5bb2 Mon Sep 17 00:00:00 2001 From: Linoy Date: Tue, 15 Oct 2024 15:12:37 +0000 Subject: [PATCH 12/13] style --- examples/dreambooth/test_dreambooth_lora_sd3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/test_dreambooth_lora_sd3.py b/examples/dreambooth/test_dreambooth_lora_sd3.py index 7b0da0ac4d10..ec323be4143e 100644 --- a/examples/dreambooth/test_dreambooth_lora_sd3.py +++ b/examples/dreambooth/test_dreambooth_lora_sd3.py @@ -102,6 +102,7 @@ def test_dreambooth_lora_text_encoder_sd3(self): (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() ) self.assertTrue(starts_with_expected_prefix) + def test_dreambooth_lora_latent_caching(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" @@ -134,6 +135,7 @@ def test_dreambooth_lora_latent_caching(self): # with `"transformer"` in their names. starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" From b69f14994789089cbdd1caf031184e76f436fb63 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 15 Oct 2024 18:25:04 +0300 Subject: [PATCH 13/13] fix latent caching --- examples/dreambooth/train_dreambooth_lora_sd3.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 703c64f8e399..8d0b6853eeec 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1512,6 +1512,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0) + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor if args.cache_latents: latents_cache = [] for batch in tqdm(train_dataloader, desc="Caching latents"): @@ -1685,7 +1687,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: pixel_values = batch["pixel_values"].to(dtype=vae.dtype) model_input = vae.encode(pixel_values).latent_dist.sample() - model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents