@@ -973,6 +973,30 @@ def __call__(
973973 ):
974974 latent_chunk = self ._select_latents (tile_latents , start_index , min (end_index - 1 , tile_latents .shape [2 ] - 1 ))
975975 latent_tile_num_frames = latent_chunk .shape [2 ]
976+
977+ if start_index > 0 :
978+ last_latent_chunk = self ._select_latents (tile_out_latents , - temporal_overlap , - 1 )
979+ last_latent_tile_num_frames = last_latent_chunk .shape [2 ]
980+ latent_chunk = torch .cat ([last_latent_chunk , latent_chunk ], dim = 2 )
981+ total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
982+ else :
983+ total_latent_num_frames = latent_tile_num_frames
984+
985+ latent_chunk = self ._pack_latents (
986+ latent_chunk ,
987+ self .transformer_spatial_patch_size ,
988+ self .transformer_temporal_patch_size ,
989+ )
990+
991+ video_ids = self ._prepare_video_ids (
992+ batch_size ,
993+ total_latent_num_frames ,
994+ latent_tile_height ,
995+ latent_tile_width ,
996+ patch_size_t = self .transformer_temporal_patch_size ,
997+ patch_size = self .transformer_spatial_patch_size ,
998+ device = device ,
999+ )
9761000
9771001 # Set timesteps
9781002 inner_timesteps , inner_num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
@@ -981,17 +1005,6 @@ def __call__(
9811005 self ._num_timesteps = len (inner_timesteps )
9821006
9831007 if start_index == 0 :
984- latent_chunk = self ._pack_latents (latent_chunk , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size )
985-
986- video_ids = self ._prepare_video_ids (
987- batch_size ,
988- latent_tile_num_frames ,
989- latent_tile_height ,
990- latent_tile_width ,
991- patch_size_t = self .transformer_temporal_patch_size ,
992- patch_size = self .transformer_spatial_patch_size ,
993- device = device ,
994- )
9951008 video_ids = self ._scale_video_ids (
9961009 video_ids ,
9971010 scale_factor = self .vae_spatial_compression_ratio ,
@@ -1066,26 +1079,6 @@ def __call__(
10661079 )
10671080 first_tile_out_latents = tile_out_latents .clone ()
10681081 else :
1069- last_latent_chunk = self ._select_latents (tile_out_latents , - temporal_overlap , - 1 )
1070- last_latent_tile_num_frames = last_latent_chunk .shape [2 ]
1071- latent_chunk = torch .cat ([last_latent_chunk , latent_chunk ], dim = 2 )
1072- total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
1073- latent_chunk = self ._pack_latents (
1074- latent_chunk ,
1075- self .transformer_spatial_patch_size ,
1076- self .transformer_temporal_patch_size ,
1077- )
1078-
1079- video_ids = self ._prepare_video_ids (
1080- batch_size ,
1081- total_latent_num_frames ,
1082- latent_tile_height ,
1083- latent_tile_width ,
1084- patch_size_t = self .transformer_temporal_patch_size ,
1085- patch_size = self .transformer_spatial_patch_size ,
1086- device = device ,
1087- )
1088-
10891082 conditioning_mask = torch .zeros (
10901083 (batch_size , total_latent_num_frames ), dtype = torch .float32 , device = device ,
10911084 )
0 commit comments