@@ -116,17 +116,14 @@ def __init__(
116116 self .theta = theta
117117
118118 def forward (
119- self , hidden_states : torch .Tensor , rope_interpolation_scale : Optional [Tuple [torch .Tensor , float , float ]] = None
119+ self , hidden_states : torch .Tensor , num_frames : int , height : int , width : int , rope_interpolation_scale : Optional [Tuple [torch .Tensor , float , float ]] = None
120120 ) -> Tuple [torch .Tensor , torch .Tensor ]:
121- batch_size , num_channels , num_frames , height , width = hidden_states .shape
122- post_patch_num_frames = num_frames // self .patch_size_t
123- post_patch_height = height // self .patch_size
124- post_patch_width = width // self .patch_size
121+ batch_size = hidden_states .size (0 )
125122
126123 # Always compute rope in fp32
127- grid_h = torch .arange (post_patch_height , dtype = torch .float32 , device = hidden_states .device )
128- grid_w = torch .arange (post_patch_width , dtype = torch .float32 , device = hidden_states .device )
129- grid_f = torch .arange (post_patch_num_frames , dtype = torch .float32 , device = hidden_states .device )
124+ grid_h = torch .arange (height , dtype = torch .float32 , device = hidden_states .device )
125+ grid_w = torch .arange (width , dtype = torch .float32 , device = hidden_states .device )
126+ grid_f = torch .arange (num_frames , dtype = torch .float32 , device = hidden_states .device )
130127 grid = torch .meshgrid (grid_f , grid_h , grid_w , indexing = "ij" )
131128 grid = torch .stack (grid , dim = 0 )
132129 grid = grid .unsqueeze (0 ).repeat (batch_size , 1 , 1 , 1 , 1 )
@@ -374,28 +371,20 @@ def forward(
374371 encoder_hidden_states : torch .Tensor ,
375372 timestep : torch .LongTensor ,
376373 encoder_attention_mask : torch .Tensor ,
374+ num_frames : int ,
375+ height : int ,
376+ width : int ,
377377 rope_interpolation_scale : Optional [Tuple [float , float , float ]] = None ,
378378 return_dict : bool = True ,
379379 ) -> torch .Tensor :
380- image_rotary_emb = self .rope (hidden_states , rope_interpolation_scale )
380+ image_rotary_emb = self .rope (hidden_states , num_frames , height , width , rope_interpolation_scale )
381381
382382 # convert encoder_attention_mask to a bias the same way we do for attention_mask
383383 if encoder_attention_mask is not None and encoder_attention_mask .ndim == 2 :
384384 encoder_attention_mask = (1 - encoder_attention_mask .to (hidden_states .dtype )) * - 10000.0
385385 encoder_attention_mask = encoder_attention_mask .unsqueeze (1 )
386386
387- batch_size , num_channels , num_frames , height , width = hidden_states .shape
388- p = self .config .patch_size
389- p_t = self .config .patch_size_t
390-
391- post_patch_height = height // p
392- post_patch_width = width // p
393- post_patch_num_frames = num_frames // p_t
394-
395- hidden_states = hidden_states .reshape (
396- batch_size , - 1 , post_patch_num_frames , p_t , post_patch_height , p , post_patch_width , p
397- )
398- hidden_states = hidden_states .permute (0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 ).flatten (4 , 7 ).flatten (1 , 3 )
387+ batch_size = hidden_states .size (0 )
399388 hidden_states = self .proj_in (hidden_states )
400389
401390 temb , embedded_timestep = self .time_embed (
@@ -446,12 +435,7 @@ def custom_forward(*inputs):
446435
447436 hidden_states = self .norm_out (hidden_states )
448437 hidden_states = hidden_states * (1 + scale ) + shift
449- hidden_states = self .proj_out (hidden_states )
450-
451- hidden_states = hidden_states .reshape (
452- batch_size , post_patch_num_frames , post_patch_height , post_patch_width , - 1 , p_t , p , p
453- )
454- output = hidden_states .permute (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 ).flatten (6 , 7 ).flatten (4 , 5 ).flatten (2 , 3 )
438+ output = self .proj_out (hidden_states )
455439
456440 if not return_dict :
457441 return (output ,)
0 commit comments