1818import  torch 
1919import  torch .nn  as  nn 
2020import  torch .nn .functional  as  F 
21- from  einops  import  rearrange 
2221
2322from  ...configuration_utils  import  ConfigMixin , register_to_config 
2423from  ...utils  import  logging 
2524from  ...utils .torch_utils  import  maybe_allow_in_graph 
2625from  ..attention  import  FeedForward 
27- from  ..attention_processor  import  (
28-     AllegroAttnProcessor2_0 ,
29-     Attention ,
30- )
31- from  ..embeddings  import  PixArtAlphaTextProjection 
26+ from  ..attention_processor  import  AllegroAttnProcessor2_0 , Attention 
27+ from  ..embeddings  import  PatchEmbed , PixArtAlphaTextProjection 
3228from  ..modeling_outputs  import  Transformer2DModelOutput 
3329from  ..modeling_utils  import  ModelMixin 
3430from  ..normalization  import  AllegroAdaLayerNormSingle 
3733logger  =  logging .get_logger (__name__ )
3834
3935
40- class  PatchEmbed2D (nn .Module ):
41-     """2D Image to Patch Embedding""" 
42- 
43-     def  __init__ (
44-         self ,
45-         num_frames = 1 ,
46-         height = 224 ,
47-         width = 224 ,
48-         patch_size_t = 1 ,
49-         patch_size = 16 ,
50-         in_channels = 3 ,
51-         embed_dim = 768 ,
52-         layer_norm = False ,
53-         flatten = True ,
54-         bias = True ,
55-         use_abs_pos = False ,
56-     ):
57-         super ().__init__ ()
58-         self .use_abs_pos  =  use_abs_pos 
59-         self .flatten  =  flatten 
60-         self .layer_norm  =  layer_norm 
61- 
62-         self .proj  =  nn .Conv2d (
63-             in_channels , embed_dim , kernel_size = (patch_size , patch_size ), stride = (patch_size , patch_size ), bias = bias 
64-         )
65-         if  layer_norm :
66-             self .norm  =  nn .LayerNorm (embed_dim , elementwise_affine = False , eps = 1e-6 )
67-         else :
68-             self .norm  =  None 
69- 
70-         self .patch_size_t  =  patch_size_t 
71-         self .patch_size  =  patch_size 
72- 
73-     def  forward (self , latent ):
74-         b , _ , _ , _ , _  =  latent .shape 
75-         video_latent  =  None 
76- 
77-         latent  =  rearrange (latent , "b c t h w -> (b t) c h w" )
78- 
79-         latent  =  self .proj (latent )
80-         if  self .flatten :
81-             latent  =  latent .flatten (2 ).transpose (1 , 2 )  # BT C H W -> BT N C 
82-         if  self .layer_norm :
83-             latent  =  self .norm (latent )
84- 
85-         latent  =  rearrange (latent , "(b t) n c -> b (t n) c" , b = b )
86-         video_latent  =  latent 
87- 
88-         return  video_latent 
89- 
90- 
9136@maybe_allow_in_graph  
9237class  AllegroTransformerBlock (nn .Module ):
9338    r""" 
@@ -280,13 +225,13 @@ def __init__(
280225        interpolation_scale_w  =  interpolation_scale_w  if  interpolation_scale_w  is  not None  else  sample_width  /  40 
281226
282227        # 1. Patch embedding 
283-         self .pos_embed  =  PatchEmbed2D (
228+         self .pos_embed  =  PatchEmbed (
284229            height = sample_height ,
285230            width = sample_width ,
286231            patch_size = patch_size ,
287232            in_channels = in_channels ,
288233            embed_dim = self .inner_dim ,
289-             #  pos_embed_type=None,
234+             pos_embed_type = None ,
290235        )
291236
292237        # 2. Transformer blocks 
@@ -327,8 +272,8 @@ def _set_gradient_checkpointing(self, module, value=False):
327272    def  forward (
328273        self ,
329274        hidden_states : torch .Tensor ,
330-         encoder_hidden_states : Optional [ torch .Tensor ]  =   None ,
331-         timestep : Optional [ torch .LongTensor ]  =   None ,
275+         encoder_hidden_states : torch .Tensor ,
276+         timestep : torch .LongTensor ,
332277        attention_mask : Optional [torch .Tensor ] =  None ,
333278        encoder_attention_mask : Optional [torch .Tensor ] =  None ,
334279        image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
@@ -368,13 +313,9 @@ def forward(
368313            )
369314
370315        # convert encoder_attention_mask to a bias the same way we do for attention_mask 
371-         if  encoder_attention_mask  is  not None  and  encoder_attention_mask .ndim  ==  3 :
372-             # b, 1+use_image_num, l -> a video with images 
373-             # b, 1, l -> only images 
316+         if  encoder_attention_mask  is  not None  and  encoder_attention_mask .ndim  ==  2 :
374317            encoder_attention_mask  =  (1  -  encoder_attention_mask .to (self .dtype )) *  - 10000.0 
375-             encoder_attention_mask  =  (
376-                 rearrange (encoder_attention_mask , "b 1 l -> (b 1) 1 l" ) if  encoder_attention_mask .numel () >  0  else  None 
377-             )
318+             encoder_attention_mask  =  encoder_attention_mask .unsqueeze (1 )
378319
379320        # 1. Input 
380321        post_patch_num_frames  =  num_frames  //  self .config .patch_size_temporal 
@@ -385,9 +326,9 @@ def forward(
385326            timestep , batch_size = batch_size , hidden_dtype = hidden_states .dtype 
386327        )
387328
388-         hidden_states  =  self . pos_embed ( 
389-              hidden_states 
390-         )   # TODO(aryan): remove dtype conversion here and move to pipeline if needed 
329+         hidden_states  =  hidden_states . permute ( 0 ,  2 ,  1 ,  3 ,  4 ). flatten ( 0 ,  1 ) 
330+         hidden_states   =   self . pos_embed ( hidden_states ) 
331+         hidden_states   =   hidden_states . unflatten ( 0 , ( batch_size ,  - 1 )). flatten ( 1 ,  2 ) 
391332
392333        encoder_hidden_states  =  self .caption_projection (encoder_hidden_states )
393334        encoder_hidden_states  =  encoder_hidden_states .view (batch_size , - 1 , encoder_hidden_states .shape [- 1 ])
0 commit comments