2727from  ..embeddings  import  PatchEmbed , PixArtAlphaTextProjection 
2828from  ..modeling_outputs  import  Transformer2DModelOutput 
2929from  ..modeling_utils  import  ModelMixin 
30- from  ..normalization  import  AllegroAdaLayerNormSingle 
30+ from  ..normalization  import  AdaLayerNormSingle 
3131
3232
3333logger  =  logging .get_logger (__name__ )
3636@maybe_allow_in_graph  
3737class  AllegroTransformerBlock (nn .Module ):
3838    r""" 
39-     TODO(aryan): docs 
39+     Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model. 
40+     Args: 
41+         dim (`int`): 
42+             The number of channels in the input and output. 
43+         num_attention_heads (`int`): 
44+             The number of heads to use for multi-head attention. 
45+         attention_head_dim (`int`): 
46+             The number of channels in each head. 
47+         dropout (`float`, defaults to `0.0`): 
48+             The dropout probability to use. 
49+         cross_attention_dim (`int`, defaults to `2304`): 
50+             The dimension of the cross attention features. 
51+         activation_fn (`str`, defaults to `"gelu-approximate"`): 
52+             Activation function to be used in feed-forward. 
53+         attention_bias (`bool`, defaults to `False`): 
54+             Whether or not to use bias in attention projection layers. 
55+         only_cross_attention (`bool`, defaults to `False`): 
56+         norm_elementwise_affine (`bool`, defaults to `True`): 
57+             Whether to use learnable elementwise affine parameters for normalization. 
58+         norm_eps (`float`, defaults to `1e-5`): 
59+             Epsilon value for normalization layers. 
60+         final_dropout (`bool` defaults to `False`): 
61+             Whether to apply a final dropout after the last feed-forward layer. 
4062    """ 
4163
4264    def  __init__ (
@@ -48,11 +70,8 @@ def __init__(
4870        cross_attention_dim : Optional [int ] =  None ,
4971        activation_fn : str  =  "geglu" ,
5072        attention_bias : bool  =  False ,
51-         only_cross_attention : bool  =  False ,
52-         upcast_attention : bool  =  False ,
5373        norm_elementwise_affine : bool  =  True ,
5474        norm_eps : float  =  1e-5 ,
55-         final_dropout : bool  =  False ,
5675    ):
5776        super ().__init__ ()
5877
@@ -65,8 +84,7 @@ def __init__(
6584            dim_head = attention_head_dim ,
6685            dropout = dropout ,
6786            bias = attention_bias ,
68-             cross_attention_dim = cross_attention_dim  if  only_cross_attention  else  None ,
69-             upcast_attention = upcast_attention ,
87+             cross_attention_dim = None ,
7088            processor = AllegroAttnProcessor2_0 (),
7189        )
7290
@@ -79,9 +97,8 @@ def __init__(
7997            dim_head = attention_head_dim ,
8098            dropout = dropout ,
8199            bias = attention_bias ,
82-             upcast_attention = upcast_attention ,
83100            processor = AllegroAttnProcessor2_0 (),
84-         )   # is self-attn if encoder_hidden_states is none 
101+         )
85102
86103        # 3. Feed Forward 
87104        self .norm3  =  nn .LayerNorm (dim , elementwise_affine = norm_elementwise_affine , eps = norm_eps )
@@ -90,7 +107,6 @@ def __init__(
90107            dim ,
91108            dropout = dropout ,
92109            activation_fn = activation_fn ,
93-             final_dropout = final_dropout ,
94110        )
95111
96112        # 4. Scale-shift 
@@ -159,37 +175,55 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
159175    _supports_gradient_checkpointing  =  True 
160176
161177    """ 
162-     A 2D Transformer model for image-like data. 
163- 
164-     Parameters: 
165-         num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 
166-         attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 
167-         in_channels (`int`, *optional*): 
168-             The number of channels in the input and output (specify if the input is **continuous**). 
169-         num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 
170-         dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 
171-         cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 
172-         sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 
173-             This is fixed during training since it is used to learn a number of position embeddings. 
174-         num_vector_embeds (`int`, *optional*): 
175-             The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). 
176-             Includes the class for the masked latent pixel. 
177-         activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 
178-         num_embeds_ada_norm ( `int`, *optional*): 
179-             The number of diffusion steps used during training. Pass if at least one of the norm_layers is 
180-             `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are 
181-             added to the hidden states. 
182- 
183-             During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. 
184-         attention_bias (`bool`, *optional*): 
185-             Configure if the `TransformerBlocks` attention should contain a bias parameter. 
178+     A 3D Transformer model for video-like data. 
179+     Args: 
180+         patch_size (`int`, defaults to `2`): 
181+             The size of spatial patches to use in the patch embedding layer. 
182+         patch_size_t (`int`, defaults to `1`): 
183+             The size of temporal patches to use in the patch embedding layer. 
184+         num_attention_heads (`int`, defaults to `24`): 
185+             The number of heads to use for multi-head attention. 
186+         attention_head_dim (`int`, defaults to `96`): 
187+             The number of channels in each head. 
188+         in_channels (`int`, defaults to `4`): 
189+             The number of channels in the input. 
190+         out_channels (`int`, *optional*, defaults to `4`): 
191+             The number of channels in the output. 
192+         num_layers (`int`, defaults to `32`): 
193+             The number of layers of Transformer blocks to use. 
194+         dropout (`float`, defaults to `0.0`): 
195+             The dropout probability to use. 
196+         cross_attention_dim (`int`, defaults to `2304`): 
197+             The dimension of the cross attention features. 
198+         attention_bias (`bool`, defaults to `True`): 
199+             Whether or not to use bias in the attention projection layers. 
200+         sample_height (`int`, defaults to `90`): 
201+             The height of the input latents. 
202+         sample_width (`int`, defaults to `160`): 
203+             The width of the input latents. 
204+         sample_frames (`int`, defaults to `22`): 
205+             The number of frames in the input latents. 
206+         activation_fn (`str`, defaults to `"gelu-approximate"`): 
207+             Activation function to use in feed-forward. 
208+         norm_elementwise_affine (`bool`, defaults to `True`): 
209+             Whether or not to use elementwise affine in normalization layers. 
210+         norm_eps (`float`, defaults to `1e-5`): 
211+             The epsilon value to use in normalization layers. 
212+         caption_channels (`int`, defaults to `4096`): 
213+             Number of channels to use for projecting the caption embeddings. 
214+         interpolation_scale_h (`float`, defaults to `2.0`): 
215+             Scaling factor to apply in 3D positional embeddings across height dimension. 
216+         interpolation_scale_w (`float`, defaults to `2.0`): 
217+             Scaling factor to apply in 3D positional embeddings across width dimension. 
218+         interpolation_scale_t (`float`, defaults to `2.2`): 
219+             Scaling factor to apply in 3D positional embeddings across time dimension. 
186220    """ 
187221
188222    @register_to_config  
189223    def  __init__ (
190224        self ,
191225        patch_size : int  =  2 ,
192-         patch_size_temporal : int  =  1 ,
226+         patch_size_t : int  =  1 ,
193227        num_attention_heads : int  =  24 ,
194228        attention_head_dim : int  =  96 ,
195229        in_channels : int  =  4 ,
@@ -202,7 +236,6 @@ def __init__(
202236        sample_width : int  =  160 ,
203237        sample_frames : int  =  22 ,
204238        activation_fn : str  =  "gelu-approximate" ,
205-         upcast_attention : bool  =  False ,
206239        norm_elementwise_affine : bool  =  False ,
207240        norm_eps : float  =  1e-6 ,
208241        caption_channels : int  =  4096 ,
@@ -245,7 +278,6 @@ def __init__(
245278                    cross_attention_dim = cross_attention_dim ,
246279                    activation_fn = activation_fn ,
247280                    attention_bias = attention_bias ,
248-                     upcast_attention = upcast_attention ,
249281                    norm_elementwise_affine = norm_elementwise_affine ,
250282                    norm_eps = norm_eps ,
251283                )
@@ -259,7 +291,7 @@ def __init__(
259291        self .proj_out  =  nn .Linear (self .inner_dim , patch_size  *  patch_size  *  out_channels )
260292
261293        # 4. Timestep embeddings 
262-         self .adaln_single  =  AllegroAdaLayerNormSingle (self .inner_dim , use_additional_conditions = False )
294+         self .adaln_single  =  AdaLayerNormSingle (self .inner_dim , use_additional_conditions = False )
263295
264296        # 5. Caption projection 
265297        self .caption_projection  =  PixArtAlphaTextProjection (in_features = caption_channels , hidden_size = self .inner_dim )
@@ -280,9 +312,13 @@ def forward(
280312        return_dict : bool  =  True ,
281313    ):
282314        batch_size , num_channels , num_frames , height , width  =  hidden_states .shape 
283-         p_t  =  self .config .patch_size_temporal 
315+         p_t  =  self .config .patch_size_t 
284316        p  =  self .config .patch_size 
285317
318+         post_patch_num_frames  =  num_frames  //  self .config .patch_size_temporal 
319+         post_patch_height  =  height  //  self .config .patch_size 
320+         post_patch_width  =  width  //  self .config .patch_size 
321+ 
286322        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 
287323        #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 
288324        #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 
@@ -317,22 +353,20 @@ def forward(
317353            encoder_attention_mask  =  (1  -  encoder_attention_mask .to (self .dtype )) *  - 10000.0 
318354            encoder_attention_mask  =  encoder_attention_mask .unsqueeze (1 )
319355
320-         # 1. Input 
321-         post_patch_num_frames  =  num_frames  //  self .config .patch_size_temporal 
322-         post_patch_height  =  height  //  self .config .patch_size 
323-         post_patch_width  =  width  //  self .config .patch_size 
324- 
356+         # 1. Timestep embeddings 
325357        timestep , embedded_timestep  =  self .adaln_single (
326358            timestep , batch_size = batch_size , hidden_dtype = hidden_states .dtype 
327359        )
328360
361+         # 2. Patch embeddings 
329362        hidden_states  =  hidden_states .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
330363        hidden_states  =  self .pos_embed (hidden_states )
331364        hidden_states  =  hidden_states .unflatten (0 , (batch_size , - 1 )).flatten (1 , 2 )
332365
333366        encoder_hidden_states  =  self .caption_projection (encoder_hidden_states )
334367        encoder_hidden_states  =  encoder_hidden_states .view (batch_size , - 1 , encoder_hidden_states .shape [- 1 ])
335368
369+         # 3. Transformer blocks 
336370        for  i , block  in  enumerate (self .transformer_blocks ):
337371            # TODO(aryan): Implement gradient checkpointing 
338372            if  self .gradient_checkpointing :
@@ -364,7 +398,7 @@ def custom_forward(*inputs):
364398                    image_rotary_emb = image_rotary_emb ,
365399                )
366400
367-         # 3 . Output 
401+         # 4 . Output normalization & projection  
368402        shift , scale  =  (self .scale_shift_table [None ] +  embedded_timestep [:, None ]).chunk (2 , dim = 1 )
369403        hidden_states  =  self .norm_out (hidden_states )
370404
@@ -373,7 +407,7 @@ def custom_forward(*inputs):
373407        hidden_states  =  self .proj_out (hidden_states )
374408        hidden_states  =  hidden_states .squeeze (1 )
375409
376-         # unpatchify  
410+         # 5. Unpatchify  
377411        hidden_states  =  hidden_states .reshape (
378412            batch_size , post_patch_num_frames , post_patch_height , post_patch_width , p_t , p , p , - 1 
379413        )
0 commit comments