Skip to content

Commit df0ccd8

Browse files
committed
[fix] fix for prior preservation and mixed precision sampling
1 parent 425a715 commit df0ccd8

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,7 @@ def main(args):
11931193
subfolder="transformer",
11941194
revision=args.revision,
11951195
variant=args.variant,
1196+
torch_dtype=torch_dtype,
11961197
)
11971198
pipeline = FluxKontextPipeline.from_pretrained(
11981199
args.pretrained_model_name_or_path,
@@ -1215,7 +1216,8 @@ def main(args):
12151216
for example in tqdm(
12161217
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
12171218
):
1218-
images = pipeline(example["prompt"]).images
1219+
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
1220+
images = pipeline(prompt=example["prompt"]).images
12191221

12201222
for i, image in enumerate(images):
12211223
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
@@ -1789,6 +1791,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17891791
device=accelerator.device,
17901792
prompt=args.instance_prompt,
17911793
)
1794+
else:
1795+
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
1796+
prompts, text_encoders, tokenizers
1797+
)
17921798

17931799
# Convert images to latent space
17941800
if args.cache_latents:

0 commit comments

Comments
 (0)