Skip to content

Commit 1dd587b

Browse files
committed
Merge remote-tracking branch 'JerryWu-code/z-image-dev' into fork/JerryWu-code/z-image
# Conflicts: # src/diffusers/models/transformers/transformer_z_image.py
2 parents 7df350d + a4b89a0 commit 1dd587b

File tree

3 files changed

+54
-132
lines changed

3 files changed

+54
-132
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 49 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,25 @@
2121
import torch.nn.functional as F
2222
from einops import rearrange
2323

24+
from ...configuration_utils import ConfigMixin, register_to_config
25+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
26+
from ...models.attention_processor import Attention
27+
from ...models.modeling_utils import ModelMixin
28+
from ...utils.import_utils import is_apex_available, is_flash_attn_available
29+
from ...utils.torch_utils import maybe_allow_in_graph
2430

25-
try:
31+
32+
if is_flash_attn_available():
2633
from flash_attn import flash_attn_varlen_func
27-
except ImportError:
34+
else:
2835
flash_attn_varlen_func = None
2936

30-
# todo see how other teams do this
31-
try:
37+
if is_apex_available():
38+
# Here needs apex with "APEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation ."
3239
from apex.normalization import FusedRMSNorm as RMSNorm
33-
except ImportError:
40+
else:
3441
from torch.nn import RMSNorm
3542

36-
from ...configuration_utils import ConfigMixin, register_to_config
37-
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
38-
from ...models.attention_processor import Attention
39-
from ...models.modeling_utils import ModelMixin
40-
from ...utils.torch_utils import maybe_allow_in_graph
41-
4243

4344
ADALN_EMBED_DIM = 256
4445
SEQ_MULTI_OF = 32
@@ -103,20 +104,20 @@ def __call__(
103104
x_cu_seqlens: Optional[torch.Tensor] = None,
104105
x_max_item_seqlen: Optional[int] = None,
105106
) -> torch.Tensor:
106-
x_shard = hidden_states
107-
x_freqs_cis_shard = image_rotary_emb
107+
x = hidden_states
108+
x_freqs_cis = image_rotary_emb
108109

109-
query = attn.to_q(x_shard)
110-
key = attn.to_k(x_shard)
111-
value = attn.to_v(x_shard)
110+
query = attn.to_q(x)
111+
key = attn.to_k(x)
112+
value = attn.to_v(x)
112113

113-
seqlen_shard = x_shard.shape[0]
114+
seqlen = x.shape[0]
114115

115116
# Reshape to [seq_len, heads, head_dim]
116117
head_dim = query.shape[-1] // attn.heads
117-
query = query.view(seqlen_shard, attn.heads, head_dim)
118-
key = key.view(seqlen_shard, attn.heads, head_dim)
119-
value = value.view(seqlen_shard, attn.heads, head_dim)
118+
query = query.view(seqlen, attn.heads, head_dim)
119+
key = key.view(seqlen, attn.heads, head_dim)
120+
value = value.view(seqlen, attn.heads, head_dim)
120121
# Apply Norms
121122
if attn.norm_q is not None:
122123
query = attn.norm_q(query)
@@ -131,9 +132,9 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
131132
x_out = torch.view_as_real(x * freqs_cis).flatten(2)
132133
return x_out.type_as(x_in)
133134

134-
if x_freqs_cis_shard is not None:
135-
query = apply_rotary_emb(query, x_freqs_cis_shard)
136-
key = apply_rotary_emb(key, x_freqs_cis_shard)
135+
if x_freqs_cis is not None:
136+
query = apply_rotary_emb(query, x_freqs_cis)
137+
key = apply_rotary_emb(key, x_freqs_cis)
137138

138139
# Cast to correct dtype
139140
dtype = query.dtype
@@ -274,9 +275,9 @@ def __init__(
274275

275276
def forward(
276277
self,
277-
x_shard: torch.Tensor,
278-
x_src_ids_shard: torch.Tensor,
279-
x_freqs_cis_shard: torch.Tensor,
278+
x: torch.Tensor,
279+
x_src_ids: torch.Tensor,
280+
x_freqs_cis: torch.Tensor,
280281
x_cu_seqlens: torch.Tensor,
281282
x_max_item_seqlen: int,
282283
adaln_input: Optional[torch.Tensor] = None,
@@ -286,80 +287,40 @@ def forward(
286287
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
287288
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
288289
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
289-
scale_gate_msa = (scale_msa, gate_msa)
290-
scale_gate_mlp = (scale_mlp, gate_mlp)
291-
else:
292-
scale_gate_msa = None
293-
scale_gate_mlp = None
294-
x_src_ids_shard = None
295-
296-
x_shard = self.attn_forward(
297-
x_shard,
298-
x_freqs_cis_shard,
299-
x_cu_seqlens,
300-
x_max_item_seqlen,
301-
scale_gate_msa,
302-
x_src_ids_shard,
303-
)
304290

305-
x_shard = self.ffn_forward(x_shard, scale_gate_mlp, x_src_ids_shard)
306-
307-
return x_shard
308-
309-
def attn_forward(
310-
self,
311-
x_shard,
312-
x_freqs_cis_shard,
313-
x_cu_seqlens,
314-
x_max_item_seqlen,
315-
scale_gate: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
316-
x_src_ids_shard: Optional[torch.Tensor] = None,
317-
):
318-
if self.modulation:
319-
assert scale_gate is not None and x_src_ids_shard is not None
320-
scale_msa, gate_msa = scale_gate
321-
322-
# Pass extra args needed for ZSingleStreamAttnProcessor
291+
# Attention block
323292
attn_out = self.attention(
324-
self.attention_norm1(x_shard) * scale_msa[x_src_ids_shard],
325-
image_rotary_emb=x_freqs_cis_shard,
293+
self.attention_norm1(x) * scale_msa[x_src_ids],
294+
image_rotary_emb=x_freqs_cis,
326295
x_cu_seqlens=x_cu_seqlens,
327296
x_max_item_seqlen=x_max_item_seqlen,
328297
)
298+
x = x + gate_msa[x_src_ids] * self.attention_norm2(attn_out)
329299

330-
x_shard = x_shard + gate_msa[x_src_ids_shard] * self.attention_norm2(attn_out)
300+
# FFN block
301+
x = x + gate_mlp[x_src_ids] * self.ffn_norm2(
302+
self.feed_forward(
303+
self.ffn_norm1(x) * scale_mlp[x_src_ids],
304+
)
305+
)
331306
else:
307+
# Attention block
332308
attn_out = self.attention(
333-
self.attention_norm1(x_shard),
334-
image_rotary_emb=x_freqs_cis_shard,
309+
self.attention_norm1(x),
310+
image_rotary_emb=x_freqs_cis,
335311
x_cu_seqlens=x_cu_seqlens,
336312
x_max_item_seqlen=x_max_item_seqlen,
337313
)
338-
x_shard = x_shard + self.attention_norm2(attn_out)
339-
return x_shard
314+
x = x + self.attention_norm2(attn_out)
340315

341-
def ffn_forward(
342-
self,
343-
x_shard,
344-
scale_gate: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
345-
x_src_ids_shard: Optional[torch.Tensor] = None,
346-
):
347-
if self.modulation:
348-
assert scale_gate is not None and x_src_ids_shard is not None
349-
scale_mlp, gate_mlp = scale_gate
350-
x_shard = x_shard + gate_mlp[x_src_ids_shard] * self.ffn_norm2(
316+
# FFN block
317+
x = x + self.ffn_norm2(
351318
self.feed_forward(
352-
self.ffn_norm1(x_shard) * scale_mlp[x_src_ids_shard],
319+
self.ffn_norm1(x),
353320
)
354321
)
355322

356-
else:
357-
x_shard = x_shard + self.ffn_norm2(
358-
self.feed_forward(
359-
self.ffn_norm1(x_shard),
360-
)
361-
)
362-
return x_shard
323+
return x
363324

364325

365326
class FinalLayer(nn.Module):
@@ -377,11 +338,11 @@ def __init__(self, hidden_size, out_channels):
377338
nn.init.zeros_(self.adaLN_modulation[1].weight)
378339
nn.init.zeros_(self.adaLN_modulation[1].bias)
379340

380-
def forward(self, x_shard, x_src_ids_shard, c):
341+
def forward(self, x, x_src_ids, c):
381342
scale = 1.0 + self.adaLN_modulation(c)
382-
x_shard = self.norm_final(x_shard) * scale[x_src_ids_shard]
383-
x_shard = self.linear(x_shard)
384-
return x_shard
343+
x = self.norm_final(x) * scale[x_src_ids]
344+
x = self.linear(x)
345+
return x
385346

386347

387348
class RopeEmbedder:
@@ -465,8 +426,6 @@ def __init__(
465426
all_final_layer = {}
466427
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
467428
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
468-
nn.init.xavier_uniform_(x_embedder.weight)
469-
nn.init.constant_(x_embedder.bias, 0.0)
470429
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
471430

472431
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
@@ -793,16 +752,3 @@ def forward(
793752
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
794753

795754
return x, {}
796-
797-
def parameter_count(self) -> int:
798-
total_params = 0
799-
800-
def _recursive_count_params(module):
801-
nonlocal total_params
802-
for param in module.parameters(recurse=False):
803-
total_params += param.numel()
804-
for submodule in module.children():
805-
_recursive_count_params(submodule)
806-
807-
_recursive_count_params(self)
808-
return total_params

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -249,35 +249,6 @@ def _encode_prompt(
249249

250250
return embeddings_list
251251

252-
def enable_vae_slicing(self):
253-
r"""
254-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
255-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
256-
"""
257-
self.vae.enable_slicing()
258-
259-
def disable_vae_slicing(self):
260-
r"""
261-
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
262-
computing decoding in one step.
263-
"""
264-
self.vae.disable_slicing()
265-
266-
def enable_vae_tiling(self):
267-
r"""
268-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
269-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
270-
processing larger images.
271-
"""
272-
self.vae.enable_tiling()
273-
274-
def disable_vae_tiling(self):
275-
r"""
276-
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
277-
computing decoding in one step.
278-
"""
279-
self.vae.disable_tiling()
280-
281252
def prepare_latents(
282253
self,
283254
batch_size,

src/diffusers/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
230230
_aiter_available, _aiter_version = _is_package_available("aiter")
231231
_kornia_available, _kornia_version = _is_package_available("kornia")
232232
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
233+
_apex_available, _apex_version = _is_package_available("apex")
233234

234235

235236
def is_torch_available():
@@ -420,6 +421,10 @@ def is_kornia_available():
420421
return _kornia_available
421422

422423

424+
def is_apex_available():
425+
return _apex_available
426+
427+
423428
# docstyle-ignore
424429
FLAX_IMPORT_ERROR = """
425430
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the

0 commit comments

Comments
 (0)