@@ -430,6 +430,7 @@ def check_inputs(
430430 video ,
431431 frame_index ,
432432 strength ,
433+ denoise_strength ,
433434 height ,
434435 width ,
435436 callback_on_step_end_tensor_inputs = None ,
@@ -497,6 +498,9 @@ def check_inputs(
497498 elif isinstance (video , list ) and isinstance (strength , list ) and len (video ) != len (strength ):
498499 raise ValueError ("If `conditions` is not provided, `video` and `strength` must be of the same length." )
499500
501+ if denoise_strength < 0 or denoise_strength > 1 :
502+ raise ValueError (f"The value of strength should in [0.0, 1.0] but is { denoise_strength } " )
503+
500504 @staticmethod
501505 def _prepare_video_ids (
502506 batch_size : int ,
@@ -649,6 +653,8 @@ def prepare_latents(
649653 width : int = 704 ,
650654 num_frames : int = 161 ,
651655 num_prefix_latent_frames : int = 2 ,
656+ sigma : Optional [torch .Tensor ] = None ,
657+ latents : Optional [torch .Tensor ] = None ,
652658 generator : Optional [torch .Generator ] = None ,
653659 device : Optional [torch .device ] = None ,
654660 dtype : Optional [torch .dtype ] = None ,
@@ -658,7 +664,18 @@ def prepare_latents(
658664 latent_width = width // self .vae_spatial_compression_ratio
659665
660666 shape = (batch_size , num_channels_latents , num_latent_frames , latent_height , latent_width )
661- latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
667+
668+ noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
669+ if latents is not None and sigma is not None :
670+ if latents .shape != shape :
671+ raise ValueError (
672+ f"Latents shape { latents .shape } does not match expected shape { shape } . Please check the input."
673+ )
674+ latents = latents .to (device = device , dtype = dtype )
675+ sigma = sigma .to (device = device , dtype = dtype )
676+ latents = sigma * noise + (1 - sigma ) * latents
677+ else :
678+ latents = noise
662679
663680 if len (conditions ) > 0 :
664681 condition_latent_frames_mask = torch .zeros (
@@ -766,6 +783,13 @@ def prepare_latents(
766783
767784 return latents , conditioning_mask , video_ids , extra_conditioning_num_latents
768785
786+ def get_timesteps (self , sigmas , timesteps , num_inference_steps , strength ):
787+ num_steps = min (int (num_inference_steps * strength ), num_inference_steps )
788+ start_index = max (num_inference_steps - num_steps , 0 )
789+ sigmas = sigmas [start_index :]
790+ timesteps = timesteps [start_index :]
791+ return sigmas , timesteps , num_inference_steps - start_index
792+
769793 @property
770794 def guidance_scale (self ):
771795 return self ._guidance_scale
@@ -799,6 +823,7 @@ def __call__(
799823 video : List [PipelineImageInput ] = None ,
800824 frame_index : Union [int , List [int ]] = 0 ,
801825 strength : Union [float , List [float ]] = 1.0 ,
826+ denoise_strength : float = 1.0 ,
802827 prompt : Union [str , List [str ]] = None ,
803828 negative_prompt : Optional [Union [str , List [str ]]] = None ,
804829 height : int = 512 ,
@@ -842,6 +867,10 @@ def __call__(
842867 generation. If not provided, one has to pass `conditions`.
843868 strength (`float` or `List[float]`, *optional*):
844869 The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
870+ denoise_strength (`float`, defaults to `1.0`):
871+ The strength of the noise added to the latents for editing. Higher strength leads to more noise added
872+ to the latents, therefore leading to more differences between original video and generated video. This
873+ is useful for video-to-video editing.
845874 prompt (`str` or `List[str]`, *optional*):
846875 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
847876 instead.
@@ -918,8 +947,6 @@ def __call__(
918947
919948 if isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
920949 callback_on_step_end_tensor_inputs = callback_on_step_end .tensor_inputs
921- if latents is not None :
922- raise ValueError ("Passing latents is not yet supported." )
923950
924951 # 1. Check inputs. Raise error if not correct
925952 self .check_inputs (
@@ -929,6 +956,7 @@ def __call__(
929956 video = video ,
930957 frame_index = frame_index ,
931958 strength = strength ,
959+ denoise_strength = denoise_strength ,
932960 height = height ,
933961 width = width ,
934962 callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
@@ -977,8 +1005,9 @@ def __call__(
9771005 strength = [strength ] * num_conditions
9781006
9791007 device = self ._execution_device
1008+ vae_dtype = self .vae .dtype
9801009
981- # 3. Prepare text embeddings
1010+ # 3. Prepare text embeddings & conditioning image/video
9821011 (
9831012 prompt_embeds ,
9841013 prompt_attention_mask ,
@@ -1000,8 +1029,6 @@ def __call__(
10001029 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
10011030 prompt_attention_mask = torch .cat ([negative_prompt_attention_mask , prompt_attention_mask ], dim = 0 )
10021031
1003- vae_dtype = self .vae .dtype
1004-
10051032 conditioning_tensors = []
10061033 is_conditioning_image_or_video = image is not None or video is not None
10071034 if is_conditioning_image_or_video :
@@ -1032,7 +1059,24 @@ def __call__(
10321059 )
10331060 conditioning_tensors .append (condition_tensor )
10341061
1035- # 4. Prepare latent variables
1062+ # 4. Prepare timesteps
1063+ latent_num_frames = (num_frames - 1 ) // self .vae_temporal_compression_ratio + 1
1064+ latent_height = height // self .vae_spatial_compression_ratio
1065+ latent_width = width // self .vae_spatial_compression_ratio
1066+ sigmas = linear_quadratic_schedule (num_inference_steps )
1067+ timesteps = sigmas * 1000
1068+ timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
1069+ num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
1070+ self ._num_timesteps = len (timesteps )
1071+
1072+ latent_sigma = None
1073+ if denoise_strength < 1 :
1074+ sigmas , timesteps , num_inference_steps = self .get_timesteps (
1075+ sigmas , timesteps , num_inference_steps , denoise_strength
1076+ )
1077+ latent_sigma = sigmas [:1 ].repeat (batch_size * num_videos_per_prompt )
1078+
1079+ # 5. Prepare latent variables
10361080 num_channels_latents = self .transformer .config .in_channels
10371081 latents , conditioning_mask , video_coords , extra_conditioning_num_latents = self .prepare_latents (
10381082 conditioning_tensors ,
@@ -1043,6 +1087,8 @@ def __call__(
10431087 height = height ,
10441088 width = width ,
10451089 num_frames = num_frames ,
1090+ sigma = latent_sigma ,
1091+ latents = latents ,
10461092 generator = generator ,
10471093 device = device ,
10481094 dtype = torch .float32 ,
@@ -1056,21 +1102,6 @@ def __call__(
10561102 if self .do_classifier_free_guidance :
10571103 video_coords = torch .cat ([video_coords , video_coords ], dim = 0 )
10581104
1059- # 5. Prepare timesteps
1060- latent_num_frames = (num_frames - 1 ) // self .vae_temporal_compression_ratio + 1
1061- latent_height = height // self .vae_spatial_compression_ratio
1062- latent_width = width // self .vae_spatial_compression_ratio
1063- sigmas = linear_quadratic_schedule (num_inference_steps )
1064- timesteps = sigmas * 1000
1065- timesteps , num_inference_steps = retrieve_timesteps (
1066- self .scheduler ,
1067- num_inference_steps ,
1068- device ,
1069- timesteps = timesteps ,
1070- )
1071- num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
1072- self ._num_timesteps = len (timesteps )
1073-
10741105 # 6. Denoising loop
10751106 with self .progress_bar (total = num_inference_steps ) as progress_bar :
10761107 for i , t in enumerate (timesteps ):
0 commit comments