@@ -571,6 +571,8 @@ def __call__(
571571 prompt_attention_mask : Optional [torch .Tensor ] = None ,
572572 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
573573 negative_prompt_attention_mask : Optional [torch .Tensor ] = None ,
574+ decode_timestep : Union [float , List [float ]] = 0.05 ,
575+ decode_noise_scale : Union [float , List [float ]] = 0.025 ,
574576 output_type : Optional [str ] = "pil" ,
575577 return_dict : bool = True ,
576578 attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -625,6 +627,10 @@ def __call__(
625627 provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
626628 negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
627629 Pre-generated attention mask for negative text embeddings.
630+ decode_timestep (`float`, defaults to `0.05`):
631+ The timestep at which generated video is decoded.
632+ decode_noise_scale (`float`, defaults to `0.025`):
633+ The interpolation factor between random noise and denoised latents at the decode timestep.
628634 output_type (`str`, *optional*, defaults to `"pil"`):
629635 The output format of the generate image. Choose between
630636 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -849,6 +855,24 @@ def __call__(
849855 latents , self .vae .latents_mean , self .vae .latents_std , self .vae .config .scaling_factor
850856 )
851857 latents = latents .to (prompt_embeds .dtype )
858+
859+ if not self .vae .config .timestep_conditioning :
860+ timestep = None
861+ else :
862+ noise = torch .randn (latents .shape , generator = generator , device = device , dtype = latents .dtype )
863+ if not isinstance (decode_timestep , list ):
864+ decode_timestep = [decode_timestep ] * batch_size
865+ if decode_noise_scale is None :
866+ decode_noise_scale = decode_timestep
867+ elif not isinstance (decode_noise_scale , list ):
868+ decode_noise_scale = [decode_noise_scale ] * batch_size
869+
870+ timestep = torch .tensor (decode_timestep , device = device , dtype = latents .dtype )
871+ decode_noise_scale = torch .tensor (decode_noise_scale , device = device , dtype = latents .dtype )[
872+ :, None , None , None , None
873+ ]
874+ latents = (1 - decode_noise_scale ) * latents + decode_noise_scale * noise
875+
852876 video = self .vae .decode (latents , return_dict = False )[0 ]
853877 video = self .video_processor .postprocess_video (video , output_type = output_type )
854878
0 commit comments