diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index 12be5efeccb2..d59b4ce3cb17 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -370,7 +370,6 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents def prepare_latents( self, image: PipelineImageInput, diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index b075cf5ba014..24e9cccdb440 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -175,6 +175,7 @@ def __init__( image_encoder: CLIPVisionModel = None, transformer_2: WanTransformer3DModel = None, boundary_ratio: Optional[float] = None, + expand_timesteps: bool = False, ): super().__init__() @@ -188,10 +189,10 @@ def __init__( image_processor=image_processor, transformer_2=transformer_2, ) - self.register_to_config(boundary_ratio=boundary_ratio) + self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps) - self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.image_processor = image_processor @@ -419,8 +420,12 @@ def prepare_latents( else: latents = latents.to(device=device, dtype=dtype) - image = image.unsqueeze(2) - if last_image is None: + image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] + + if self.config.expand_timesteps: + video_condition = image + + elif last_image is None: video_condition = torch.cat( [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 ) @@ -453,6 +458,13 @@ def prepare_latents( latent_condition = latent_condition.to(dtype) latent_condition = (latent_condition - latents_mean) * latents_std + if self.config.expand_timesteps: + first_frame_mask = torch.ones( + 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device + ) + first_frame_mask[:, :, 0] = 0 + return latents, latent_condition, first_frame_mask + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) if last_image is None: @@ -662,7 +674,7 @@ def __call__( if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - if self.config.boundary_ratio is None: + if self.config.boundary_ratio is None and not self.config.expand_timesteps: if image_embeds is None: if last_image is None: image_embeds = self.encode_image(image, device) @@ -682,7 +694,8 @@ def __call__( last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( device, dtype=torch.float32 ) - latents, condition = self.prepare_latents( + + latents_outputs = self.prepare_latents( image, batch_size * num_videos_per_prompt, num_channels_latents, @@ -695,6 +708,10 @@ def __call__( latents, last_image, ) + if self.config.expand_timesteps: + latents, condition, first_frame_mask = latents_outputs + else: + latents, condition = latents_outputs # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -721,8 +738,17 @@ def __call__( current_model = self.transformer_2 current_guidance_scale = guidance_scale_2 - latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) - timestep = t.expand(latents.shape[0]) + if self.config.expand_timesteps: + latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(transformer_dtype) + + # seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size) + temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) noise_pred = current_model( hidden_states=latent_model_input, @@ -766,6 +792,9 @@ def __call__( self._current_timestep = None + if self.config.expand_timesteps: + latents = (1 - first_frame_mask) * condition + first_frame_mask * latents + if not output_type == "latent": latents = latents.to(self.vae.dtype) latents_mean = (