2727from  ..embeddings  import  PatchEmbed , PixArtAlphaTextProjection 
2828from  ..modeling_outputs  import  Transformer2DModelOutput 
2929from  ..modeling_utils  import  ModelMixin 
30- from  ..normalization  import  AdaLayerNormSingle 
30+ from  ..normalization  import  AllegroAdaLayerNormSingle 
3131
3232
3333logger  =  logging .get_logger (__name__ )
3636@maybe_allow_in_graph  
3737class  AllegroTransformerBlock (nn .Module ):
3838    r""" 
39-     Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model. 
40- 
41-     Args: 
42-         dim (`int`): 
43-             The number of channels in the input and output. 
44-         num_attention_heads (`int`): 
45-             The number of heads to use for multi-head attention. 
46-         attention_head_dim (`int`): 
47-             The number of channels in each head. 
48-         dropout (`float`, defaults to `0.0`): 
49-             The dropout probability to use. 
50-         cross_attention_dim (`int`, defaults to `2304`): 
51-             The dimension of the cross attention features. 
52-         activation_fn (`str`, defaults to `"gelu-approximate"`): 
53-             Activation function to be used in feed-forward. 
54-         attention_bias (`bool`, defaults to `False`): 
55-             Whether or not to use bias in attention projection layers. 
56-         only_cross_attention (`bool`, defaults to `False`): 
57-         norm_elementwise_affine (`bool`, defaults to `True`): 
58-             Whether to use learnable elementwise affine parameters for normalization. 
59-         norm_eps (`float`, defaults to `1e-5`): 
60-             Epsilon value for normalization layers. 
61-         final_dropout (`bool` defaults to `False`): 
62-             Whether to apply a final dropout after the last feed-forward layer. 
39+     TODO(aryan): docs 
6340    """ 
6441
6542    def  __init__ (
@@ -71,8 +48,11 @@ def __init__(
7148        cross_attention_dim : Optional [int ] =  None ,
7249        activation_fn : str  =  "geglu" ,
7350        attention_bias : bool  =  False ,
51+         only_cross_attention : bool  =  False ,
52+         upcast_attention : bool  =  False ,
7453        norm_elementwise_affine : bool  =  True ,
7554        norm_eps : float  =  1e-5 ,
55+         final_dropout : bool  =  False ,
7656    ):
7757        super ().__init__ ()
7858
@@ -85,7 +65,8 @@ def __init__(
8565            dim_head = attention_head_dim ,
8666            dropout = dropout ,
8767            bias = attention_bias ,
88-             cross_attention_dim = cross_attention_dim ,
68+             cross_attention_dim = cross_attention_dim  if  only_cross_attention  else  None ,
69+             upcast_attention = upcast_attention ,
8970            processor = AllegroAttnProcessor2_0 (),
9071        )
9172
@@ -98,6 +79,7 @@ def __init__(
9879            dim_head = attention_head_dim ,
9980            dropout = dropout ,
10081            bias = attention_bias ,
82+             upcast_attention = upcast_attention ,
10183            processor = AllegroAttnProcessor2_0 (),
10284        )  # is self-attn if encoder_hidden_states is none 
10385
@@ -108,6 +90,7 @@ def __init__(
10890            dim ,
10991            dropout = dropout ,
11092            activation_fn = activation_fn ,
93+             final_dropout = final_dropout ,
11194        )
11295
11396        # 4. Scale-shift 
@@ -164,63 +147,49 @@ def forward(
164147        ff_output  =  gate_mlp  *  ff_output 
165148
166149        hidden_states  =  ff_output  +  hidden_states 
150+ 
151+         # TODO(aryan): maybe following line is not required 
152+         if  hidden_states .ndim  ==  4 :
153+             hidden_states  =  hidden_states .squeeze (1 )
154+ 
167155        return  hidden_states 
168156
169157
170158class  AllegroTransformer3DModel (ModelMixin , ConfigMixin ):
171159    _supports_gradient_checkpointing  =  True 
172160
173-     r""" 
174-     A 3D Transformer model for video-like data. 
175- 
176-     Args: 
177-         patch_size (`int`, defaults to `2`): 
178-             The size of spatial patches to use in the patch embedding layer. 
179-         patch_size_t (`int`, defaults to `1`): 
180-             The size of temporal patches to use in the patch embedding layer. 
181-         num_attention_heads (`int`, defaults to `24`): 
182-             The number of heads to use for multi-head attention. 
183-         attention_head_dim (`int`, defaults to `96`): 
184-             The number of channels in each head. 
185-         in_channels (`int`, defaults to `4`): 
186-             The number of channels in the input. 
187-         out_channels (`int`, *optional*, defaults to `4`): 
188-             The number of channels in the output. 
189-         num_layers (`int`, defaults to `32`): 
190-             The number of layers of Transformer blocks to use. 
191-         dropout (`float`, defaults to `0.0`): 
192-             The dropout probability to use. 
193-         cross_attention_dim (`int`, defaults to `2304`): 
194-             The dimension of the cross attention features. 
195-         attention_bias (`bool`, defaults to `True`): 
196-             Whether or not to use bias in the attention projection layers. 
197-         sample_height (`int`, defaults to `90`): 
198-             The height of the input latents. 
199-         sample_width (`int`, defaults to `160`): 
200-             The width of the input latents. 
201-         sample_frames (`int`, defaults to `22`): 
202-             The number of frames in the input latents. 
203-         activation_fn (`str`, defaults to `"gelu-approximate"`): 
204-             Activation function to use in feed-forward. 
205-         norm_elementwise_affine (`bool`, defaults to `True`): 
206-             Whether or not to use elementwise affine in normalization layers. 
207-         norm_eps (`float`, defaults to `1e-5`): 
208-             The epsilon value to use in normalization layers. 
209-         caption_channels (`int`, defaults to `4096`): 
210-             Number of channels to use for projecting the caption embeddings. 
211-         interpolation_scale_h (`float`, defaults to `2.0`): 
212-             Scaling factor to apply in 3D positional embeddings across height dimension. 
213-         interpolation_scale_w (`float`, defaults to `2.0`): 
214-             Scaling factor to apply in 3D positional embeddings across width dimension. 
215-         interpolation_scale_t (`float`, defaults to `2.2`): 
216-             Scaling factor to apply in 3D positional embeddings across time dimension. 
161+     """ 
162+     A 2D Transformer model for image-like data. 
163+ 
164+     Parameters: 
165+         num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 
166+         attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 
167+         in_channels (`int`, *optional*): 
168+             The number of channels in the input and output (specify if the input is **continuous**). 
169+         num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 
170+         dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 
171+         cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 
172+         sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 
173+             This is fixed during training since it is used to learn a number of position embeddings. 
174+         num_vector_embeds (`int`, *optional*): 
175+             The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). 
176+             Includes the class for the masked latent pixel. 
177+         activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 
178+         num_embeds_ada_norm ( `int`, *optional*): 
179+             The number of diffusion steps used during training. Pass if at least one of the norm_layers is 
180+             `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are 
181+             added to the hidden states. 
182+ 
183+             During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. 
184+         attention_bias (`bool`, *optional*): 
185+             Configure if the `TransformerBlocks` attention should contain a bias parameter. 
217186    """ 
218187
219188    @register_to_config  
220189    def  __init__ (
221190        self ,
222191        patch_size : int  =  2 ,
223-         patch_size_t : int  =  1 ,
192+         patch_size_temporal : int  =  1 ,
224193        num_attention_heads : int  =  24 ,
225194        attention_head_dim : int  =  96 ,
226195        in_channels : int  =  4 ,
@@ -233,6 +202,7 @@ def __init__(
233202        sample_width : int  =  160 ,
234203        sample_frames : int  =  22 ,
235204        activation_fn : str  =  "gelu-approximate" ,
205+         upcast_attention : bool  =  False ,
236206        norm_elementwise_affine : bool  =  False ,
237207        norm_eps : float  =  1e-6 ,
238208        caption_channels : int  =  4096 ,
@@ -275,6 +245,7 @@ def __init__(
275245                    cross_attention_dim = cross_attention_dim ,
276246                    activation_fn = activation_fn ,
277247                    attention_bias = attention_bias ,
248+                     upcast_attention = upcast_attention ,
278249                    norm_elementwise_affine = norm_elementwise_affine ,
279250                    norm_eps = norm_eps ,
280251                )
@@ -288,7 +259,7 @@ def __init__(
288259        self .proj_out  =  nn .Linear (self .inner_dim , patch_size  *  patch_size  *  out_channels )
289260
290261        # 4. Timestep embeddings 
291-         self .adaln_single  =  AdaLayerNormSingle (self .inner_dim , use_additional_conditions = False )
262+         self .adaln_single  =  AllegroAdaLayerNormSingle (self .inner_dim , use_additional_conditions = False )
292263
293264        # 5. Caption projection 
294265        self .caption_projection  =  PixArtAlphaTextProjection (in_features = caption_channels , hidden_size = self .inner_dim )
@@ -309,13 +280,9 @@ def forward(
309280        return_dict : bool  =  True ,
310281    ):
311282        batch_size , num_channels , num_frames , height , width  =  hidden_states .shape 
312-         p_t  =  self .config .patch_size_t 
283+         p_t  =  self .config .patch_size_temporal 
313284        p  =  self .config .patch_size 
314285
315-         post_patch_num_frames  =  num_frames  //  self .config .patch_size_t 
316-         post_patch_height  =  height  //  self .config .patch_size 
317-         post_patch_width  =  width  //  self .config .patch_size 
318- 
319286        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 
320287        #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 
321288        #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 
@@ -350,20 +317,22 @@ def forward(
350317            encoder_attention_mask  =  (1  -  encoder_attention_mask .to (self .dtype )) *  - 10000.0 
351318            encoder_attention_mask  =  encoder_attention_mask .unsqueeze (1 )
352319
353-         # 1. Timestep embeddings 
320+         # 1. Input 
321+         post_patch_num_frames  =  num_frames  //  self .config .patch_size_temporal 
322+         post_patch_height  =  height  //  self .config .patch_size 
323+         post_patch_width  =  width  //  self .config .patch_size 
324+ 
354325        timestep , embedded_timestep  =  self .adaln_single (
355326            timestep , batch_size = batch_size , hidden_dtype = hidden_states .dtype 
356327        )
357328
358-         # 2. Patch embeddings 
359329        hidden_states  =  hidden_states .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
360330        hidden_states  =  self .pos_embed (hidden_states )
361331        hidden_states  =  hidden_states .unflatten (0 , (batch_size , - 1 )).flatten (1 , 2 )
362332
363333        encoder_hidden_states  =  self .caption_projection (encoder_hidden_states )
364334        encoder_hidden_states  =  encoder_hidden_states .view (batch_size , - 1 , encoder_hidden_states .shape [- 1 ])
365335
366-         # 3. Transformer blocks 
367336        for  i , block  in  enumerate (self .transformer_blocks ):
368337            # TODO(aryan): Implement gradient checkpointing 
369338            if  self .gradient_checkpointing :
@@ -395,16 +364,16 @@ def custom_forward(*inputs):
395364                    image_rotary_emb = image_rotary_emb ,
396365                )
397366
398-         # 4 . Output normalization & projection  
367+         # 3 . Output 
399368        shift , scale  =  (self .scale_shift_table [None ] +  embedded_timestep [:, None ]).chunk (2 , dim = 1 )
400369        hidden_states  =  self .norm_out (hidden_states )
401370
402-         # modulation  
371+         # Modulation  
403372        hidden_states  =  hidden_states  *  (1  +  scale ) +  shift 
404373        hidden_states  =  self .proj_out (hidden_states )
405374        hidden_states  =  hidden_states .squeeze (1 )
406375
407-         # 5. Unpatchify  
376+         # unpatchify  
408377        hidden_states  =  hidden_states .reshape (
409378            batch_size , post_patch_num_frames , post_patch_height , post_patch_width , p_t , p , p , - 1 
410379        )
0 commit comments