@@ -979,6 +979,11 @@ def __call__(
979979 last_latent_tile_num_frames = last_latent_chunk .shape [2 ]
980980 latent_chunk = torch .cat ([last_latent_chunk , latent_chunk ], dim = 2 )
981981 total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
982+
983+ conditioning_mask = torch .zeros (
984+ (batch_size , total_latent_num_frames ), dtype = torch .float32 , device = device ,
985+ )
986+ conditioning_mask [:, :last_latent_tile_num_frames ] = 1.0
982987 else :
983988 total_latent_num_frames = latent_tile_num_frames
984989
@@ -998,25 +1003,28 @@ def __call__(
9981003 device = device ,
9991004 )
10001005
1006+ if start_index > 0 :
1007+ conditioning_mask = conditioning_mask .gather (1 , video_ids [:, 0 ])
1008+
1009+ video_ids = self ._scale_video_ids (
1010+ video_ids ,
1011+ scale_factor = self .vae_spatial_compression_ratio ,
1012+ scale_factor_t = self .vae_temporal_compression_ratio ,
1013+ frame_index = 0 ,
1014+ device = device
1015+ )
1016+ video_ids = video_ids .float ()
1017+ video_ids [:, 0 ] = video_ids [:, 0 ] * (1.0 / frame_rate )
1018+ if self .do_classifier_free_guidance :
1019+ video_ids = torch .cat ([video_ids , video_ids ], dim = 0 )
1020+
10011021 # Set timesteps
10021022 inner_timesteps , inner_num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
10031023 sigmas = self .scheduler .sigmas
10041024 num_warmup_steps = max (len (inner_timesteps ) - inner_num_inference_steps * self .scheduler .order , 0 )
10051025 self ._num_timesteps = len (inner_timesteps )
10061026
10071027 if start_index == 0 :
1008- video_ids = self ._scale_video_ids (
1009- video_ids ,
1010- scale_factor = self .vae_spatial_compression_ratio ,
1011- scale_factor_t = self .vae_temporal_compression_ratio ,
1012- frame_index = 0 ,
1013- device = device
1014- )
1015- video_ids = video_ids .float ()
1016- video_ids [:, 0 ] = video_ids [:, 0 ] * (1.0 / frame_rate )
1017- if self .do_classifier_free_guidance :
1018- video_ids = torch .cat ([video_ids , video_ids ], dim = 0 )
1019-
10201028 with self .progress_bar (total = inner_num_inference_steps ) as progress_bar :
10211029 for i , t in enumerate (inner_timesteps ):
10221030 if self .interrupt :
@@ -1079,24 +1087,6 @@ def __call__(
10791087 )
10801088 first_tile_out_latents = tile_out_latents .clone ()
10811089 else :
1082- conditioning_mask = torch .zeros (
1083- (batch_size , total_latent_num_frames ), dtype = torch .float32 , device = device ,
1084- )
1085- conditioning_mask [:, :last_latent_tile_num_frames ] = 1.0
1086- conditioning_mask = conditioning_mask .gather (1 , video_ids [:, 0 ])
1087-
1088- video_ids = self ._scale_video_ids (
1089- video_ids ,
1090- scale_factor = self .vae_spatial_compression_ratio ,
1091- scale_factor_t = self .vae_temporal_compression_ratio ,
1092- frame_index = 0 ,
1093- device = device
1094- )
1095- video_ids = video_ids .float ()
1096- video_ids [:, 0 ] = video_ids [:, 0 ] * (1.0 / frame_rate )
1097- if self .do_classifier_free_guidance :
1098- video_ids = torch .cat ([video_ids , video_ids ], dim = 0 )
1099-
11001090 with self .progress_bar (total = inner_num_inference_steps ) as progress_bar :
11011091 for i , t in enumerate (inner_timesteps ):
11021092 if self .interrupt :
0 commit comments