Skip to content

Commit 0ff1aa9

Browse files
Brvcketlinoytsabansayakpaul
authored
[fix] fix for prior preservation and mixed precision sampling (huggingface#11873)
Co-authored-by: Linoy Tsaban <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 901da9d commit 0ff1aa9

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,7 @@ def main(args):
12701270
subfolder="transformer",
12711271
revision=args.revision,
12721272
variant=args.variant,
1273+
torch_dtype=torch_dtype,
12731274
)
12741275
pipeline = FluxKontextPipeline.from_pretrained(
12751276
args.pretrained_model_name_or_path,
@@ -1292,7 +1293,8 @@ def main(args):
12921293
for example in tqdm(
12931294
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
12941295
):
1295-
images = pipeline(example["prompt"]).images
1296+
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
1297+
images = pipeline(prompt=example["prompt"]).images
12961298

12971299
for i, image in enumerate(images):
12981300
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
@@ -1899,6 +1901,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18991901
device=accelerator.device,
19001902
prompt=args.instance_prompt,
19011903
)
1904+
else:
1905+
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
1906+
prompts, text_encoders, tokenizers
1907+
)
19021908

19031909
# Convert images to latent space
19041910
if args.cache_latents:

0 commit comments

Comments
 (0)