Skip to content

Commit 2f2d8c3

Browse files
committed
Merge branch 'z-image-dev' into z-image
2 parents 336c5ce + 1dd8f3c commit 2f2d8c3

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from einops import rearrange
2221
from torch.nn.utils.rnn import pad_sequence
2322

2423
from ...configuration_utils import ConfigMixin, register_to_config
@@ -429,9 +428,12 @@ def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_pat
429428
for i in range(bsz):
430429
F, H, W = size[i]
431430
ori_len = (F // pF) * (H // pH) * (W // pW)
432-
x[i] = rearrange(
433-
x[i][:ori_len].view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels),
434-
"f h w pf ph pw c -> c (f pf) (h ph) (w pw)",
431+
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
432+
x[i] = (
433+
x[i][:ori_len]
434+
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
435+
.permute(6, 0, 3, 1, 4, 2, 5)
436+
.reshape(self.out_channels, F, H, W)
435437
)
436438
return x
437439

@@ -497,7 +499,8 @@ def patchify_and_embed(
497499
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
498500

499501
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
500-
image = rearrange(image, "c f pf h ph w pw -> (f h w) (pf ph pw c)")
502+
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
503+
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
501504

502505
image_ori_len = len(image)
503506
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -432,26 +432,35 @@ def __call__(
432432
elif prompt is not None and isinstance(prompt, list):
433433
batch_size = len(prompt)
434434
else:
435-
batch_size = prompt_embeds.shape[0]
435+
batch_size = len(prompt_embeds)
436436

437437
lora_scale = (
438438
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
439439
)
440-
(
441-
prompt_embeds,
442-
negative_prompt_embeds,
443-
) = self.encode_prompt(
444-
prompt=prompt,
445-
negative_prompt=negative_prompt,
446-
do_classifier_free_guidance=self.do_classifier_free_guidance,
447-
prompt_embeds=prompt_embeds,
448-
negative_prompt_embeds=negative_prompt_embeds,
449-
dtype=dtype,
450-
device=device,
451-
num_images_per_prompt=num_images_per_prompt,
452-
max_sequence_length=max_sequence_length,
453-
lora_scale=lora_scale,
454-
)
440+
441+
# If prompt_embeds is provided and prompt is None, skip encoding
442+
if prompt_embeds is not None and prompt is None:
443+
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
444+
raise ValueError(
445+
"When `prompt_embeds` is provided without `prompt`, "
446+
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
447+
)
448+
else:
449+
(
450+
prompt_embeds,
451+
negative_prompt_embeds,
452+
) = self.encode_prompt(
453+
prompt=prompt,
454+
negative_prompt=negative_prompt,
455+
do_classifier_free_guidance=self.do_classifier_free_guidance,
456+
prompt_embeds=prompt_embeds,
457+
negative_prompt_embeds=negative_prompt_embeds,
458+
dtype=dtype,
459+
device=device,
460+
num_images_per_prompt=num_images_per_prompt,
461+
max_sequence_length=max_sequence_length,
462+
lora_scale=lora_scale,
463+
)
455464

456465
# 4. Prepare latent variables
457466
num_channels_latents = self.transformer.in_channels

0 commit comments

Comments
 (0)