diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 137f3222f6d9..6787c37f93a8 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -38,10 +38,7 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.optimization import get_scheduler from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid -from diffusers.training_utils import ( - cast_training_params, - clear_objs_and_retain_memory, -) +from diffusers.training_utils import cast_training_params, free_memory from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -726,7 +723,8 @@ def log_validation( } ) - clear_objs_and_retain_memory([pipe]) + del pipe + free_memory() return videos diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index e344a9b1e2a5..5969218f3c3e 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -54,7 +54,7 @@ from diffusers.models.controlnet_flux import FluxControlNetModel from diffusers.optimization import get_scheduler from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline -from diffusers.training_utils import clear_objs_and_retain_memory, compute_density_for_timestep_sampling +from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available @@ -193,7 +193,8 @@ def log_validation( else: logger.warning(f"image logging not implemented for {tracker.name}") - clear_objs_and_retain_memory([pipeline]) + del pipeline + free_memory() return image_logs @@ -1103,7 +1104,8 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50 ) - clear_objs_and_retain_memory([text_encoders, tokenizers]) + del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + free_memory() # Then get the training dataset ready to be passed to the dataloader. train_dataset = prepare_train_dataset(train_dataset, accelerator) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 4b255c501d99..9ea78370f5e0 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -49,11 +49,7 @@ StableDiffusion3ControlNetPipeline, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import ( - clear_objs_and_retain_memory, - compute_density_for_timestep_sampling, - compute_loss_weighting_for_sd3, -) +from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -174,7 +170,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v else: logger.warning(f"image logging not implemented for {tracker.name}") - clear_objs_and_retain_memory(pipeline) + del pipeline + free_memory() if not is_final_validation: controlnet.to(accelerator.device) @@ -1131,7 +1128,9 @@ def compute_text_embeddings(batch, text_encoders, tokenizers): new_fingerprint = Hasher.hash(args) train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) - clear_objs_and_retain_memory(text_encoders + tokenizers) + del text_encoder_one, text_encoder_two, text_encoder_three + del tokenizer_one, tokenizer_two, tokenizer_three + free_memory() train_dataloader = torch.utils.data.DataLoader( train_dataset, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 6091622719ee..fcc11386abcf 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -55,9 +55,9 @@ from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, - clear_objs_and_retain_memory, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + free_memory, ) from diffusers.utils import ( check_min_version, @@ -1437,7 +1437,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: - clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two]) + del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1480,7 +1481,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) if args.validation_prompt is None: - clear_objs_and_retain_memory([vae]) + del vae + free_memory() # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -1817,7 +1819,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): torch_dtype=weight_dtype, ) if not args.train_text_encoder: - clear_objs_and_retain_memory([text_encoder_one, text_encoder_two]) + del text_encoder_one, text_encoder_two + free_memory() # Save the lora layers accelerator.wait_for_everyone() diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 3060813bbbdc..02f5a7ee0f7a 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -55,9 +55,9 @@ from diffusers.training_utils import ( _set_state_dict_into_text_encoder, cast_training_params, - clear_objs_and_retain_memory, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + free_memory, ) from diffusers.utils import ( check_min_version, @@ -211,7 +211,8 @@ def log_validation( } ) - clear_objs_and_retain_memory(objs=[pipeline]) + del pipeline + free_memory() return images @@ -1106,7 +1107,8 @@ def main(args): image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" image.save(image_filename) - clear_objs_and_retain_memory(objs=[pipeline]) + del pipeline + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1453,9 +1455,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection - clear_objs_and_retain_memory( - objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three] - ) + del tokenizers, text_encoders + del text_encoder_one, text_encoder_two, text_encoder_three + free_memory() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1791,11 +1793,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): epoch=epoch, torch_dtype=weight_dtype, ) - objs = [] - if not args.train_text_encoder: - objs.extend([text_encoder_one, text_encoder_two, text_encoder_three]) - clear_objs_and_retain_memory(objs=objs) + del text_encoder_one, text_encoder_two, text_encoder_three + free_memory() # Save the lora layers accelerator.wait_for_everyone() diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 26d4a2a504c6..57bd9074870c 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -260,12 +260,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting -def clear_objs_and_retain_memory(objs: List[Any]): - """Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator.""" - if len(objs) >= 1: - for obj in objs: - del obj - +def free_memory(): + """Runs garbage collection. Then clears the cache of the available accelerator.""" gc.collect() if torch.cuda.is_available():