diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 2097fd398f20..69bd39944a2d 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -178,11 +178,11 @@ def log_validation( else: logger.warning(f"image logging not implemented for {tracker.name}") - del pipeline - gc.collect() - torch.cuda.empty_cache() + del pipeline + gc.collect() + torch.cuda.empty_cache() - return image_logs + return image_logs def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index 232d3da8e820..94f030fe01ef 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -192,9 +192,9 @@ def log_validation( else: logger.warning(f"image logging not implemented for {tracker.name}") - del pipeline - free_memory() - return image_logs + del pipeline + free_memory() + return image_logs def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 488c80e67d59..ecd7572ca39f 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -199,13 +199,13 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v else: logger.warning(f"image logging not implemented for {tracker.name}") - del pipeline - free_memory() + del pipeline + free_memory() - if not is_final_validation: - controlnet.to(accelerator.device) + if not is_final_validation: + controlnet.to(accelerator.device) - return image_logs + return image_logs # Copied from dreambooth sd3 example diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 3368db1ec096..76d232da1c4e 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -201,11 +201,11 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, else: logger.warning(f"image logging not implemented for {tracker.name}") - del pipeline - gc.collect() - torch.cuda.empty_cache() + del pipeline + gc.collect() + torch.cuda.empty_cache() - return image_logs + return image_logs def import_model_class_from_model_name_or_path(