Skip to content

Commit 1f1ce2e

Browse files
authored
Merge branch 'main' into WanVACE_Attention
2 parents fb531ec + 006d092 commit 1f1ce2e

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,7 @@ def main(args):
13991399
torch_dtype = torch.float16
14001400
elif args.prior_generation_precision == "bf16":
14011401
torch_dtype = torch.bfloat16
1402+
14021403
pipeline = FluxPipeline.from_pretrained(
14031404
args.pretrained_model_name_or_path,
14041405
torch_dtype=torch_dtype,
@@ -1419,7 +1420,8 @@ def main(args):
14191420
for example in tqdm(
14201421
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
14211422
):
1422-
images = pipeline(example["prompt"]).images
1423+
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
1424+
images = pipeline(prompt=example["prompt"]).images
14231425

14241426
for i, image in enumerate(images):
14251427
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,7 @@ def main(args):
11311131
torch_dtype = torch.float16
11321132
elif args.prior_generation_precision == "bf16":
11331133
torch_dtype = torch.bfloat16
1134+
11341135
pipeline = FluxPipeline.from_pretrained(
11351136
args.pretrained_model_name_or_path,
11361137
torch_dtype=torch_dtype,
@@ -1151,16 +1152,16 @@ def main(args):
11511152
for example in tqdm(
11521153
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
11531154
):
1154-
images = pipeline(example["prompt"]).images
1155+
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
1156+
images = pipeline(prompt=example["prompt"]).images
11551157

11561158
for i, image in enumerate(images):
11571159
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
11581160
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
11591161
image.save(image_filename)
11601162

11611163
del pipeline
1162-
if torch.cuda.is_available():
1163-
torch.cuda.empty_cache()
1164+
free_memory()
11641165

11651166
# Handle the repository creation
11661167
if accelerator.is_main_process:
@@ -1728,6 +1729,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17281729
device=accelerator.device,
17291730
prompt=args.instance_prompt,
17301731
)
1732+
else:
1733+
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
1734+
prompts, text_encoders, tokenizers
1735+
)
17311736

17321737
# Convert images to latent space
17331738
if args.cache_latents:

0 commit comments

Comments
 (0)