2121import torch .nn .functional as F
2222from einops import rearrange
2323
24+ from ...configuration_utils import ConfigMixin , register_to_config
25+ from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
26+ from ...models .attention_processor import Attention
27+ from ...models .modeling_utils import ModelMixin
28+ from ...utils .import_utils import is_apex_available , is_flash_attn_available
29+ from ...utils .torch_utils import maybe_allow_in_graph
30+
2431
25- try :
32+ if is_flash_attn_available () :
2633 from flash_attn import flash_attn_varlen_func
27- except ImportError :
34+ else :
2835 flash_attn_varlen_func = None
2936
30- try :
37+ if is_apex_available ():
38+ # Here needs apex with "APEX_CPP_EXT=1 APEX_CUDA_EXT=1 pip install -v --no-build-isolation ."
3139 from apex .normalization import FusedRMSNorm as RMSNorm
32- except ImportError :
40+ else :
3341 from torch .nn import RMSNorm
3442
35- from ...configuration_utils import ConfigMixin , register_to_config
36- from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
37- from ...models .attention_processor import Attention
38- from ...models .modeling_utils import ModelMixin
39- from ...utils .torch_utils import maybe_allow_in_graph
40-
4143
4244ADALN_EMBED_DIM = 256
4345SEQ_MULTI_OF = 32
@@ -61,10 +63,6 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
6163 bias = True ,
6264 ),
6365 )
64- nn .init .normal_ (self .mlp [0 ].weight , std = 0.02 )
65- nn .init .zeros_ (self .mlp [0 ].bias )
66- nn .init .normal_ (self .mlp [2 ].weight , std = 0.02 )
67- nn .init .zeros_ (self .mlp [2 ].bias )
6866
6967 self .frequency_embedding_size = frequency_embedding_size
7068
@@ -106,20 +104,20 @@ def __call__(
106104 x_cu_seqlens : Optional [torch .Tensor ] = None ,
107105 x_max_item_seqlen : Optional [int ] = None ,
108106 ) -> torch .Tensor :
109- x_shard = hidden_states
110- x_freqs_cis_shard = image_rotary_emb
107+ x = hidden_states
108+ x_freqs_cis = image_rotary_emb
111109
112- query = attn .to_q (x_shard )
113- key = attn .to_k (x_shard )
114- value = attn .to_v (x_shard )
110+ query = attn .to_q (x )
111+ key = attn .to_k (x )
112+ value = attn .to_v (x )
115113
116- seqlen_shard = x_shard .shape [0 ]
114+ seqlen = x .shape [0 ]
117115
118116 # Reshape to [seq_len, heads, head_dim]
119117 head_dim = query .shape [- 1 ] // attn .heads
120- query = query .view (seqlen_shard , attn .heads , head_dim )
121- key = key .view (seqlen_shard , attn .heads , head_dim )
122- value = value .view (seqlen_shard , attn .heads , head_dim )
118+ query = query .view (seqlen , attn .heads , head_dim )
119+ key = key .view (seqlen , attn .heads , head_dim )
120+ value = value .view (seqlen , attn .heads , head_dim )
123121 # Apply Norms
124122 if attn .norm_q is not None :
125123 query = attn .norm_q (query )
@@ -134,9 +132,9 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
134132 x_out = torch .view_as_real (x * freqs_cis ).flatten (2 )
135133 return x_out .type_as (x_in )
136134
137- if x_freqs_cis_shard is not None :
138- query = apply_rotary_emb (query , x_freqs_cis_shard )
139- key = apply_rotary_emb (key , x_freqs_cis_shard )
135+ if x_freqs_cis is not None :
136+ query = apply_rotary_emb (query , x_freqs_cis )
137+ key = apply_rotary_emb (key , x_freqs_cis )
140138
141139 # Cast to correct dtype
142140 dtype = query .dtype
@@ -277,9 +275,9 @@ def __init__(
277275
278276 def forward (
279277 self ,
280- x_shard : torch .Tensor ,
281- x_src_ids_shard : torch .Tensor ,
282- x_freqs_cis_shard : torch .Tensor ,
278+ x : torch .Tensor ,
279+ x_src_ids : torch .Tensor ,
280+ x_freqs_cis : torch .Tensor ,
283281 x_cu_seqlens : torch .Tensor ,
284282 x_max_item_seqlen : int ,
285283 adaln_input : Optional [torch .Tensor ] = None ,
@@ -289,80 +287,40 @@ def forward(
289287 scale_msa , gate_msa , scale_mlp , gate_mlp = self .adaLN_modulation (adaln_input ).chunk (4 , dim = 1 )
290288 gate_msa , gate_mlp = gate_msa .tanh (), gate_mlp .tanh ()
291289 scale_msa , scale_mlp = 1.0 + scale_msa , 1.0 + scale_mlp
292- scale_gate_msa = (scale_msa , gate_msa )
293- scale_gate_mlp = (scale_mlp , gate_mlp )
294- else :
295- scale_gate_msa = None
296- scale_gate_mlp = None
297- x_src_ids_shard = None
298-
299- x_shard = self .attn_forward (
300- x_shard ,
301- x_freqs_cis_shard ,
302- x_cu_seqlens ,
303- x_max_item_seqlen ,
304- scale_gate_msa ,
305- x_src_ids_shard ,
306- )
307290
308- x_shard = self .ffn_forward (x_shard , scale_gate_mlp , x_src_ids_shard )
309-
310- return x_shard
311-
312- def attn_forward (
313- self ,
314- x_shard ,
315- x_freqs_cis_shard ,
316- x_cu_seqlens ,
317- x_max_item_seqlen ,
318- scale_gate : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
319- x_src_ids_shard : Optional [torch .Tensor ] = None ,
320- ):
321- if self .modulation :
322- assert scale_gate is not None and x_src_ids_shard is not None
323- scale_msa , gate_msa = scale_gate
324-
325- # Pass extra args needed for ZSingleStreamAttnProcessor
291+ # Attention block
326292 attn_out = self .attention (
327- self .attention_norm1 (x_shard ) * scale_msa [x_src_ids_shard ],
328- image_rotary_emb = x_freqs_cis_shard ,
293+ self .attention_norm1 (x ) * scale_msa [x_src_ids ],
294+ image_rotary_emb = x_freqs_cis ,
329295 x_cu_seqlens = x_cu_seqlens ,
330296 x_max_item_seqlen = x_max_item_seqlen ,
331297 )
298+ x = x + gate_msa [x_src_ids ] * self .attention_norm2 (attn_out )
332299
333- x_shard = x_shard + gate_msa [x_src_ids_shard ] * self .attention_norm2 (attn_out )
300+ # FFN block
301+ x = x + gate_mlp [x_src_ids ] * self .ffn_norm2 (
302+ self .feed_forward (
303+ self .ffn_norm1 (x ) * scale_mlp [x_src_ids ],
304+ )
305+ )
334306 else :
307+ # Attention block
335308 attn_out = self .attention (
336- self .attention_norm1 (x_shard ),
337- image_rotary_emb = x_freqs_cis_shard ,
309+ self .attention_norm1 (x ),
310+ image_rotary_emb = x_freqs_cis ,
338311 x_cu_seqlens = x_cu_seqlens ,
339312 x_max_item_seqlen = x_max_item_seqlen ,
340313 )
341- x_shard = x_shard + self .attention_norm2 (attn_out )
342- return x_shard
314+ x = x + self .attention_norm2 (attn_out )
343315
344- def ffn_forward (
345- self ,
346- x_shard ,
347- scale_gate : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
348- x_src_ids_shard : Optional [torch .Tensor ] = None ,
349- ):
350- if self .modulation :
351- assert scale_gate is not None and x_src_ids_shard is not None
352- scale_mlp , gate_mlp = scale_gate
353- x_shard = x_shard + gate_mlp [x_src_ids_shard ] * self .ffn_norm2 (
316+ # FFN block
317+ x = x + self .ffn_norm2 (
354318 self .feed_forward (
355- self .ffn_norm1 (x_shard ) * scale_mlp [ x_src_ids_shard ] ,
319+ self .ffn_norm1 (x ) ,
356320 )
357321 )
358322
359- else :
360- x_shard = x_shard + self .ffn_norm2 (
361- self .feed_forward (
362- self .ffn_norm1 (x_shard ),
363- )
364- )
365- return x_shard
323+ return x
366324
367325
368326class FinalLayer (nn .Module ):
@@ -380,11 +338,11 @@ def __init__(self, hidden_size, out_channels):
380338 nn .init .zeros_ (self .adaLN_modulation [1 ].weight )
381339 nn .init .zeros_ (self .adaLN_modulation [1 ].bias )
382340
383- def forward (self , x_shard , x_src_ids_shard , c ):
341+ def forward (self , x , x_src_ids , c ):
384342 scale = 1.0 + self .adaLN_modulation (c )
385- x_shard = self .norm_final (x_shard ) * scale [x_src_ids_shard ]
386- x_shard = self .linear (x_shard )
387- return x_shard
343+ x = self .norm_final (x ) * scale [x_src_ids ]
344+ x = self .linear (x )
345+ return x
388346
389347
390348class RopeEmbedder :
@@ -468,8 +426,6 @@ def __init__(
468426 all_final_layer = {}
469427 for patch_idx , (patch_size , f_patch_size ) in enumerate (zip (all_patch_size , all_f_patch_size )):
470428 x_embedder = nn .Linear (f_patch_size * patch_size * patch_size * in_channels , dim , bias = True )
471- nn .init .xavier_uniform_ (x_embedder .weight )
472- nn .init .constant_ (x_embedder .bias , 0.0 )
473429 all_x_embedder [f"{ patch_size } -{ f_patch_size } " ] = x_embedder
474430
475431 final_layer = FinalLayer (dim , patch_size * patch_size * f_patch_size * self .out_channels )
@@ -698,24 +654,23 @@ def forward(
698654 ]
699655 x_freqs_cis = self .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 )
700656
701- x_shard = torch .cat (x , dim = 0 )
702- x_src_ids_shard = torch .cat (x_src_ids , dim = 0 )
703- x_freqs_cis_shard = torch .cat (x_freqs_cis , dim = 0 )
704- x_pad_mask_shard = torch .cat (x_pad_mask , dim = 0 )
705- del x
657+ x = torch .cat (x , dim = 0 )
658+ x_src_ids = torch .cat (x_src_ids , dim = 0 )
659+ x_freqs_cis = torch .cat (x_freqs_cis , dim = 0 )
660+ x_pad_mask = torch .cat (x_pad_mask , dim = 0 )
706661
707- x_shard = self .all_x_embedder [f"{ patch_size } -{ f_patch_size } " ](x_shard )
708- x_shard [ x_pad_mask_shard ] = self .x_pad_token
662+ x = self .all_x_embedder [f"{ patch_size } -{ f_patch_size } " ](x )
663+ x [ x_pad_mask ] = self .x_pad_token
709664 for layer in self .noise_refiner :
710- x_shard = layer (
711- x_shard ,
712- x_src_ids_shard ,
713- x_freqs_cis_shard ,
665+ x = layer (
666+ x ,
667+ x_src_ids ,
668+ x_freqs_cis ,
714669 x_cu_seqlens ,
715670 x_max_item_seqlen ,
716671 adaln_input ,
717672 )
718- x_flatten = x_shard
673+ x_flatten = x
719674
720675 # cap embed & refine
721676 cap_item_seqlens = [len (_ ) for _ in cap_feats ]
@@ -734,23 +689,23 @@ def forward(
734689 ]
735690 cap_freqs_cis = self .rope_embedder (torch .cat (cap_pos_ids , dim = 0 )).split (cap_item_seqlens , dim = 0 )
736691
737- cap_shard = torch .cat (cap_feats , dim = 0 )
738- cap_src_ids_shard = torch .cat (cap_src_ids , dim = 0 )
739- cap_freqs_cis_shard = torch .cat (cap_freqs_cis , dim = 0 )
740- cap_pad_mask_shard = torch .cat (cap_pad_mask , dim = 0 )
692+ cap = torch .cat (cap_feats , dim = 0 )
693+ cap_src_ids = torch .cat (cap_src_ids , dim = 0 )
694+ cap_freqs_cis = torch .cat (cap_freqs_cis , dim = 0 )
695+ cap_pad_mask = torch .cat (cap_pad_mask , dim = 0 )
741696 del cap_feats
742697
743- cap_shard = self .cap_embedder (cap_shard )
744- cap_shard [ cap_pad_mask_shard ] = self .cap_pad_token
698+ cap = self .cap_embedder (cap )
699+ cap [ cap_pad_mask ] = self .cap_pad_token
745700 for layer in self .context_refiner :
746- cap_shard = layer (
747- cap_shard ,
748- cap_src_ids_shard ,
749- cap_freqs_cis_shard ,
701+ cap = layer (
702+ cap ,
703+ cap_src_ids ,
704+ cap_freqs_cis ,
750705 cap_cu_seqlens ,
751706 cap_max_item_seqlen ,
752707 )
753- cap_flatten = cap_shard
708+ cap_flatten = cap
754709
755710 # unified
756711 def merge_interleave (l1 , l2 ):
@@ -774,41 +729,32 @@ def merge_interleave(l1, l2):
774729 ),
775730 (1 , 0 ),
776731 )
777- unified_src_ids = torch .cat (merge_interleave (cap_src_ids , x_src_ids ))
778- unified_freqs_cis = torch .cat (merge_interleave (cap_freqs_cis , x_freqs_cis ))
779-
780- unified_shard = unified
781- unified_src_ids_shard = unified_src_ids
782- unified_freqs_cis_shard = unified_freqs_cis
732+ unified_src_ids = torch .cat (
733+ merge_interleave (
734+ cap_src_ids .split (cap_item_seqlens , dim = 0 ),
735+ x_src_ids .split (x_item_seqlens , dim = 0 ),
736+ )
737+ )
738+ unified_freqs_cis = torch .cat (
739+ merge_interleave (
740+ cap_freqs_cis .split (cap_item_seqlens , dim = 0 ),
741+ x_freqs_cis .split (x_item_seqlens , dim = 0 ),
742+ )
743+ )
783744 for layer in self .layers :
784- unified_shard = layer (
785- unified_shard ,
786- unified_src_ids_shard ,
787- unified_freqs_cis_shard ,
745+ unified = layer (
746+ unified ,
747+ unified_src_ids ,
748+ unified_freqs_cis ,
788749 unified_cu_seqlens ,
789750 unified_max_item_seqlen ,
790751 adaln_input ,
791752 )
792- unified_shard = self .all_final_layer [f"{ patch_size } -{ f_patch_size } " ](
793- unified_shard , unified_src_ids_shard , adaln_input
794- )
795- unified = unified_shard .split (unified_item_seqlens , dim = 0 )
753+ unified = self .all_final_layer [f"{ patch_size } -{ f_patch_size } " ](unified , unified_src_ids , adaln_input )
754+ unified = unified .split (unified_item_seqlens , dim = 0 )
796755 x = [unified [i ][cap_item_seqlens [i ] :] for i in range (bsz )]
797756 assert all (len (x [i ]) == x_item_seqlens [i ] for i in range (bsz ))
798757
799758 x = self .unpatchify (x , x_size , patch_size , f_patch_size )
800759
801760 return x , {}
802-
803- def parameter_count (self ) -> int :
804- total_params = 0
805-
806- def _recursive_count_params (module ):
807- nonlocal total_params
808- for param in module .parameters (recurse = False ):
809- total_params += param .numel ()
810- for submodule in module .children ():
811- _recursive_count_params (submodule )
812-
813- _recursive_count_params (self )
814- return total_params
0 commit comments