Skip to content

Commit 69d61e5

Browse files
committed
Fix prev bug, and support for prompt_embeds pass in args after prompt pre-encode as List of torch Tensor.
1 parent 38a89ed commit 69d61e5

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -431,26 +431,35 @@ def __call__(
431431
elif prompt is not None and isinstance(prompt, list):
432432
batch_size = len(prompt)
433433
else:
434-
batch_size = prompt_embeds.shape[0]
434+
batch_size = len(prompt_embeds)
435435

436436
lora_scale = (
437437
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
438438
)
439-
(
440-
prompt_embeds,
441-
negative_prompt_embeds,
442-
) = self.encode_prompt(
443-
prompt=prompt,
444-
negative_prompt=negative_prompt,
445-
do_classifier_free_guidance=self.do_classifier_free_guidance,
446-
prompt_embeds=prompt_embeds,
447-
negative_prompt_embeds=negative_prompt_embeds,
448-
dtype=dtype,
449-
device=device,
450-
num_images_per_prompt=num_images_per_prompt,
451-
max_sequence_length=max_sequence_length,
452-
lora_scale=lora_scale,
453-
)
439+
440+
# If prompt_embeds is provided and prompt is None, skip encoding
441+
if prompt_embeds is not None and prompt is None:
442+
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
443+
raise ValueError(
444+
"When `prompt_embeds` is provided without `prompt`, "
445+
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
446+
)
447+
else:
448+
(
449+
prompt_embeds,
450+
negative_prompt_embeds,
451+
) = self.encode_prompt(
452+
prompt=prompt,
453+
negative_prompt=negative_prompt,
454+
do_classifier_free_guidance=self.do_classifier_free_guidance,
455+
prompt_embeds=prompt_embeds,
456+
negative_prompt_embeds=negative_prompt_embeds,
457+
dtype=dtype,
458+
device=device,
459+
num_images_per_prompt=num_images_per_prompt,
460+
max_sequence_length=max_sequence_length,
461+
lora_scale=lora_scale,
462+
)
454463

455464
# 4. Prepare latent variables
456465
num_channels_latents = self.transformer.in_channels

0 commit comments

Comments
 (0)