@@ -227,13 +227,17 @@ def forward(
227227 # Prepare text embeddings for spatial block
228228 # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
229229 encoder_hidden_states = self .caption_projection (encoder_hidden_states ) # 3 120 1152
230- encoder_hidden_states_spatial = encoder_hidden_states .repeat_interleave (num_frame , dim = 0 ). view (
231- - 1 , encoder_hidden_states . shape [ - 2 ], encoder_hidden_states .shape [- 1 ]
232- )
230+ encoder_hidden_states_spatial = encoder_hidden_states .repeat_interleave (
231+ num_frame , dim = 0 , output_size = encoder_hidden_states .shape [0 ] * num_frame
232+ ). view ( - 1 , encoder_hidden_states . shape [ - 2 ], encoder_hidden_states . shape [ - 1 ])
233233
234234 # Prepare timesteps for spatial and temporal block
235- timestep_spatial = timestep .repeat_interleave (num_frame , dim = 0 ).view (- 1 , timestep .shape [- 1 ])
236- timestep_temp = timestep .repeat_interleave (num_patches , dim = 0 ).view (- 1 , timestep .shape [- 1 ])
235+ timestep_spatial = timestep .repeat_interleave (
236+ num_frame , dim = 0 , output_size = timestep .shape [0 ] * num_frame
237+ ).view (- 1 , timestep .shape [- 1 ])
238+ timestep_temp = timestep .repeat_interleave (
239+ num_patches , dim = 0 , output_size = timestep .shape [0 ] * num_patches
240+ ).view (- 1 , timestep .shape [- 1 ])
237241
238242 # Spatial and temporal transformer blocks
239243 for i , (spatial_block , temp_block ) in enumerate (
@@ -299,7 +303,9 @@ def forward(
299303 ).permute (0 , 2 , 1 , 3 )
300304 hidden_states = hidden_states .reshape (- 1 , hidden_states .shape [- 2 ], hidden_states .shape [- 1 ])
301305
302- embedded_timestep = embedded_timestep .repeat_interleave (num_frame , dim = 0 ).view (- 1 , embedded_timestep .shape [- 1 ])
306+ embedded_timestep = embedded_timestep .repeat_interleave (
307+ num_frame , dim = 0 , output_size = embedded_timestep .shape [0 ] * num_frame
308+ ).view (- 1 , embedded_timestep .shape [- 1 ])
303309 shift , scale = (self .scale_shift_table [None ] + embedded_timestep [:, None ]).chunk (2 , dim = 1 )
304310 hidden_states = self .norm_out (hidden_states )
305311 # Modulation
0 commit comments