Skip to content

Commit 3e74bb2

Browse files
committed
Reformatting with make style, black & isort.
1 parent 42658fa commit 3e74bb2

File tree

5 files changed

+128
-65
lines changed

5 files changed

+128
-65
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 107 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import math
1717
from typing import List, Optional, Tuple
1818

19-
from einops import rearrange
2019
import torch
2120
import torch.nn as nn
2221
import torch.nn.functional as F
22+
from einops import rearrange
23+
2324

2425
try:
2526
from flash_attn import flash_attn_varlen_func
@@ -33,10 +34,10 @@
3334

3435
from ...configuration_utils import ConfigMixin, register_to_config
3536
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
37+
from ...models.attention_processor import Attention
3638
from ...models.modeling_utils import ModelMixin
3739
from ...utils.torch_utils import maybe_allow_in_graph
38-
from ...models.attention_processor import Attention
39-
from ...models.attention_dispatch import dispatch_attention_fn
40+
4041

4142
ADALN_EMBED_DIM = 256
4243
SEQ_MULTI_OF = 32
@@ -88,10 +89,10 @@ def forward(self, t):
8889

8990
class ZSingleStreamAttnProcessor:
9091
"""
91-
Processor for Z-Image single stream attention that adapts the existing Attention class
92-
to match the behavior of the original Z-ImageAttention module.
92+
Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
93+
original Z-ImageAttention module.
9394
"""
94-
95+
9596
_attention_backend = None
9697
_parallel_config = None
9798

@@ -107,24 +108,24 @@ def __call__(
107108
) -> torch.Tensor:
108109
x_shard = hidden_states
109110
x_freqs_cis_shard = image_rotary_emb
110-
111+
111112
query = attn.to_q(x_shard)
112113
key = attn.to_k(x_shard)
113114
value = attn.to_v(x_shard)
114-
115+
115116
seqlen_shard = x_shard.shape[0]
116-
117+
117118
# Reshape to [seq_len, heads, head_dim]
118119
head_dim = query.shape[-1] // attn.heads
119120
query = query.view(seqlen_shard, attn.heads, head_dim)
120121
key = key.view(seqlen_shard, attn.heads, head_dim)
121-
value = value.view(seqlen_shard, attn.heads, head_dim)
122+
value = value.view(seqlen_shard, attn.heads, head_dim)
122123
# Apply Norms
123124
if attn.norm_q is not None:
124125
query = attn.norm_q(query)
125126
if attn.norm_k is not None:
126127
key = attn.norm_k(key)
127-
128+
128129
# Apply RoPE
129130
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
130131
with torch.amp.autocast("cuda", enabled=False):
@@ -136,17 +137,17 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
136137
if x_freqs_cis_shard is not None:
137138
query = apply_rotary_emb(query, x_freqs_cis_shard)
138139
key = apply_rotary_emb(key, x_freqs_cis_shard)
139-
140+
140141
# Cast to correct dtype
141142
dtype = query.dtype
142143
query, key = query.to(dtype), key.to(dtype)
143-
144+
144145
# Flash Attention
145146
softmax_scale = math.sqrt(1 / head_dim)
146147
assert dtype in [torch.float16, torch.bfloat16]
147-
148+
148149
if x_cu_seqlens is None or x_max_item_seqlen is None:
149-
raise ValueError("x_cu_seqlens and x_max_item_seqlen are required for ZSingleStreamAttnProcessor")
150+
raise ValueError("x_cu_seqlens and x_max_item_seqlen are required for ZSingleStreamAttnProcessor")
150151

151152
if flash_attn_varlen_func is not None:
152153
output = flash_attn_varlen_func(
@@ -164,45 +165,50 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
164165
output = output.flatten(-2)
165166
else:
166167
seqlens = (x_cu_seqlens[1:] - x_cu_seqlens[:-1]).cpu().tolist()
167-
168+
168169
q_split = torch.split(query, seqlens, dim=0)
169170
k_split = torch.split(key, seqlens, dim=0)
170171
v_split = torch.split(value, seqlens, dim=0)
171-
172+
172173
q_padded = torch.nn.utils.rnn.pad_sequence(q_split, batch_first=True)
173174
k_padded = torch.nn.utils.rnn.pad_sequence(k_split, batch_first=True)
174175
v_padded = torch.nn.utils.rnn.pad_sequence(v_split, batch_first=True)
175-
176+
176177
batch_size, max_seqlen, _, _ = q_padded.shape
177-
178+
178179
mask = torch.zeros((batch_size, max_seqlen), dtype=torch.bool, device=query.device)
179180
for i, l in enumerate(seqlens):
180181
mask[i, :l] = True
181-
182+
182183
attn_mask = torch.zeros((batch_size, 1, 1, max_seqlen), dtype=query.dtype, device=query.device)
183184
attn_mask.masked_fill_(~mask[:, None, None, :], torch.finfo(query.dtype).min)
184-
185+
185186
q_padded = q_padded.transpose(1, 2)
186187
k_padded = k_padded.transpose(1, 2)
187188
v_padded = v_padded.transpose(1, 2)
188-
189+
189190
output = F.scaled_dot_product_attention(
190-
q_padded, k_padded, v_padded, attn_mask=attn_mask, dropout_p=0.0, scale=softmax_scale
191+
q_padded,
192+
k_padded,
193+
v_padded,
194+
attn_mask=attn_mask,
195+
dropout_p=0.0,
196+
scale=softmax_scale,
191197
)
192-
198+
193199
output = output.transpose(1, 2)
194-
200+
195201
out_list = []
196202
for i, l in enumerate(seqlens):
197203
out_list.append(output[i, :l])
198-
204+
199205
output = torch.cat(out_list, dim=0)
200206
output = output.flatten(-2)
201207

202208
output = attn.to_out[0](output)
203-
if len(attn.to_out) > 1: # dropout
204-
output = attn.to_out[1](output)
205-
209+
if len(attn.to_out) > 1: # dropout
210+
output = attn.to_out[1](output)
211+
206212
return output
207213

208214

@@ -226,12 +232,19 @@ def forward(self, x):
226232
@maybe_allow_in_graph
227233
class ZImageTransformerBlock(nn.Module):
228234
def __init__(
229-
self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True
235+
self,
236+
layer_id: int,
237+
dim: int,
238+
n_heads: int,
239+
n_kv_heads: int,
240+
norm_eps: float,
241+
qk_norm: bool,
242+
modulation=True,
230243
):
231244
super().__init__()
232245
self.dim = dim
233246
self.head_dim = dim // n_heads
234-
247+
235248
# Refactored to use diffusers Attention with custom processor
236249
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
237250
self.attention = Attention(
@@ -244,7 +257,7 @@ def __init__(
244257
bias=False,
245258
processor=ZSingleStreamAttnProcessor(),
246259
)
247-
260+
248261
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
249262
self.layer_id = layer_id
250263

@@ -284,7 +297,12 @@ def forward(
284297
x_src_ids_shard = None
285298

286299
x_shard = self.attn_forward(
287-
x_shard, x_freqs_cis_shard, x_cu_seqlens, x_max_item_seqlen, scale_gate_msa, x_src_ids_shard
300+
x_shard,
301+
x_freqs_cis_shard,
302+
x_cu_seqlens,
303+
x_max_item_seqlen,
304+
scale_gate_msa,
305+
x_src_ids_shard,
288306
)
289307

290308
x_shard = self.ffn_forward(x_shard, scale_gate_mlp, x_src_ids_shard)
@@ -303,22 +321,22 @@ def attn_forward(
303321
if self.modulation:
304322
assert scale_gate is not None and x_src_ids_shard is not None
305323
scale_msa, gate_msa = scale_gate
306-
324+
307325
# Pass extra args needed for ZSingleStreamAttnProcessor
308326
attn_out = self.attention(
309327
self.attention_norm1(x_shard) * scale_msa[x_src_ids_shard],
310328
image_rotary_emb=x_freqs_cis_shard,
311329
x_cu_seqlens=x_cu_seqlens,
312-
x_max_item_seqlen=x_max_item_seqlen
330+
x_max_item_seqlen=x_max_item_seqlen,
313331
)
314-
332+
315333
x_shard = x_shard + gate_msa[x_src_ids_shard] * self.attention_norm2(attn_out)
316334
else:
317335
attn_out = self.attention(
318336
self.attention_norm1(x_shard),
319337
image_rotary_emb=x_freqs_cis_shard,
320338
x_cu_seqlens=x_cu_seqlens,
321-
x_max_item_seqlen=x_max_item_seqlen
339+
x_max_item_seqlen=x_max_item_seqlen,
322340
)
323341
x_shard = x_shard + self.attention_norm2(attn_out)
324342
return x_shard
@@ -371,7 +389,10 @@ def forward(self, x_shard, x_src_ids_shard, c):
371389

372390
class RopeEmbedder:
373391
def __init__(
374-
self, theta: float = 256.0, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (64, 128, 128)
392+
self,
393+
theta: float = 256.0,
394+
axes_dims: List[int] = (16, 56, 56),
395+
axes_lens: List[int] = (64, 128, 128),
375396
):
376397
self.theta = theta
377398
self.axes_dims = axes_dims
@@ -458,13 +479,29 @@ def __init__(
458479
self.all_final_layer = nn.ModuleDict(all_final_layer)
459480
self.noise_refiner = nn.ModuleList(
460481
[
461-
ZImageTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True)
482+
ZImageTransformerBlock(
483+
1000 + layer_id,
484+
dim,
485+
n_heads,
486+
n_kv_heads,
487+
norm_eps,
488+
qk_norm,
489+
modulation=True,
490+
)
462491
for layer_id in range(n_refiner_layers)
463492
]
464493
)
465494
self.context_refiner = nn.ModuleList(
466495
[
467-
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False)
496+
ZImageTransformerBlock(
497+
layer_id,
498+
dim,
499+
n_heads,
500+
n_kv_heads,
501+
norm_eps,
502+
qk_norm,
503+
modulation=False,
504+
)
468505
for layer_id in range(n_refiner_layers)
469506
]
470507
)
@@ -524,8 +561,6 @@ def patchify_and_embed(
524561
patch_size: int,
525562
f_patch_size: int,
526563
):
527-
528-
bsz = len(all_image)
529564
pH = pW = patch_size
530565
pF = f_patch_size
531566
device = all_image[0].device
@@ -560,7 +595,10 @@ def patchify_and_embed(
560595
)
561596
)
562597
# padded feature
563-
cap_padded_feat = torch.cat([all_cap_feats[i], all_cap_feats[i][-1:].repeat(cap_padding_len, 1)], dim=0)
598+
cap_padded_feat = torch.cat(
599+
[all_cap_feats[i], all_cap_feats[i][-1:].repeat(cap_padding_len, 1)],
600+
dim=0,
601+
)
564602
all_cap_feats_out.append(cap_padded_feat)
565603

566604
### Process Image
@@ -623,7 +661,6 @@ def forward(
623661
patch_size=2,
624662
f_patch_size=1,
625663
):
626-
627664
assert patch_size in self.all_patch_size
628665
assert f_patch_size in self.all_f_patch_size
629666

@@ -649,7 +686,11 @@ def forward(
649686
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
650687
x_max_item_seqlen = max(x_item_seqlens)
651688
x_cu_seqlens = F.pad(
652-
torch.cumsum(torch.tensor(x_item_seqlens, dtype=torch.int32, device=device), dim=0, dtype=torch.int32),
689+
torch.cumsum(
690+
torch.tensor(x_item_seqlens, dtype=torch.int32, device=device),
691+
dim=0,
692+
dtype=torch.int32,
693+
),
653694
(1, 0),
654695
)
655696
x_src_ids = [
@@ -666,15 +707,26 @@ def forward(
666707
x_shard = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_shard)
667708
x_shard[x_pad_mask_shard] = self.x_pad_token
668709
for layer in self.noise_refiner:
669-
x_shard = layer(x_shard, x_src_ids_shard, x_freqs_cis_shard, x_cu_seqlens, x_max_item_seqlen, adaln_input)
710+
x_shard = layer(
711+
x_shard,
712+
x_src_ids_shard,
713+
x_freqs_cis_shard,
714+
x_cu_seqlens,
715+
x_max_item_seqlen,
716+
adaln_input,
717+
)
670718
x_flatten = x_shard
671719

672720
# cap embed & refine
673721
cap_item_seqlens = [len(_) for _ in cap_feats]
674722
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
675723
cap_max_item_seqlen = max(cap_item_seqlens)
676724
cap_cu_seqlens = F.pad(
677-
torch.cumsum(torch.tensor(cap_item_seqlens, dtype=torch.int32, device=device), dim=0, dtype=torch.int32),
725+
torch.cumsum(
726+
torch.tensor(cap_item_seqlens, dtype=torch.int32, device=device),
727+
dim=0,
728+
dtype=torch.int32,
729+
),
678730
(1, 0),
679731
)
680732
cap_src_ids = [
@@ -705,14 +757,20 @@ def merge_interleave(l1, l2):
705757
return list(itertools.chain(*zip(l1, l2)))
706758

707759
unified = torch.cat(
708-
merge_interleave(cap_flatten.split(cap_item_seqlens, dim=0), x_flatten.split(x_item_seqlens, dim=0)), dim=0
760+
merge_interleave(
761+
cap_flatten.split(cap_item_seqlens, dim=0),
762+
x_flatten.split(x_item_seqlens, dim=0),
763+
),
764+
dim=0,
709765
)
710766
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
711767
assert len(unified) == sum(unified_item_seqlens)
712768
unified_max_item_seqlen = max(unified_item_seqlens)
713769
unified_cu_seqlens = F.pad(
714770
torch.cumsum(
715-
torch.tensor(unified_item_seqlens, dtype=torch.int32, device=device), dim=0, dtype=torch.int32
771+
torch.tensor(unified_item_seqlens, dtype=torch.int32, device=device),
772+
dim=0,
773+
dtype=torch.int32,
716774
),
717775
(1, 0),
718776
)

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,12 +820,12 @@
820820
WanVACEPipeline,
821821
WanVideoToVideoPipeline,
822822
)
823-
from .z_image import ZImagePipeline
824823
from .wuerstchen import (
825824
WuerstchenCombinedPipeline,
826825
WuerstchenDecoderPipeline,
827826
WuerstchenPriorPipeline,
828827
)
828+
from .z_image import ZImagePipeline
829829

830830
try:
831831
if not is_onnx_available():

src/diffusers/pipelines/z_image/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,3 @@
4848

4949
for name, value in _dummy_objects.items():
5050
setattr(sys.modules[__name__], name, value)
51-

0 commit comments

Comments
 (0)