Skip to content

Commit a4b89a0

Browse files
committed
Remove init, Modify import utils, Merge forward in transformers block, Remove once func in pipeline.
1 parent 3e74bb2 commit a4b89a0

File tree

3 files changed

+94
-172
lines changed

3 files changed

+94
-172
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 89 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,25 @@
2121
import torch.nn.functional as F
2222
from einops import rearrange
2323

24+
from ...configuration_utils import ConfigMixin, register_to_config
25+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
26+
from ...models.attention_processor import Attention
27+
from ...models.modeling_utils import ModelMixin
28+
from ...utils.import_utils import is_apex_available, is_flash_attn_available
29+
from ...utils.torch_utils import maybe_allow_in_graph
30+
2431

25-
try:
32+
if is_flash_attn_available():
2633
from flash_attn import flash_attn_varlen_func
27-
except ImportError:
34+
else:
2835
flash_attn_varlen_func = None
2936

30-
try:
37+
if is_apex_available():
38+
# Here needs apex with "APEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation ."
3139
from apex.normalization import FusedRMSNorm as RMSNorm
32-
except ImportError:
40+
else:
3341
from torch.nn import RMSNorm
3442

35-
from ...configuration_utils import ConfigMixin, register_to_config
36-
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
37-
from ...models.attention_processor import Attention
38-
from ...models.modeling_utils import ModelMixin
39-
from ...utils.torch_utils import maybe_allow_in_graph
40-
4143

4244
ADALN_EMBED_DIM = 256
4345
SEQ_MULTI_OF = 32
@@ -61,10 +63,6 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
6163
bias=True,
6264
),
6365
)
64-
nn.init.normal_(self.mlp[0].weight, std=0.02)
65-
nn.init.zeros_(self.mlp[0].bias)
66-
nn.init.normal_(self.mlp[2].weight, std=0.02)
67-
nn.init.zeros_(self.mlp[2].bias)
6866

6967
self.frequency_embedding_size = frequency_embedding_size
7068

@@ -106,20 +104,20 @@ def __call__(
106104
x_cu_seqlens: Optional[torch.Tensor] = None,
107105
x_max_item_seqlen: Optional[int] = None,
108106
) -> torch.Tensor:
109-
x_shard = hidden_states
110-
x_freqs_cis_shard = image_rotary_emb
107+
x = hidden_states
108+
x_freqs_cis = image_rotary_emb
111109

112-
query = attn.to_q(x_shard)
113-
key = attn.to_k(x_shard)
114-
value = attn.to_v(x_shard)
110+
query = attn.to_q(x)
111+
key = attn.to_k(x)
112+
value = attn.to_v(x)
115113

116-
seqlen_shard = x_shard.shape[0]
114+
seqlen = x.shape[0]
117115

118116
# Reshape to [seq_len, heads, head_dim]
119117
head_dim = query.shape[-1] // attn.heads
120-
query = query.view(seqlen_shard, attn.heads, head_dim)
121-
key = key.view(seqlen_shard, attn.heads, head_dim)
122-
value = value.view(seqlen_shard, attn.heads, head_dim)
118+
query = query.view(seqlen, attn.heads, head_dim)
119+
key = key.view(seqlen, attn.heads, head_dim)
120+
value = value.view(seqlen, attn.heads, head_dim)
123121
# Apply Norms
124122
if attn.norm_q is not None:
125123
query = attn.norm_q(query)
@@ -134,9 +132,9 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
134132
x_out = torch.view_as_real(x * freqs_cis).flatten(2)
135133
return x_out.type_as(x_in)
136134

137-
if x_freqs_cis_shard is not None:
138-
query = apply_rotary_emb(query, x_freqs_cis_shard)
139-
key = apply_rotary_emb(key, x_freqs_cis_shard)
135+
if x_freqs_cis is not None:
136+
query = apply_rotary_emb(query, x_freqs_cis)
137+
key = apply_rotary_emb(key, x_freqs_cis)
140138

141139
# Cast to correct dtype
142140
dtype = query.dtype
@@ -277,9 +275,9 @@ def __init__(
277275

278276
def forward(
279277
self,
280-
x_shard: torch.Tensor,
281-
x_src_ids_shard: torch.Tensor,
282-
x_freqs_cis_shard: torch.Tensor,
278+
x: torch.Tensor,
279+
x_src_ids: torch.Tensor,
280+
x_freqs_cis: torch.Tensor,
283281
x_cu_seqlens: torch.Tensor,
284282
x_max_item_seqlen: int,
285283
adaln_input: Optional[torch.Tensor] = None,
@@ -289,80 +287,40 @@ def forward(
289287
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
290288
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
291289
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
292-
scale_gate_msa = (scale_msa, gate_msa)
293-
scale_gate_mlp = (scale_mlp, gate_mlp)
294-
else:
295-
scale_gate_msa = None
296-
scale_gate_mlp = None
297-
x_src_ids_shard = None
298-
299-
x_shard = self.attn_forward(
300-
x_shard,
301-
x_freqs_cis_shard,
302-
x_cu_seqlens,
303-
x_max_item_seqlen,
304-
scale_gate_msa,
305-
x_src_ids_shard,
306-
)
307290

308-
x_shard = self.ffn_forward(x_shard, scale_gate_mlp, x_src_ids_shard)
309-
310-
return x_shard
311-
312-
def attn_forward(
313-
self,
314-
x_shard,
315-
x_freqs_cis_shard,
316-
x_cu_seqlens,
317-
x_max_item_seqlen,
318-
scale_gate: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
319-
x_src_ids_shard: Optional[torch.Tensor] = None,
320-
):
321-
if self.modulation:
322-
assert scale_gate is not None and x_src_ids_shard is not None
323-
scale_msa, gate_msa = scale_gate
324-
325-
# Pass extra args needed for ZSingleStreamAttnProcessor
291+
# Attention block
326292
attn_out = self.attention(
327-
self.attention_norm1(x_shard) * scale_msa[x_src_ids_shard],
328-
image_rotary_emb=x_freqs_cis_shard,
293+
self.attention_norm1(x) * scale_msa[x_src_ids],
294+
image_rotary_emb=x_freqs_cis,
329295
x_cu_seqlens=x_cu_seqlens,
330296
x_max_item_seqlen=x_max_item_seqlen,
331297
)
298+
x = x + gate_msa[x_src_ids] * self.attention_norm2(attn_out)
332299

333-
x_shard = x_shard + gate_msa[x_src_ids_shard] * self.attention_norm2(attn_out)
300+
# FFN block
301+
x = x + gate_mlp[x_src_ids] * self.ffn_norm2(
302+
self.feed_forward(
303+
self.ffn_norm1(x) * scale_mlp[x_src_ids],
304+
)
305+
)
334306
else:
307+
# Attention block
335308
attn_out = self.attention(
336-
self.attention_norm1(x_shard),
337-
image_rotary_emb=x_freqs_cis_shard,
309+
self.attention_norm1(x),
310+
image_rotary_emb=x_freqs_cis,
338311
x_cu_seqlens=x_cu_seqlens,
339312
x_max_item_seqlen=x_max_item_seqlen,
340313
)
341-
x_shard = x_shard + self.attention_norm2(attn_out)
342-
return x_shard
314+
x = x + self.attention_norm2(attn_out)
343315

344-
def ffn_forward(
345-
self,
346-
x_shard,
347-
scale_gate: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
348-
x_src_ids_shard: Optional[torch.Tensor] = None,
349-
):
350-
if self.modulation:
351-
assert scale_gate is not None and x_src_ids_shard is not None
352-
scale_mlp, gate_mlp = scale_gate
353-
x_shard = x_shard + gate_mlp[x_src_ids_shard] * self.ffn_norm2(
316+
# FFN block
317+
x = x + self.ffn_norm2(
354318
self.feed_forward(
355-
self.ffn_norm1(x_shard) * scale_mlp[x_src_ids_shard],
319+
self.ffn_norm1(x),
356320
)
357321
)
358322

359-
else:
360-
x_shard = x_shard + self.ffn_norm2(
361-
self.feed_forward(
362-
self.ffn_norm1(x_shard),
363-
)
364-
)
365-
return x_shard
323+
return x
366324

367325

368326
class FinalLayer(nn.Module):
@@ -380,11 +338,11 @@ def __init__(self, hidden_size, out_channels):
380338
nn.init.zeros_(self.adaLN_modulation[1].weight)
381339
nn.init.zeros_(self.adaLN_modulation[1].bias)
382340

383-
def forward(self, x_shard, x_src_ids_shard, c):
341+
def forward(self, x, x_src_ids, c):
384342
scale = 1.0 + self.adaLN_modulation(c)
385-
x_shard = self.norm_final(x_shard) * scale[x_src_ids_shard]
386-
x_shard = self.linear(x_shard)
387-
return x_shard
343+
x = self.norm_final(x) * scale[x_src_ids]
344+
x = self.linear(x)
345+
return x
388346

389347

390348
class RopeEmbedder:
@@ -468,8 +426,6 @@ def __init__(
468426
all_final_layer = {}
469427
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
470428
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
471-
nn.init.xavier_uniform_(x_embedder.weight)
472-
nn.init.constant_(x_embedder.bias, 0.0)
473429
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
474430

475431
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
@@ -698,24 +654,23 @@ def forward(
698654
]
699655
x_freqs_cis = self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)
700656

701-
x_shard = torch.cat(x, dim=0)
702-
x_src_ids_shard = torch.cat(x_src_ids, dim=0)
703-
x_freqs_cis_shard = torch.cat(x_freqs_cis, dim=0)
704-
x_pad_mask_shard = torch.cat(x_pad_mask, dim=0)
705-
del x
657+
x = torch.cat(x, dim=0)
658+
x_src_ids = torch.cat(x_src_ids, dim=0)
659+
x_freqs_cis = torch.cat(x_freqs_cis, dim=0)
660+
x_pad_mask = torch.cat(x_pad_mask, dim=0)
706661

707-
x_shard = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_shard)
708-
x_shard[x_pad_mask_shard] = self.x_pad_token
662+
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
663+
x[x_pad_mask] = self.x_pad_token
709664
for layer in self.noise_refiner:
710-
x_shard = layer(
711-
x_shard,
712-
x_src_ids_shard,
713-
x_freqs_cis_shard,
665+
x = layer(
666+
x,
667+
x_src_ids,
668+
x_freqs_cis,
714669
x_cu_seqlens,
715670
x_max_item_seqlen,
716671
adaln_input,
717672
)
718-
x_flatten = x_shard
673+
x_flatten = x
719674

720675
# cap embed & refine
721676
cap_item_seqlens = [len(_) for _ in cap_feats]
@@ -734,23 +689,23 @@ def forward(
734689
]
735690
cap_freqs_cis = self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)
736691

737-
cap_shard = torch.cat(cap_feats, dim=0)
738-
cap_src_ids_shard = torch.cat(cap_src_ids, dim=0)
739-
cap_freqs_cis_shard = torch.cat(cap_freqs_cis, dim=0)
740-
cap_pad_mask_shard = torch.cat(cap_pad_mask, dim=0)
692+
cap = torch.cat(cap_feats, dim=0)
693+
cap_src_ids = torch.cat(cap_src_ids, dim=0)
694+
cap_freqs_cis = torch.cat(cap_freqs_cis, dim=0)
695+
cap_pad_mask = torch.cat(cap_pad_mask, dim=0)
741696
del cap_feats
742697

743-
cap_shard = self.cap_embedder(cap_shard)
744-
cap_shard[cap_pad_mask_shard] = self.cap_pad_token
698+
cap = self.cap_embedder(cap)
699+
cap[cap_pad_mask] = self.cap_pad_token
745700
for layer in self.context_refiner:
746-
cap_shard = layer(
747-
cap_shard,
748-
cap_src_ids_shard,
749-
cap_freqs_cis_shard,
701+
cap = layer(
702+
cap,
703+
cap_src_ids,
704+
cap_freqs_cis,
750705
cap_cu_seqlens,
751706
cap_max_item_seqlen,
752707
)
753-
cap_flatten = cap_shard
708+
cap_flatten = cap
754709

755710
# unified
756711
def merge_interleave(l1, l2):
@@ -774,41 +729,32 @@ def merge_interleave(l1, l2):
774729
),
775730
(1, 0),
776731
)
777-
unified_src_ids = torch.cat(merge_interleave(cap_src_ids, x_src_ids))
778-
unified_freqs_cis = torch.cat(merge_interleave(cap_freqs_cis, x_freqs_cis))
779-
780-
unified_shard = unified
781-
unified_src_ids_shard = unified_src_ids
782-
unified_freqs_cis_shard = unified_freqs_cis
732+
unified_src_ids = torch.cat(
733+
merge_interleave(
734+
cap_src_ids.split(cap_item_seqlens, dim=0),
735+
x_src_ids.split(x_item_seqlens, dim=0),
736+
)
737+
)
738+
unified_freqs_cis = torch.cat(
739+
merge_interleave(
740+
cap_freqs_cis.split(cap_item_seqlens, dim=0),
741+
x_freqs_cis.split(x_item_seqlens, dim=0),
742+
)
743+
)
783744
for layer in self.layers:
784-
unified_shard = layer(
785-
unified_shard,
786-
unified_src_ids_shard,
787-
unified_freqs_cis_shard,
745+
unified = layer(
746+
unified,
747+
unified_src_ids,
748+
unified_freqs_cis,
788749
unified_cu_seqlens,
789750
unified_max_item_seqlen,
790751
adaln_input,
791752
)
792-
unified_shard = self.all_final_layer[f"{patch_size}-{f_patch_size}"](
793-
unified_shard, unified_src_ids_shard, adaln_input
794-
)
795-
unified = unified_shard.split(unified_item_seqlens, dim=0)
753+
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, unified_src_ids, adaln_input)
754+
unified = unified.split(unified_item_seqlens, dim=0)
796755
x = [unified[i][cap_item_seqlens[i] :] for i in range(bsz)]
797756
assert all(len(x[i]) == x_item_seqlens[i] for i in range(bsz))
798757

799758
x = self.unpatchify(x, x_size, patch_size, f_patch_size)
800759

801760
return x, {}
802-
803-
def parameter_count(self) -> int:
804-
total_params = 0
805-
806-
def _recursive_count_params(module):
807-
nonlocal total_params
808-
for param in module.parameters(recurse=False):
809-
total_params += param.numel()
810-
for submodule in module.children():
811-
_recursive_count_params(submodule)
812-
813-
_recursive_count_params(self)
814-
return total_params

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -249,35 +249,6 @@ def _encode_prompt(
249249

250250
return embeddings_list
251251

252-
def enable_vae_slicing(self):
253-
r"""
254-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
255-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
256-
"""
257-
self.vae.enable_slicing()
258-
259-
def disable_vae_slicing(self):
260-
r"""
261-
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
262-
computing decoding in one step.
263-
"""
264-
self.vae.disable_slicing()
265-
266-
def enable_vae_tiling(self):
267-
r"""
268-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
269-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
270-
processing larger images.
271-
"""
272-
self.vae.enable_tiling()
273-
274-
def disable_vae_tiling(self):
275-
r"""
276-
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
277-
computing decoding in one step.
278-
"""
279-
self.vae.disable_tiling()
280-
281252
def prepare_latents(
282253
self,
283254
batch_size,

0 commit comments

Comments
 (0)