Skip to content

Commit 549ad57

Browse files
committed
Merge branch 'z-image-dev-ql' into z-image-dev
2 parents 8e391b7 + 69d61e5 commit 549ad57

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,9 @@ def forward(
638638

639639
if torch.is_grad_enabled() and self.gradient_checkpointing:
640640
for layer in self.layers:
641-
unified = self._gradient_checkpointing_func(layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input)
641+
unified = self._gradient_checkpointing_func(
642+
layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
643+
)
642644
else:
643645
for layer in self.layers:
644646
unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@
4545
>>> # pipe.transformer.set_attention_backend("flash")
4646
>>> # (2) Use flash attention 3
4747
>>> # pipe.transformer.set_attention_backend("_flash_3")
48-
49-
>>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
48+
>>> prompt = '一幅为名为"造相「Z-IMAGE-TURBO」"的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。'
5049
>>> image = pipe(
5150
... prompt,
5251
... height=1024,
@@ -432,26 +431,35 @@ def __call__(
432431
elif prompt is not None and isinstance(prompt, list):
433432
batch_size = len(prompt)
434433
else:
435-
batch_size = prompt_embeds.shape[0]
434+
batch_size = len(prompt_embeds)
436435

437436
lora_scale = (
438437
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
439438
)
440-
(
441-
prompt_embeds,
442-
negative_prompt_embeds,
443-
) = self.encode_prompt(
444-
prompt=prompt,
445-
negative_prompt=negative_prompt,
446-
do_classifier_free_guidance=self.do_classifier_free_guidance,
447-
prompt_embeds=prompt_embeds,
448-
negative_prompt_embeds=negative_prompt_embeds,
449-
dtype=dtype,
450-
device=device,
451-
num_images_per_prompt=num_images_per_prompt,
452-
max_sequence_length=max_sequence_length,
453-
lora_scale=lora_scale,
454-
)
439+
440+
# If prompt_embeds is provided and prompt is None, skip encoding
441+
if prompt_embeds is not None and prompt is None:
442+
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
443+
raise ValueError(
444+
"When `prompt_embeds` is provided without `prompt`, "
445+
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
446+
)
447+
else:
448+
(
449+
prompt_embeds,
450+
negative_prompt_embeds,
451+
) = self.encode_prompt(
452+
prompt=prompt,
453+
negative_prompt=negative_prompt,
454+
do_classifier_free_guidance=self.do_classifier_free_guidance,
455+
prompt_embeds=prompt_embeds,
456+
negative_prompt_embeds=negative_prompt_embeds,
457+
dtype=dtype,
458+
device=device,
459+
num_images_per_prompt=num_images_per_prompt,
460+
max_sequence_length=max_sequence_length,
461+
lora_scale=lora_scale,
462+
)
455463

456464
# 4. Prepare latent variables
457465
num_channels_latents = self.transformer.in_channels

0 commit comments

Comments
 (0)