2121import torch .nn .functional as F
2222from einops import rearrange
2323
24+ from ...configuration_utils import ConfigMixin , register_to_config
25+ from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
26+ from ...models .attention_processor import Attention
27+ from ...models .modeling_utils import ModelMixin
28+ from ...utils .import_utils import is_apex_available , is_flash_attn_available
29+ from ...utils .torch_utils import maybe_allow_in_graph
2430
25- try :
31+
32+ if is_flash_attn_available ():
2633 from flash_attn import flash_attn_varlen_func
27- except ImportError :
34+ else :
2835 flash_attn_varlen_func = None
2936
30- # todo see how other teams do this
31- try :
37+ if is_apex_available ():
38+ # Here needs apex with "APEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation ."
3239 from apex .normalization import FusedRMSNorm as RMSNorm
33- except ImportError :
40+ else :
3441 from torch .nn import RMSNorm
3542
36- from ...configuration_utils import ConfigMixin , register_to_config
37- from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
38- from ...models .attention_processor import Attention
39- from ...models .modeling_utils import ModelMixin
40- from ...utils .torch_utils import maybe_allow_in_graph
41-
4243
4344ADALN_EMBED_DIM = 256
4445SEQ_MULTI_OF = 32
@@ -103,20 +104,20 @@ def __call__(
103104 x_cu_seqlens : Optional [torch .Tensor ] = None ,
104105 x_max_item_seqlen : Optional [int ] = None ,
105106 ) -> torch .Tensor :
106- x_shard = hidden_states
107- x_freqs_cis_shard = image_rotary_emb
107+ x = hidden_states
108+ x_freqs_cis = image_rotary_emb
108109
109- query = attn .to_q (x_shard )
110- key = attn .to_k (x_shard )
111- value = attn .to_v (x_shard )
110+ query = attn .to_q (x )
111+ key = attn .to_k (x )
112+ value = attn .to_v (x )
112113
113- seqlen_shard = x_shard .shape [0 ]
114+ seqlen = x .shape [0 ]
114115
115116 # Reshape to [seq_len, heads, head_dim]
116117 head_dim = query .shape [- 1 ] // attn .heads
117- query = query .view (seqlen_shard , attn .heads , head_dim )
118- key = key .view (seqlen_shard , attn .heads , head_dim )
119- value = value .view (seqlen_shard , attn .heads , head_dim )
118+ query = query .view (seqlen , attn .heads , head_dim )
119+ key = key .view (seqlen , attn .heads , head_dim )
120+ value = value .view (seqlen , attn .heads , head_dim )
120121 # Apply Norms
121122 if attn .norm_q is not None :
122123 query = attn .norm_q (query )
@@ -131,9 +132,9 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
131132 x_out = torch .view_as_real (x * freqs_cis ).flatten (2 )
132133 return x_out .type_as (x_in )
133134
134- if x_freqs_cis_shard is not None :
135- query = apply_rotary_emb (query , x_freqs_cis_shard )
136- key = apply_rotary_emb (key , x_freqs_cis_shard )
135+ if x_freqs_cis is not None :
136+ query = apply_rotary_emb (query , x_freqs_cis )
137+ key = apply_rotary_emb (key , x_freqs_cis )
137138
138139 # Cast to correct dtype
139140 dtype = query .dtype
@@ -274,9 +275,9 @@ def __init__(
274275
275276 def forward (
276277 self ,
277- x_shard : torch .Tensor ,
278- x_src_ids_shard : torch .Tensor ,
279- x_freqs_cis_shard : torch .Tensor ,
278+ x : torch .Tensor ,
279+ x_src_ids : torch .Tensor ,
280+ x_freqs_cis : torch .Tensor ,
280281 x_cu_seqlens : torch .Tensor ,
281282 x_max_item_seqlen : int ,
282283 adaln_input : Optional [torch .Tensor ] = None ,
@@ -286,80 +287,40 @@ def forward(
286287 scale_msa , gate_msa , scale_mlp , gate_mlp = self .adaLN_modulation (adaln_input ).chunk (4 , dim = 1 )
287288 gate_msa , gate_mlp = gate_msa .tanh (), gate_mlp .tanh ()
288289 scale_msa , scale_mlp = 1.0 + scale_msa , 1.0 + scale_mlp
289- scale_gate_msa = (scale_msa , gate_msa )
290- scale_gate_mlp = (scale_mlp , gate_mlp )
291- else :
292- scale_gate_msa = None
293- scale_gate_mlp = None
294- x_src_ids_shard = None
295-
296- x_shard = self .attn_forward (
297- x_shard ,
298- x_freqs_cis_shard ,
299- x_cu_seqlens ,
300- x_max_item_seqlen ,
301- scale_gate_msa ,
302- x_src_ids_shard ,
303- )
304290
305- x_shard = self .ffn_forward (x_shard , scale_gate_mlp , x_src_ids_shard )
306-
307- return x_shard
308-
309- def attn_forward (
310- self ,
311- x_shard ,
312- x_freqs_cis_shard ,
313- x_cu_seqlens ,
314- x_max_item_seqlen ,
315- scale_gate : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
316- x_src_ids_shard : Optional [torch .Tensor ] = None ,
317- ):
318- if self .modulation :
319- assert scale_gate is not None and x_src_ids_shard is not None
320- scale_msa , gate_msa = scale_gate
321-
322- # Pass extra args needed for ZSingleStreamAttnProcessor
291+ # Attention block
323292 attn_out = self .attention (
324- self .attention_norm1 (x_shard ) * scale_msa [x_src_ids_shard ],
325- image_rotary_emb = x_freqs_cis_shard ,
293+ self .attention_norm1 (x ) * scale_msa [x_src_ids ],
294+ image_rotary_emb = x_freqs_cis ,
326295 x_cu_seqlens = x_cu_seqlens ,
327296 x_max_item_seqlen = x_max_item_seqlen ,
328297 )
298+ x = x + gate_msa [x_src_ids ] * self .attention_norm2 (attn_out )
329299
330- x_shard = x_shard + gate_msa [x_src_ids_shard ] * self .attention_norm2 (attn_out )
300+ # FFN block
301+ x = x + gate_mlp [x_src_ids ] * self .ffn_norm2 (
302+ self .feed_forward (
303+ self .ffn_norm1 (x ) * scale_mlp [x_src_ids ],
304+ )
305+ )
331306 else :
307+ # Attention block
332308 attn_out = self .attention (
333- self .attention_norm1 (x_shard ),
334- image_rotary_emb = x_freqs_cis_shard ,
309+ self .attention_norm1 (x ),
310+ image_rotary_emb = x_freqs_cis ,
335311 x_cu_seqlens = x_cu_seqlens ,
336312 x_max_item_seqlen = x_max_item_seqlen ,
337313 )
338- x_shard = x_shard + self .attention_norm2 (attn_out )
339- return x_shard
314+ x = x + self .attention_norm2 (attn_out )
340315
341- def ffn_forward (
342- self ,
343- x_shard ,
344- scale_gate : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
345- x_src_ids_shard : Optional [torch .Tensor ] = None ,
346- ):
347- if self .modulation :
348- assert scale_gate is not None and x_src_ids_shard is not None
349- scale_mlp , gate_mlp = scale_gate
350- x_shard = x_shard + gate_mlp [x_src_ids_shard ] * self .ffn_norm2 (
316+ # FFN block
317+ x = x + self .ffn_norm2 (
351318 self .feed_forward (
352- self .ffn_norm1 (x_shard ) * scale_mlp [ x_src_ids_shard ] ,
319+ self .ffn_norm1 (x ) ,
353320 )
354321 )
355322
356- else :
357- x_shard = x_shard + self .ffn_norm2 (
358- self .feed_forward (
359- self .ffn_norm1 (x_shard ),
360- )
361- )
362- return x_shard
323+ return x
363324
364325
365326class FinalLayer (nn .Module ):
@@ -377,11 +338,11 @@ def __init__(self, hidden_size, out_channels):
377338 nn .init .zeros_ (self .adaLN_modulation [1 ].weight )
378339 nn .init .zeros_ (self .adaLN_modulation [1 ].bias )
379340
380- def forward (self , x_shard , x_src_ids_shard , c ):
341+ def forward (self , x , x_src_ids , c ):
381342 scale = 1.0 + self .adaLN_modulation (c )
382- x_shard = self .norm_final (x_shard ) * scale [x_src_ids_shard ]
383- x_shard = self .linear (x_shard )
384- return x_shard
343+ x = self .norm_final (x ) * scale [x_src_ids ]
344+ x = self .linear (x )
345+ return x
385346
386347
387348class RopeEmbedder :
@@ -465,8 +426,6 @@ def __init__(
465426 all_final_layer = {}
466427 for patch_idx , (patch_size , f_patch_size ) in enumerate (zip (all_patch_size , all_f_patch_size )):
467428 x_embedder = nn .Linear (f_patch_size * patch_size * patch_size * in_channels , dim , bias = True )
468- nn .init .xavier_uniform_ (x_embedder .weight )
469- nn .init .constant_ (x_embedder .bias , 0.0 )
470429 all_x_embedder [f"{ patch_size } -{ f_patch_size } " ] = x_embedder
471430
472431 final_layer = FinalLayer (dim , patch_size * patch_size * f_patch_size * self .out_channels )
@@ -793,16 +752,3 @@ def forward(
793752 x = self .unpatchify (unified , x_size , patch_size , f_patch_size )
794753
795754 return x , {}
796-
797- def parameter_count (self ) -> int :
798- total_params = 0
799-
800- def _recursive_count_params (module ):
801- nonlocal total_params
802- for param in module .parameters (recurse = False ):
803- total_params += param .numel ()
804- for submodule in module .children ():
805- _recursive_count_params (submodule )
806-
807- _recursive_count_params (self )
808- return total_params
0 commit comments