1616import math
1717from typing import List , Optional , Tuple
1818
19- from einops import rearrange
2019import torch
2120import torch .nn as nn
2221import torch .nn .functional as F
22+ from einops import rearrange
23+
2324
2425try :
2526 from flash_attn import flash_attn_varlen_func
3334
3435from ...configuration_utils import ConfigMixin , register_to_config
3536from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
37+ from ...models .attention_processor import Attention
3638from ...models .modeling_utils import ModelMixin
3739from ...utils .torch_utils import maybe_allow_in_graph
38- from ...models .attention_processor import Attention
39- from ...models .attention_dispatch import dispatch_attention_fn
40+
4041
4142ADALN_EMBED_DIM = 256
4243SEQ_MULTI_OF = 32
@@ -88,10 +89,10 @@ def forward(self, t):
8889
8990class ZSingleStreamAttnProcessor :
9091 """
91- Processor for Z-Image single stream attention that adapts the existing Attention class
92- to match the behavior of the original Z-ImageAttention module.
92+ Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
93+ original Z-ImageAttention module.
9394 """
94-
95+
9596 _attention_backend = None
9697 _parallel_config = None
9798
@@ -107,24 +108,24 @@ def __call__(
107108 ) -> torch .Tensor :
108109 x_shard = hidden_states
109110 x_freqs_cis_shard = image_rotary_emb
110-
111+
111112 query = attn .to_q (x_shard )
112113 key = attn .to_k (x_shard )
113114 value = attn .to_v (x_shard )
114-
115+
115116 seqlen_shard = x_shard .shape [0 ]
116-
117+
117118 # Reshape to [seq_len, heads, head_dim]
118119 head_dim = query .shape [- 1 ] // attn .heads
119120 query = query .view (seqlen_shard , attn .heads , head_dim )
120121 key = key .view (seqlen_shard , attn .heads , head_dim )
121- value = value .view (seqlen_shard , attn .heads , head_dim )
122+ value = value .view (seqlen_shard , attn .heads , head_dim )
122123 # Apply Norms
123124 if attn .norm_q is not None :
124125 query = attn .norm_q (query )
125126 if attn .norm_k is not None :
126127 key = attn .norm_k (key )
127-
128+
128129 # Apply RoPE
129130 def apply_rotary_emb (x_in : torch .Tensor , freqs_cis : torch .Tensor ) -> torch .Tensor :
130131 with torch .amp .autocast ("cuda" , enabled = False ):
@@ -136,17 +137,17 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
136137 if x_freqs_cis_shard is not None :
137138 query = apply_rotary_emb (query , x_freqs_cis_shard )
138139 key = apply_rotary_emb (key , x_freqs_cis_shard )
139-
140+
140141 # Cast to correct dtype
141142 dtype = query .dtype
142143 query , key = query .to (dtype ), key .to (dtype )
143-
144+
144145 # Flash Attention
145146 softmax_scale = math .sqrt (1 / head_dim )
146147 assert dtype in [torch .float16 , torch .bfloat16 ]
147-
148+
148149 if x_cu_seqlens is None or x_max_item_seqlen is None :
149- raise ValueError ("x_cu_seqlens and x_max_item_seqlen are required for ZSingleStreamAttnProcessor" )
150+ raise ValueError ("x_cu_seqlens and x_max_item_seqlen are required for ZSingleStreamAttnProcessor" )
150151
151152 if flash_attn_varlen_func is not None :
152153 output = flash_attn_varlen_func (
@@ -164,45 +165,50 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
164165 output = output .flatten (- 2 )
165166 else :
166167 seqlens = (x_cu_seqlens [1 :] - x_cu_seqlens [:- 1 ]).cpu ().tolist ()
167-
168+
168169 q_split = torch .split (query , seqlens , dim = 0 )
169170 k_split = torch .split (key , seqlens , dim = 0 )
170171 v_split = torch .split (value , seqlens , dim = 0 )
171-
172+
172173 q_padded = torch .nn .utils .rnn .pad_sequence (q_split , batch_first = True )
173174 k_padded = torch .nn .utils .rnn .pad_sequence (k_split , batch_first = True )
174175 v_padded = torch .nn .utils .rnn .pad_sequence (v_split , batch_first = True )
175-
176+
176177 batch_size , max_seqlen , _ , _ = q_padded .shape
177-
178+
178179 mask = torch .zeros ((batch_size , max_seqlen ), dtype = torch .bool , device = query .device )
179180 for i , l in enumerate (seqlens ):
180181 mask [i , :l ] = True
181-
182+
182183 attn_mask = torch .zeros ((batch_size , 1 , 1 , max_seqlen ), dtype = query .dtype , device = query .device )
183184 attn_mask .masked_fill_ (~ mask [:, None , None , :], torch .finfo (query .dtype ).min )
184-
185+
185186 q_padded = q_padded .transpose (1 , 2 )
186187 k_padded = k_padded .transpose (1 , 2 )
187188 v_padded = v_padded .transpose (1 , 2 )
188-
189+
189190 output = F .scaled_dot_product_attention (
190- q_padded , k_padded , v_padded , attn_mask = attn_mask , dropout_p = 0.0 , scale = softmax_scale
191+ q_padded ,
192+ k_padded ,
193+ v_padded ,
194+ attn_mask = attn_mask ,
195+ dropout_p = 0.0 ,
196+ scale = softmax_scale ,
191197 )
192-
198+
193199 output = output .transpose (1 , 2 )
194-
200+
195201 out_list = []
196202 for i , l in enumerate (seqlens ):
197203 out_list .append (output [i , :l ])
198-
204+
199205 output = torch .cat (out_list , dim = 0 )
200206 output = output .flatten (- 2 )
201207
202208 output = attn .to_out [0 ](output )
203- if len (attn .to_out ) > 1 : # dropout
204- output = attn .to_out [1 ](output )
205-
209+ if len (attn .to_out ) > 1 : # dropout
210+ output = attn .to_out [1 ](output )
211+
206212 return output
207213
208214
@@ -226,12 +232,19 @@ def forward(self, x):
226232@maybe_allow_in_graph
227233class ZImageTransformerBlock (nn .Module ):
228234 def __init__ (
229- self , layer_id : int , dim : int , n_heads : int , n_kv_heads : int , norm_eps : float , qk_norm : bool , modulation = True
235+ self ,
236+ layer_id : int ,
237+ dim : int ,
238+ n_heads : int ,
239+ n_kv_heads : int ,
240+ norm_eps : float ,
241+ qk_norm : bool ,
242+ modulation = True ,
230243 ):
231244 super ().__init__ ()
232245 self .dim = dim
233246 self .head_dim = dim // n_heads
234-
247+
235248 # Refactored to use diffusers Attention with custom processor
236249 # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
237250 self .attention = Attention (
@@ -244,7 +257,7 @@ def __init__(
244257 bias = False ,
245258 processor = ZSingleStreamAttnProcessor (),
246259 )
247-
260+
248261 self .feed_forward = FeedForward (dim = dim , hidden_dim = int (dim / 3 * 8 ))
249262 self .layer_id = layer_id
250263
@@ -284,7 +297,12 @@ def forward(
284297 x_src_ids_shard = None
285298
286299 x_shard = self .attn_forward (
287- x_shard , x_freqs_cis_shard , x_cu_seqlens , x_max_item_seqlen , scale_gate_msa , x_src_ids_shard
300+ x_shard ,
301+ x_freqs_cis_shard ,
302+ x_cu_seqlens ,
303+ x_max_item_seqlen ,
304+ scale_gate_msa ,
305+ x_src_ids_shard ,
288306 )
289307
290308 x_shard = self .ffn_forward (x_shard , scale_gate_mlp , x_src_ids_shard )
@@ -303,22 +321,22 @@ def attn_forward(
303321 if self .modulation :
304322 assert scale_gate is not None and x_src_ids_shard is not None
305323 scale_msa , gate_msa = scale_gate
306-
324+
307325 # Pass extra args needed for ZSingleStreamAttnProcessor
308326 attn_out = self .attention (
309327 self .attention_norm1 (x_shard ) * scale_msa [x_src_ids_shard ],
310328 image_rotary_emb = x_freqs_cis_shard ,
311329 x_cu_seqlens = x_cu_seqlens ,
312- x_max_item_seqlen = x_max_item_seqlen
330+ x_max_item_seqlen = x_max_item_seqlen ,
313331 )
314-
332+
315333 x_shard = x_shard + gate_msa [x_src_ids_shard ] * self .attention_norm2 (attn_out )
316334 else :
317335 attn_out = self .attention (
318336 self .attention_norm1 (x_shard ),
319337 image_rotary_emb = x_freqs_cis_shard ,
320338 x_cu_seqlens = x_cu_seqlens ,
321- x_max_item_seqlen = x_max_item_seqlen
339+ x_max_item_seqlen = x_max_item_seqlen ,
322340 )
323341 x_shard = x_shard + self .attention_norm2 (attn_out )
324342 return x_shard
@@ -371,7 +389,10 @@ def forward(self, x_shard, x_src_ids_shard, c):
371389
372390class RopeEmbedder :
373391 def __init__ (
374- self , theta : float = 256.0 , axes_dims : List [int ] = (16 , 56 , 56 ), axes_lens : List [int ] = (64 , 128 , 128 )
392+ self ,
393+ theta : float = 256.0 ,
394+ axes_dims : List [int ] = (16 , 56 , 56 ),
395+ axes_lens : List [int ] = (64 , 128 , 128 ),
375396 ):
376397 self .theta = theta
377398 self .axes_dims = axes_dims
@@ -458,13 +479,29 @@ def __init__(
458479 self .all_final_layer = nn .ModuleDict (all_final_layer )
459480 self .noise_refiner = nn .ModuleList (
460481 [
461- ZImageTransformerBlock (1000 + layer_id , dim , n_heads , n_kv_heads , norm_eps , qk_norm , modulation = True )
482+ ZImageTransformerBlock (
483+ 1000 + layer_id ,
484+ dim ,
485+ n_heads ,
486+ n_kv_heads ,
487+ norm_eps ,
488+ qk_norm ,
489+ modulation = True ,
490+ )
462491 for layer_id in range (n_refiner_layers )
463492 ]
464493 )
465494 self .context_refiner = nn .ModuleList (
466495 [
467- ZImageTransformerBlock (layer_id , dim , n_heads , n_kv_heads , norm_eps , qk_norm , modulation = False )
496+ ZImageTransformerBlock (
497+ layer_id ,
498+ dim ,
499+ n_heads ,
500+ n_kv_heads ,
501+ norm_eps ,
502+ qk_norm ,
503+ modulation = False ,
504+ )
468505 for layer_id in range (n_refiner_layers )
469506 ]
470507 )
@@ -524,8 +561,6 @@ def patchify_and_embed(
524561 patch_size : int ,
525562 f_patch_size : int ,
526563 ):
527-
528- bsz = len (all_image )
529564 pH = pW = patch_size
530565 pF = f_patch_size
531566 device = all_image [0 ].device
@@ -560,7 +595,10 @@ def patchify_and_embed(
560595 )
561596 )
562597 # padded feature
563- cap_padded_feat = torch .cat ([all_cap_feats [i ], all_cap_feats [i ][- 1 :].repeat (cap_padding_len , 1 )], dim = 0 )
598+ cap_padded_feat = torch .cat (
599+ [all_cap_feats [i ], all_cap_feats [i ][- 1 :].repeat (cap_padding_len , 1 )],
600+ dim = 0 ,
601+ )
564602 all_cap_feats_out .append (cap_padded_feat )
565603
566604 ### Process Image
@@ -623,7 +661,6 @@ def forward(
623661 patch_size = 2 ,
624662 f_patch_size = 1 ,
625663 ):
626-
627664 assert patch_size in self .all_patch_size
628665 assert f_patch_size in self .all_f_patch_size
629666
@@ -649,7 +686,11 @@ def forward(
649686 assert all (_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens )
650687 x_max_item_seqlen = max (x_item_seqlens )
651688 x_cu_seqlens = F .pad (
652- torch .cumsum (torch .tensor (x_item_seqlens , dtype = torch .int32 , device = device ), dim = 0 , dtype = torch .int32 ),
689+ torch .cumsum (
690+ torch .tensor (x_item_seqlens , dtype = torch .int32 , device = device ),
691+ dim = 0 ,
692+ dtype = torch .int32 ,
693+ ),
653694 (1 , 0 ),
654695 )
655696 x_src_ids = [
@@ -666,15 +707,26 @@ def forward(
666707 x_shard = self .all_x_embedder [f"{ patch_size } -{ f_patch_size } " ](x_shard )
667708 x_shard [x_pad_mask_shard ] = self .x_pad_token
668709 for layer in self .noise_refiner :
669- x_shard = layer (x_shard , x_src_ids_shard , x_freqs_cis_shard , x_cu_seqlens , x_max_item_seqlen , adaln_input )
710+ x_shard = layer (
711+ x_shard ,
712+ x_src_ids_shard ,
713+ x_freqs_cis_shard ,
714+ x_cu_seqlens ,
715+ x_max_item_seqlen ,
716+ adaln_input ,
717+ )
670718 x_flatten = x_shard
671719
672720 # cap embed & refine
673721 cap_item_seqlens = [len (_ ) for _ in cap_feats ]
674722 assert all (_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens )
675723 cap_max_item_seqlen = max (cap_item_seqlens )
676724 cap_cu_seqlens = F .pad (
677- torch .cumsum (torch .tensor (cap_item_seqlens , dtype = torch .int32 , device = device ), dim = 0 , dtype = torch .int32 ),
725+ torch .cumsum (
726+ torch .tensor (cap_item_seqlens , dtype = torch .int32 , device = device ),
727+ dim = 0 ,
728+ dtype = torch .int32 ,
729+ ),
678730 (1 , 0 ),
679731 )
680732 cap_src_ids = [
@@ -705,14 +757,20 @@ def merge_interleave(l1, l2):
705757 return list (itertools .chain (* zip (l1 , l2 )))
706758
707759 unified = torch .cat (
708- merge_interleave (cap_flatten .split (cap_item_seqlens , dim = 0 ), x_flatten .split (x_item_seqlens , dim = 0 )), dim = 0
760+ merge_interleave (
761+ cap_flatten .split (cap_item_seqlens , dim = 0 ),
762+ x_flatten .split (x_item_seqlens , dim = 0 ),
763+ ),
764+ dim = 0 ,
709765 )
710766 unified_item_seqlens = [a + b for a , b in zip (cap_item_seqlens , x_item_seqlens )]
711767 assert len (unified ) == sum (unified_item_seqlens )
712768 unified_max_item_seqlen = max (unified_item_seqlens )
713769 unified_cu_seqlens = F .pad (
714770 torch .cumsum (
715- torch .tensor (unified_item_seqlens , dtype = torch .int32 , device = device ), dim = 0 , dtype = torch .int32
771+ torch .tensor (unified_item_seqlens , dtype = torch .int32 , device = device ),
772+ dim = 0 ,
773+ dtype = torch .int32 ,
716774 ),
717775 (1 , 0 ),
718776 )
0 commit comments