@@ -314,21 +314,20 @@ def __init__(
314314 interpolation_scale_h : float = 2.0 ,
315315 interpolation_scale_w : float = 2.0 ,
316316 interpolation_scale_t : float = 2.2 ,
317- use_additional_conditions : Optional [bool ] = None ,
318317 use_rotary_positional_embeddings : bool = True ,
319318 model_max_length : int = 300 ,
320319 ):
321320 super ().__init__ ()
322321
323322 self .inner_dim = num_attention_heads * attention_head_dim
324- self .out_channels = in_channels if out_channels is None else out_channels
325323
326324 interpolation_scale_t = (
327325 interpolation_scale_t if interpolation_scale_t is not None else ((sample_frames - 1 ) // 16 + 1 ) if sample_frames % 2 == 1 else sample_frames // 16
328326 )
329327 interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30
330328 interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40
331329
330+ # 1. Patch embedding
332331 self .pos_embed = PatchEmbed2D (
333332 height = sample_height ,
334333 width = sample_width ,
@@ -337,9 +336,8 @@ def __init__(
337336 embed_dim = self .inner_dim ,
338337 # pos_embed_type=None,
339338 )
340- interpolation_scale_thw = (interpolation_scale_t , interpolation_scale_h , interpolation_scale_w )
341339
342- # 3. Define transformers blocks, spatial attention
340+ # 2. Transformer blocks
343341 self .transformer_blocks = nn .ModuleList (
344342 [
345343 AllegroTransformerBlock (
@@ -358,19 +356,18 @@ def __init__(
358356 ]
359357 )
360358
361- # 4. Define output layers
359+ # 3. Output projection & norm
362360 self .norm_out = nn .LayerNorm (self .inner_dim , elementwise_affine = False , eps = 1e-6 )
363361 self .scale_shift_table = nn .Parameter (torch .randn (2 , self .inner_dim ) / self .inner_dim ** 0.5 )
364- self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * self . out_channels )
362+ self .proj_out = nn .Linear (self .inner_dim , patch_size * patch_size * out_channels )
365363
366- # 5. PixArt-Alpha blocks.
364+ # 4. Timestep embeddings
367365 self .adaln_single = AllegroAdaLayerNormSingle (self .inner_dim , use_additional_conditions = False )
368366
369- self .caption_projection = None
370- if caption_channels is not None :
371- self .caption_projection = PixArtAlphaTextProjection (
372- in_features = caption_channels , hidden_size = self .inner_dim
373- )
367+ # 5. Caption projection
368+ self .caption_projection = PixArtAlphaTextProjection (
369+ in_features = caption_channels , hidden_size = self .inner_dim
370+ )
374371
375372 self .gradient_checkpointing = False
376373
@@ -382,15 +379,14 @@ def forward(
382379 hidden_states : torch .Tensor ,
383380 timestep : Optional [torch .LongTensor ] = None ,
384381 encoder_hidden_states : Optional [torch .Tensor ] = None ,
385- added_cond_kwargs : Dict [str , torch .Tensor ] = None ,
386- class_labels : Optional [torch .LongTensor ] = None ,
387- cross_attention_kwargs : Dict [str , Any ] = None ,
388382 attention_mask : Optional [torch .Tensor ] = None ,
389383 encoder_attention_mask : Optional [torch .Tensor ] = None ,
390384 image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
391385 return_dict : bool = True ,
392386 ):
393- batch_size , c , frame , h , w = hidden_states .shape
387+ batch_size , num_channels , num_frames , height , width = hidden_states .shape
388+ p_t = self .config .patch_size_temporal
389+ p = self .config .patch_size
394390
395391 # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
396392 # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
@@ -409,112 +405,61 @@ def forward(
409405 # (keep = +0, discard = -10000.0)
410406 # b, frame+use_image_num, h, w -> a video with images
411407 # b, 1, h, w -> only images
412- attention_mask = attention_mask .to (self .dtype )
413- attention_mask_vid = attention_mask [:, :frame ] # b, frame, h, w
408+ attention_mask = attention_mask .to (hidden_states .dtype )
409+ attention_mask = attention_mask [:, :num_frames ] # [batch_size, num_frames, height, width]
414410
415- if attention_mask_vid .numel () > 0 :
416- attention_mask_vid = attention_mask_vid .unsqueeze (1 ) # b 1 t h w
417- attention_mask_vid = F .max_pool3d (attention_mask_vid , kernel_size = (self . config . patch_size_temporal , self . config . patch_size , self . config . patch_size ), stride = (self . config . patch_size_temporal , self . config . patch_size , self . config . patch_size ))
418- attention_mask_vid = rearrange ( attention_mask_vid , 'b 1 t h w -> (b 1) 1 (t h w)' )
411+ if attention_mask .numel () > 0 :
412+ attention_mask = attention_mask .unsqueeze (1 ) # [batch_size, 1, num_frames, height, width]
413+ attention_mask = F .max_pool3d (attention_mask , kernel_size = (p_t , p , p ), stride = (p_t , p , p ))
414+ attention_mask = attention_mask . flatten ( 1 ). view ( batch_size , 1 , - 1 )
419415
420- attention_mask_vid = (1 - attention_mask_vid .bool ().to (self .dtype )) * - 10000.0 if attention_mask_vid .numel () > 0 else None
416+ attention_mask = (1 - attention_mask .bool ().to (hidden_states .dtype )) * - 10000.0 if attention_mask .numel () > 0 else None
421417
422418 # convert encoder_attention_mask to a bias the same way we do for attention_mask
423419 if encoder_attention_mask is not None and encoder_attention_mask .ndim == 3 :
424420 # b, 1+use_image_num, l -> a video with images
425421 # b, 1, l -> only images
426422 encoder_attention_mask = (1 - encoder_attention_mask .to (self .dtype )) * - 10000.0
427- encoder_attention_mask_vid = rearrange (encoder_attention_mask , 'b 1 l -> (b 1) 1 l' ) if encoder_attention_mask .numel () > 0 else None
423+ encoder_attention_mask = rearrange (encoder_attention_mask , 'b 1 l -> (b 1) 1 l' ) if encoder_attention_mask .numel () > 0 else None
428424
429425 # 1. Input
430- frame = frame // self .config .patch_size_temporal
431- height = hidden_states . shape [ - 2 ] // self .config .patch_size
432- width = hidden_states . shape [ - 1 ] // self .config .patch_size
426+ post_patch_num_frames = num_frames // self .config .patch_size_temporal
427+ post_patch_height = height // self .config .patch_size
428+ post_patch_width = width // self .config .patch_size
433429
434- added_cond_kwargs = {"resolution" : None , "aspect_ratio" : None } if added_cond_kwargs is None else added_cond_kwargs
435- hidden_states , encoder_hidden_states_vid , timestep_vid , embedded_timestep_vid = self ._operate_on_patched_inputs (
436- hidden_states , encoder_hidden_states , timestep , added_cond_kwargs , batch_size ,
437- )
430+ timestep , embedded_timestep = self .adaln_single (timestep , batch_size = batch_size , hidden_dtype = hidden_states .dtype )
431+
432+ hidden_states = self .pos_embed (hidden_states ) # TODO(aryan): remove dtype conversion here and move to pipeline if needed
433+
434+ encoder_hidden_states = self .caption_projection (encoder_hidden_states )
435+ encoder_hidden_states = encoder_hidden_states .view (batch_size , - 1 , encoder_hidden_states .shape [- 1 ])
438436
439- for _ , block in enumerate (self .transformer_blocks ):
437+ for i , block in enumerate (self .transformer_blocks ):
440438 # TODO(aryan): Implement gradient checkpointing
441- block : AllegroTransformerBlock
442- hidden_states = block .forward (
439+ hidden_states = block (
443440 hidden_states = hidden_states ,
444- encoder_hidden_states = encoder_hidden_states_vid ,
445- temb = timestep_vid ,
446- attention_mask = attention_mask_vid ,
447- encoder_attention_mask = encoder_attention_mask_vid ,
441+ encoder_hidden_states = encoder_hidden_states ,
442+ temb = timestep ,
443+ attention_mask = attention_mask ,
444+ encoder_attention_mask = encoder_attention_mask ,
448445 image_rotary_emb = image_rotary_emb ,
449446 )
450447
451448 # 3. Output
452- output = None
453- if hidden_states is not None :
454- output = self ._get_output_for_patched_inputs (
455- hidden_states = hidden_states ,
456- timestep = timestep_vid ,
457- class_labels = class_labels ,
458- embedded_timestep = embedded_timestep_vid ,
459- num_frames = frame ,
460- height = height ,
461- width = width ,
462- ) # b c t h w
449+ shift , scale = (self .scale_shift_table [None ] + embedded_timestep [:, None ]).chunk (2 , dim = 1 )
450+ hidden_states = self .norm_out (hidden_states )
451+
452+ # Modulation
453+ hidden_states = hidden_states * (1 + scale ) + shift
454+ hidden_states = self .proj_out (hidden_states )
455+ hidden_states = hidden_states .squeeze (1 )
456+
457+ # unpatchify
458+ hidden_states = hidden_states .reshape (batch_size , post_patch_num_frames , post_patch_height , post_patch_width , p_t , p , p , - 1 )
459+ hidden_states = hidden_states .permute (0 , 7 , 1 , 4 , 2 , 5 , 3 , 6 )
460+ output = hidden_states .reshape (batch_size , - 1 , num_frames , height , width )
463461
464462 if not return_dict :
465463 return (output ,)
466464
467465 return Transformer2DModelOutput (sample = output )
468-
469- def _operate_on_patched_inputs (self , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor , timestep : torch .LongTensor , added_cond_kwargs : Dict [str , Any ], batch_size : int ):
470- hidden_states = self .pos_embed (hidden_states .to (self .dtype )) # TODO(aryan): remove dtype conversion here and move to pipeline if needed
471-
472- timestep_vid = None
473- embedded_timestep_vid = None
474- encoder_hidden_states_vid = None
475-
476- if self .adaln_single is not None :
477- if self .config .use_additional_conditions and added_cond_kwargs is None :
478- raise ValueError (
479- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
480- )
481- timestep , embedded_timestep = self .adaln_single (
482- timestep , added_cond_kwargs , batch_size = batch_size , hidden_dtype = self .dtype
483- ) # b 6d, b d
484-
485- timestep_vid = timestep
486- embedded_timestep_vid = embedded_timestep
487-
488- if self .caption_projection is not None :
489- encoder_hidden_states = self .caption_projection (encoder_hidden_states ) # b, 1+use_image_num, l, d or b, 1, l, d
490- encoder_hidden_states_vid = rearrange (encoder_hidden_states [:, :1 ], 'b 1 l d -> (b 1) l d' )
491-
492- return hidden_states , encoder_hidden_states_vid , timestep_vid , embedded_timestep_vid
493-
494- def _get_output_for_patched_inputs (
495- self , hidden_states , timestep , class_labels , embedded_timestep , num_frames , height = None , width = None
496- ) -> torch .Tensor :
497- if self .config .norm_type != "ada_norm_single" :
498- conditioning = self .transformer_blocks [0 ].norm1 .emb (
499- timestep , class_labels , hidden_dtype = self .dtype
500- )
501- shift , scale = self .proj_out_1 (F .silu (conditioning )).chunk (2 , dim = 1 )
502- hidden_states = self .norm_out (hidden_states ) * (1 + scale [:, None ]) + shift [:, None ]
503- hidden_states = self .proj_out_2 (hidden_states )
504- elif self .config .norm_type == "ada_norm_single" :
505- shift , scale = (self .scale_shift_table [None ] + embedded_timestep [:, None ]).chunk (2 , dim = 1 )
506- hidden_states = self .norm_out (hidden_states )
507- # Modulation
508- hidden_states = hidden_states * (1 + scale ) + shift
509- hidden_states = self .proj_out (hidden_states )
510- hidden_states = hidden_states .squeeze (1 )
511-
512- # unpatchify
513- if self .adaln_single is None :
514- height = width = int (hidden_states .shape [1 ] ** 0.5 )
515- hidden_states = hidden_states .reshape (
516- shape = (- 1 , num_frames , height , width , self .config .patch_size_temporal , self .config .patch_size , self .config .patch_size , self .out_channels )
517- )
518- hidden_states = torch .einsum ("nthwopqc->nctohpwq" , hidden_states )
519- output = hidden_states .reshape (- 1 , self .out_channels , num_frames * self .config .patch_size_temporal , height * self .config .patch_size , width * self .config .patch_size )
520- return output
0 commit comments