Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,7 @@ def main(args):
torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16

pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
Expand All @@ -1419,7 +1420,8 @@ def main(args):
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
images = pipeline(example["prompt"]).images
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
images = pipeline(prompt=example["prompt"]).images

for i, image in enumerate(images):
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
Expand Down
11 changes: 8 additions & 3 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,7 @@ def main(args):
torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16

pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
Expand All @@ -1151,16 +1152,16 @@ def main(args):
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
images = pipeline(example["prompt"]).images
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
images = pipeline(prompt=example["prompt"]).images

for i, image in enumerate(images):
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1728,6 +1729,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
device=accelerator.device,
prompt=args.instance_prompt,
)
else:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
)

# Convert images to latent space
if args.cache_latents:
Expand Down