2727except ImportError :
2828 flash_attn_varlen_func = None
2929
30+ # todo see how other teams do this
3031try :
3132 from apex .normalization import FusedRMSNorm as RMSNorm
3233except ImportError :
@@ -61,10 +62,6 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
6162 bias = True ,
6263 ),
6364 )
64- nn .init .normal_ (self .mlp [0 ].weight , std = 0.02 )
65- nn .init .zeros_ (self .mlp [0 ].bias )
66- nn .init .normal_ (self .mlp [2 ].weight , std = 0.02 )
67- nn .init .zeros_ (self .mlp [2 ].bias )
6865
6966 self .frequency_embedding_size = frequency_embedding_size
7067
@@ -573,9 +570,9 @@ def patchify_and_embed(
573570 all_cap_pad_mask = []
574571 all_cap_feats_out = []
575572
576- for i , image in enumerate (all_image ):
577- ### LLM Text Encoder
578- cap_ori_len = len (all_cap_feats [ i ] )
573+ for i , ( image , cap_feat ) in enumerate (zip ( all_image , all_cap_feats ) ):
574+ ### Process Caption
575+ cap_ori_len = len (cap_feat )
579576 cap_padding_len = (- cap_ori_len ) % SEQ_MULTI_OF
580577 # padded position ids
581578 cap_padded_pos_ids = self .create_coordinate_grid (
@@ -596,7 +593,7 @@ def patchify_and_embed(
596593 )
597594 # padded feature
598595 cap_padded_feat = torch .cat (
599- [all_cap_feats [ i ], all_cap_feats [ i ] [- 1 :].repeat (cap_padding_len , 1 )],
596+ [cap_feat , cap_feat [- 1 :].repeat (cap_padding_len , 1 )],
600597 dim = 0 ,
601598 )
602599 all_cap_feats_out .append (cap_padded_feat )
@@ -677,126 +674,123 @@ def forward(
677674 x_size ,
678675 x_pos_ids ,
679676 cap_pos_ids ,
680- x_pad_mask ,
681- cap_pad_mask ,
677+ x_inner_pad_mask ,
678+ cap_inner_pad_mask ,
682679 ) = self .patchify_and_embed (x , cap_feats , patch_size , f_patch_size )
683680
684681 # x embed & refine
685682 x_item_seqlens = [len (_ ) for _ in x ]
686683 assert all (_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens )
687684 x_max_item_seqlen = max (x_item_seqlens )
688- x_cu_seqlens = F .pad (
689- torch .cumsum (
690- torch .tensor (x_item_seqlens , dtype = torch .int32 , device = device ),
691- dim = 0 ,
692- dtype = torch .int32 ,
693- ),
694- (1 , 0 ),
685+
686+ x = torch .cat (x , dim = 0 )
687+ x = self .all_x_embedder [f"{ patch_size } -{ f_patch_size } " ](x )
688+ x [torch .cat (x_inner_pad_mask )] = self .x_pad_token
689+ x = x .split (x_item_seqlens , dim = 0 )
690+ x_freqs_cis = self .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 ) # todo
691+
692+ pad_tensor = torch .zeros (
693+ (1 , self .dim ),
694+ dtype = x [0 ].dtype ,
695+ device = device ,
695696 )
696- x_src_ids = [
697- torch .full ((count ,), i , dtype = torch .int32 , device = device ) for i , count in enumerate (x_item_seqlens )
698- ]
699- x_freqs_cis = self .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 )
700-
701- x_shard = torch .cat (x , dim = 0 )
702- x_src_ids_shard = torch .cat (x_src_ids , dim = 0 )
703- x_freqs_cis_shard = torch .cat (x_freqs_cis , dim = 0 )
704- x_pad_mask_shard = torch .cat (x_pad_mask , dim = 0 )
705- del x
706-
707- x_shard = self .all_x_embedder [f"{ patch_size } -{ f_patch_size } " ](x_shard )
708- x_shard [x_pad_mask_shard ] = self .x_pad_token
697+ x_pad_mask = torch .zeros (
698+ (bsz , x_max_item_seqlen ),
699+ dtype = torch .bool ,
700+ device = device
701+ )
702+ for i , item in enumerate (x ):
703+ seq_len = x_item_seqlens [i ]
704+ x [i ] = torch .cat ([item , pad_tensor .repeat (x_max_item_seqlen - seq_len , 1 )])
705+ x_pad_mask [i , seq_len :] = 1
706+ x = torch .stack (x )
707+
709708 for layer in self .noise_refiner :
710- x_shard = layer (
711- x_shard ,
712- x_src_ids_shard ,
713- x_freqs_cis_shard ,
714- x_cu_seqlens ,
715- x_max_item_seqlen ,
709+ x = layer (
710+ x ,
711+ x_pad_mask ,
712+ x_freqs_cis ,
716713 adaln_input ,
717- )
718- x_flatten = x_shard
714+ ) # todo
719715
720716 # cap embed & refine
721717 cap_item_seqlens = [len (_ ) for _ in cap_feats ]
722718 assert all (_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens )
723719 cap_max_item_seqlen = max (cap_item_seqlens )
724- cap_cu_seqlens = F .pad (
725- torch .cumsum (
726- torch .tensor (cap_item_seqlens , dtype = torch .int32 , device = device ),
727- dim = 0 ,
728- dtype = torch .int32 ,
729- ),
730- (1 , 0 ),
720+
721+ cap_feats = torch .cat (cap_feats , dim = 0 )
722+ cap_feats = self .cap_embedder (cap_feats )
723+ cap_feats [torch .cat (cap_inner_pad_mask )] = self .cap_pad_token
724+ cap_feats = cap_feats .split (cap_item_seqlens , dim = 0 )
725+ cap_freqs_cis = self .rope_embedder (torch .cat (cap_pos_ids , dim = 0 )).split (cap_item_seqlens , dim = 0 ) # todo
726+
727+ pad_tensor = torch .zeros (
728+ (1 , self .dim ),
729+ dtype = x [0 ].dtype ,
730+ device = device ,
731731 )
732- cap_src_ids = [
733- torch .full ((count ,), i , dtype = torch .int32 , device = device ) for i , count in enumerate (cap_item_seqlens )
734- ]
735- cap_freqs_cis = self .rope_embedder (torch .cat (cap_pos_ids , dim = 0 )).split (cap_item_seqlens , dim = 0 )
736-
737- cap_shard = torch .cat (cap_feats , dim = 0 )
738- cap_src_ids_shard = torch .cat (cap_src_ids , dim = 0 )
739- cap_freqs_cis_shard = torch .cat (cap_freqs_cis , dim = 0 )
740- cap_pad_mask_shard = torch .cat (cap_pad_mask , dim = 0 )
741- del cap_feats
742-
743- cap_shard = self .cap_embedder (cap_shard )
744- cap_shard [cap_pad_mask_shard ] = self .cap_pad_token
732+ cap_pad_mask = torch .zeros (
733+ (bsz , cap_max_item_seqlen ),
734+ dtype = torch .bool ,
735+ device = device
736+ )
737+ for i , item in enumerate (cap_feats ):
738+ seq_len = cap_item_seqlens [i ]
739+ cap_feats [i ] = torch .cat ([item , pad_tensor .repeat (cap_max_item_seqlen - seq_len , 1 )])
740+ cap_pad_mask [i , seq_len :] = 1
741+ cap_feats = torch .stack (cap_feats )
745742 for layer in self .context_refiner :
746- cap_shard = layer (
747- cap_shard ,
748- cap_src_ids_shard ,
749- cap_freqs_cis_shard ,
750- cap_cu_seqlens ,
751- cap_max_item_seqlen ,
743+ cap_feats = layer (
744+ cap_feats ,
745+ cap_pad_mask ,
746+ cap_freqs_cis ,
752747 )
753- cap_flatten = cap_shard
754-
755- # unified
756- def merge_interleave (l1 , l2 ):
757- return list (itertools .chain (* zip (l1 , l2 )))
758748
759- unified = torch .cat (
760- merge_interleave (
761- cap_flatten .split (cap_item_seqlens , dim = 0 ),
762- x_flatten .split (x_item_seqlens , dim = 0 ),
763- ),
764- dim = 0 ,
765- )
749+ # unified todo
766750 unified_item_seqlens = [a + b for a , b in zip (cap_item_seqlens , x_item_seqlens )]
767- assert len (unified ) == sum (unified_item_seqlens )
768751 unified_max_item_seqlen = max (unified_item_seqlens )
769- unified_cu_seqlens = F .pad (
770- torch .cumsum (
771- torch .tensor (unified_item_seqlens , dtype = torch .int32 , device = device ),
772- dim = 0 ,
773- dtype = torch .int32 ,
774- ),
775- (1 , 0 ),
752+
753+ pad_tensor = torch .zeros (
754+ (1 , self .dim ),
755+ dtype = x [0 ].dtype ,
756+ device = device ,
776757 )
777- unified_src_ids = torch .cat (merge_interleave (cap_src_ids , x_src_ids ))
778- unified_freqs_cis = torch .cat (merge_interleave (cap_freqs_cis , x_freqs_cis ))
758+ unified_pad_mask = torch .zeros (
759+ (bsz , unified_max_item_seqlen ),
760+ dtype = torch .bool ,
761+ device = device
762+ )
763+
764+ unified = []
765+ for i in range (bsz ):
766+ x_len = x_item_seqlens [i ]
767+ cap_len = cap_item_seqlens [i ]
768+ unified .append (
769+ torch .cat (
770+ [
771+ x [i ][:x_item_seqlens [i ]],
772+ cap_feats [i ][:cap_item_seqlens [i ]],
773+ pad_tensor .repeat (unified_max_item_seqlen - x_len - cap_len , 1 )
774+ ]
775+ )
776+ )
777+ unified_pad_mask [i , x_len + cap_len :] = 1
778+
779+ unified_freqs_cis = torch .cat (merge_interleave (cap_freqs_cis , x_freqs_cis )) # todo
779780
780- unified_shard = unified
781- unified_src_ids_shard = unified_src_ids
782- unified_freqs_cis_shard = unified_freqs_cis
783781 for layer in self .layers :
784782 unified_shard = layer (
785- unified_shard ,
786- unified_src_ids_shard ,
787- unified_freqs_cis_shard ,
788- unified_cu_seqlens ,
789- unified_max_item_seqlen ,
783+ unified ,
784+ unified_pad_mask ,
785+ unified_freqs_cis ,
790786 adaln_input ,
791787 )
792- unified_shard = self .all_final_layer [f"{ patch_size } -{ f_patch_size } " ](
793- unified_shard , unified_src_ids_shard , adaln_input
794- )
795- unified = unified_shard .split (unified_item_seqlens , dim = 0 )
796- x = [unified [i ][cap_item_seqlens [i ] :] for i in range (bsz )]
797- assert all (len (x [i ]) == x_item_seqlens [i ] for i in range (bsz ))
798788
799- x = self .unpatchify (x , x_size , patch_size , f_patch_size )
789+ unified = self .all_final_layer [f"{ patch_size } -{ f_patch_size } " ](
790+ unified , adaln_input # todo
791+ )
792+ unified = unified .split (unified_item_seqlens , dim = 0 )
793+ x = self .unpatchify (unified , x_size , patch_size , f_patch_size )
800794
801795 return x , {}
802796
0 commit comments