@@ -521,7 +521,7 @@ def __call__(
521521 num_frames_per_chunk : int = 93 ,
522522 num_inference_steps : int = 36 ,
523523 guidance_scale : float = 3.0 ,
524- num_videos_per_prompt : Optional [ int ] = 1 ,
524+ num_videos_per_prompt : int = 1 ,
525525 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
526526 latents : Optional [torch .Tensor ] = None ,
527527 prompt_embeds : Optional [torch .Tensor ] = None ,
@@ -721,15 +721,23 @@ def __call__(
721721 vae_dtype = self .vae .dtype
722722 transformer_dtype = self .transformer .dtype
723723
724- img_context = torch .zeros (
725- batch_size ,
726- self .transformer .config .img_context_num_tokens ,
727- self .transformer .config .img_context_dim_in ,
728- device = prompt_embeds .device ,
729- dtype = transformer_dtype ,
730- )
731- encoder_hidden_states = (prompt_embeds , img_context )
732- neg_encoder_hidden_states = (negative_prompt_embeds , img_context )
724+ if getattr (self .transformer .config , "img_context_dim_in" , None ):
725+ img_context = torch .zeros (
726+ batch_size ,
727+ self .transformer .config .img_context_num_tokens ,
728+ self .transformer .config .img_context_dim_in ,
729+ device = prompt_embeds .device ,
730+ dtype = transformer_dtype ,
731+ )
732+
733+ if num_videos_per_prompt > 1 :
734+ img_context = img_context .repeat_interleave (num_videos_per_prompt , dim = 0 )
735+
736+ encoder_hidden_states = (prompt_embeds , img_context )
737+ neg_encoder_hidden_states = (negative_prompt_embeds , img_context )
738+ else :
739+ encoder_hidden_states = prompt_embeds
740+ neg_encoder_hidden_states = negative_prompt_embeds
733741
734742 if controls is not None and self .controlnet is None :
735743 logger .warning ("`controls` was provided but `controlnet` is None; ignoring `controls`." )
@@ -798,7 +806,7 @@ def __call__(
798806 chunk_stride = num_frames_per_chunk - num_conditional_frames
799807 chunk_idxs = [
800808 (start_idx , min (start_idx + num_frames_per_chunk , num_frames_out ))
801- for start_idx in range (0 , num_frames_out , chunk_stride )
809+ for start_idx in range (0 , num_frames_out - num_conditional_frames , chunk_stride )
802810 ]
803811
804812 video_chunks = []
@@ -810,6 +818,7 @@ def decode_latents(latents):
810818 video = self .vae .decode (latents .to (dtype = self .vae .dtype , device = device ), return_dict = False )[0 ]
811819 return video
812820
821+ latents_arg = latents
813822 initial_num_cond_latent_frames = 0 if video is None or controls is not None else num_cond_latent_frames
814823 latent_chunks = []
815824 num_chunks = len (chunk_idxs )
@@ -844,7 +853,7 @@ def decode_latents(latents):
844853 num_cond_latent_frames = initial_num_cond_latent_frames
845854 if chunk_idx == 0
846855 else num_cond_latent_frames ,
847- # latents=latents ,
856+ latents = latents_arg ,
848857 )
849858 cond_mask = cond_mask .to (transformer_dtype )
850859 cond_timestep = torch .ones_like (cond_indicator ) * conditional_frame_timestep
@@ -866,7 +875,6 @@ def decode_latents(latents):
866875 latents_std = self .latents_std .to (device = device , dtype = transformer_dtype )
867876 controls_latents = (controls_latents - latents_mean ) / latents_std
868877
869- # breakpoint()
870878 # Denoising loop
871879 self .scheduler .set_timesteps (num_inference_steps , device = device )
872880 timesteps = self .scheduler .timesteps
@@ -980,7 +988,7 @@ def decode_latents(latents):
980988 video = (video * 255 ).astype (np .uint8 )
981989 video_batch = []
982990 for vid in video :
983- # vid = self.safety_checker.check_video_safety(vid)
991+ vid = self .safety_checker .check_video_safety (vid )
984992 if vid is None :
985993 video_batch .append (np .zeros_like (video [0 ]))
986994 else :
0 commit comments