diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 951b989d7a65..a46490e8b3bf 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1399,6 +1399,7 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 + pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -1419,7 +1420,8 @@ def main(args): for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - images = pipeline(example["prompt"]).images + with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): + images = pipeline(prompt=example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 2353625c3878..bd3a974a17d8 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1131,6 +1131,7 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 + pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -1151,7 +1152,8 @@ def main(args): for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - images = pipeline(example["prompt"]).images + with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): + images = pipeline(prompt=example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() @@ -1159,8 +1161,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1728,6 +1729,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): device=accelerator.device, prompt=args.instance_prompt, ) + else: + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) # Convert images to latent space if args.cache_latents: