diff --git a/src/diffusers/modular_pipelines/wan/__init__.py b/src/diffusers/modular_pipelines/wan/__init__.py index 7b548e003c63..f4c855b4fd04 100644 --- a/src/diffusers/modular_pipelines/wan/__init__.py +++ b/src/diffusers/modular_pipelines/wan/__init__.py @@ -25,12 +25,14 @@ _import_structure["modular_blocks"] = [ "ALL_BLOCKS", "AUTO_BLOCKS", + "IMAGE2VIDEO_BLOCKS", "TEXT2VIDEO_BLOCKS", "WanAutoBeforeDenoiseStep", "WanAutoBlocks", "WanAutoBlocks", "WanAutoDecodeStep", "WanAutoDenoiseStep", + "WanAutoVaeEncoderStep", ] _import_structure["modular_pipeline"] = ["WanModularPipeline"] @@ -45,11 +47,13 @@ from .modular_blocks import ( ALL_BLOCKS, AUTO_BLOCKS, + IMAGE2VIDEO_BLOCKS, TEXT2VIDEO_BLOCKS, WanAutoBeforeDenoiseStep, WanAutoBlocks, WanAutoDecodeStep, WanAutoDenoiseStep, + WanAutoVaeEncoderStep, ) from .modular_pipeline import WanModularPipeline else: diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index ef65b6453725..d164f9c45643 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -282,7 +282,10 @@ def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" - ) + ), + OutputParam("height", type_hint=int), + OutputParam("width", type_hint=int), + OutputParam("num_frames", type_hint=int), ] @staticmethod diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index 76c5cda5f95e..afddb9a1ac0e 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -34,6 +34,56 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class WanI2VLoopBeforeDenoiser(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanI2VDenoiseLoopWrapper`)" + ) + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process.", + ), + InputParam( + "latent_condition", + required=True, + type_hint=torch.Tensor, + description="The latent condition to use for the denoising process.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latent_model_inputs", + type_hint=torch.Tensor, + description="The concatenated noisy and conditioning latents to use for the denoising process.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int): + block_state.latent_model_inputs = torch.cat([block_state.latents, block_state.latent_condition], dim=1) + return components, block_state + + class WanLoopDenoiser(PipelineBlock): model_name = "wan" @@ -102,7 +152,7 @@ def __call__( components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) # Prepare mini‐batches according to guidance method and `guider_input_fields` - # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # Each guider_state_batch will have .prompt_embeds. # e.g. for CFG, we prepare two batches: one for uncond, one for cond # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds @@ -120,7 +170,112 @@ def __call__( guider_state_batch.noise_pred = components.transformer( hidden_states=block_state.latents.to(transformer_dtype), timestep=t.flatten(), - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=prompt_embeds.to(transformer_dtype), + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.transformer) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + + +class WanI2VLoopDenoiser(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", WanTransformer3DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("attention_kwargs"), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latent_model_inputs", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process.", + ), + InputParam( + "image_embeds", + required=True, + type_hint=torch.Tensor, + description="The encoder hidden states for the image inputs.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process.", + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds. " + "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + ] + + @torch.no_grad() + def __call__( + self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields = { + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + } + transformer_dtype = components.transformer.dtype + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_inputs.to(transformer_dtype), + timestep=t.flatten(), + encoder_hidden_states=prompt_embeds.to(transformer_dtype), + encoder_hidden_states_image=block_state.image_embeds.to(transformer_dtype), attention_kwargs=block_state.attention_kwargs, return_dict=False, )[0] @@ -247,7 +402,7 @@ class WanDenoiseStep(WanDenoiseLoopWrapper): WanLoopDenoiser, WanLoopAfterDenoiser, ] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] + block_names = ["denoiser", "after_denoiser"] @property def description(self) -> str: @@ -257,5 +412,26 @@ def description(self) -> str: "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `WanLoopDenoiser`\n" " - `WanLoopAfterDenoiser`\n" - "This block supports both text2vid tasks." + "This block supports the text2vid task." + ) + + +class WanI2VDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanI2VLoopBeforeDenoiser, + WanI2VLoopDenoiser, + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents with conditional first- and last-frame support. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `WanI2VLoopBeforeDenoiser`\n" + " - `WanI2VLoopDenoiser`\n" + " - `WanI2VLoopAfterDenoiser`\n" + "This block supports the image-to-video and first-last-frame-to-video tasks." ) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index b2ecfd1aa61a..425fb7c748bd 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -17,11 +17,14 @@ import regex as re import torch -from transformers import AutoTokenizer, UMT5EncoderModel +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan from ...utils import is_ftfy_available, logging +from ...video_processor import VideoProcessor from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_pipeline import WanModularPipeline @@ -51,6 +54,20 @@ def prompt_clean(text): return text +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class WanTextEncoderStep(PipelineBlock): model_name = "wan" @@ -240,3 +257,238 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe # Add outputs self.set_block_state(state, block_state) return components, state + + +class WanImageEncoderStep(PipelineBlock): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Encoder step to compute image embeddings to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModel), + ComponentSpec("image_processor", CLIPImageProcessor), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "image", + required=True, + description="The input image to condition the generation on for first-frame conditioned video generation.", + ), + InputParam( + "last_image", + required=False, + description="The last image to condition the generation on for last-frame conditioned video generation.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_embeds", + type_hint=torch.Tensor, + description="image embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if not isinstance(block_state.image, PipelineImageInput): + raise ValueError(f"`image` has to be of type `PipelineImageInput` but is {type(block_state.image)}.") + if block_state.last_image is not None and not isinstance(block_state.last_image, PipelineImageInput): + raise ValueError( + f"`last_image` has to be of type `PipelineImageInput` but is {type(block_state.last_image)}." + ) + + @staticmethod + def encode_image( + components, + image: Union[PipelineImageInput, List[PipelineImageInput]], + device: torch.device, + ): + image = components.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = components.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input images + image = block_state.image + if block_state.last_image is not None: + image = [block_state.image, block_state.last_image] + + block_state.image_embeds = self.encode_image(components, image, block_state.device) + + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class WanVaeEncoderStep(PipelineBlock): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "VAE encode step that encodes the input image/last_image to latents for conditioning the video generation" + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", required=True), + InputParam("last_image", required=False), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("num_frames", type_hint=int), + InputParam("batch_size", type_hint=int), + InputParam("generator"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latent_condition", + type_hint=torch.Tensor, + description="The latents representing the reference first-frame/last-frame for conditioned video generation.", + ), + OutputParam("num_channels_latents", type_hint=int), + ] + + @staticmethod + def _encode_vae_image( + components: WanModularPipeline, + batch_size: int, + height: int, + width: int, + num_frames: int, + image: torch.Tensor, + device: torch.device, + dtype: torch.dtype, + last_image: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ): + latent_height = height // components.vae_scale_factor_spatial + latent_width = width // components.vae_scale_factor_spatial + + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(device, dtype) + + image = image.unsqueeze(2) + if last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=dtype) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + latent_condition = torch.concat([mask_lat_size, latent_condition], dim=1) + + return latent_condition + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + block_state.num_channels_latents = components.vae.config.z_dim + + block_state.image = components.video_processor.preprocess( + block_state.image, height=block_state.height, width=block_state.width + ).to(block_state.device, dtype=torch.float32) + + if block_state.last_image is not None: + block_state.last_image = components.video_processor.preprocess( + block_state.last_image, height=block_state.height, width=block_state.width + ).to(block_state.device, dtype=torch.float32) + + block_state.latent_condition = self._encode_vae_image( + components, + block_state.batch_size, + block_state.height, + block_state.width, + block_state.num_frames, + block_state.image, + block_state.device, + block_state.dtype, + block_state.last_image, + block_state.generator, + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index 5f4c1a983566..dea61c227cf9 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -21,13 +21,43 @@ WanSetTimestepsStep, ) from .decoders import WanDecodeStep -from .denoise import WanDenoiseStep -from .encoders import WanTextEncoderStep +from .denoise import WanDenoiseStep, WanI2VDenoiseStep +from .encoders import WanImageEncoderStep, WanTextEncoderStep, WanVaeEncoderStep logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class WanAutoImageEncoderStep(AutoPipelineBlocks): + block_classes = [WanImageEncoderStep] + block_names = ["image_encoder"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "Image encoder step that encodes the image inputs into a conditioning embedding.\n" + + "This is an auto pipeline block that works for both first-frame and first-last-frame conditioning tasks.\n" + + " - `WanImageEncoderStep` (image_encoder) is used when `image`, and possibly `last_image` is provided." + + " - if `image` is not provided, this step will be skipped." + ) + + +class WanAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [WanVaeEncoderStep] + block_names = ["img2vid"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for both first-frame and first-last-frame conditioning tasks.\n" + + " - `WanVaeEncoderStep` (img2vid) is used when `image`, and possibly `last_image` is provided." + + " - if `image` is not provided, this step will be skipped." + ) + + # before_denoise: text2vid class WanBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [ @@ -48,44 +78,72 @@ def description(self): ) -# before_denoise: all task (text2vid,) +# before_denoise: img2vid +class WanI2VBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + WanInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanImageEncoderStep, + WanVaeEncoderStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "image_encoder", "vae_encoder"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for image-to-video and first-last-frame-to-video tasks.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanImageEncoderStep` is used to encode the image inputs into a conditioning embedding\n" + + " - `WanVaeEncoderStep` is used to encode the image/last-image inputs into their latent representations\n" + ) + + +# before_denoise: all task (text2vid, img2vid) class WanAutoBeforeDenoiseStep(AutoPipelineBlocks): block_classes = [ WanBeforeDenoiseStep, + WanI2VBeforeDenoiseStep, ] - block_names = ["text2vid"] - block_trigger_inputs = [None] + block_names = ["text2vid", "img2vid"] + block_trigger_inputs = [None, "image"] @property def description(self): return ( "Before denoise step that prepare the inputs for the denoise step.\n" - + "This is an auto pipeline block that works for text2vid.\n" + + "This is an auto pipeline block that works for text2vid, img2vid, first-last-frame2vid.\n" + " - `WanBeforeDenoiseStep` (text2vid) is used.\n" + + " - `WanI2VBeforeDenoiseStep` (img2vid) is used when `image` is provided.\n" ) -# denoise: text2vid +# denoise: text2vid, img2vid class WanAutoDenoiseStep(AutoPipelineBlocks): block_classes = [ WanDenoiseStep, + WanI2VDenoiseStep, ] - block_names = ["denoise"] - block_trigger_inputs = [None] + block_names = ["denoise", "denoise_i2v"] + block_trigger_inputs = [None, "image"] @property def description(self) -> str: return ( "Denoise step that iteratively denoise the latents. " - "This is a auto pipeline block that works for text2vid tasks.." - " - `WanDenoiseStep` (denoise) for text2vid tasks." + "This is a auto pipeline block that works for text2vid and img2vid tasks..." + " - `WanDenoiseStep` (denoise) for text2vid task." + " - `WanI2VDenoiseStep` (denoise_i2v) for img2vid task, which is used when `image` is provided.\n" ) # decode: all task (text2img, img2img, inpainting) class WanAutoDecodeStep(AutoPipelineBlocks): block_classes = [WanDecodeStep] - block_names = ["non-inpaint"] + block_names = ["decode"] block_trigger_inputs = [None] @property @@ -116,6 +174,33 @@ def description(self): ) +# img2vid and first-last-frame2vid +class WanI2VAutoBlocks(SequentialPipelineBlocks): + block_classes = [ + WanTextEncoderStep, + WanAutoBeforeDenoiseStep, + WanImageEncoderStep, + WanAutoVaeEncoderStep, + WanAutoDenoiseStep, + WanAutoDecodeStep, + ] + block_names = [ + "text_encoder", + "before_denoise", + "image_encoder", + "vae_encoder", + "denoise", + "decoder", + ] + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-video using Wan.\n" + + "- for image-to-video and first-last-frame-to-video generation, you need to provide is `image`, and possibly `last_image`" + ) + + TEXT2VIDEO_BLOCKS = InsertableDict( [ ("text_encoder", WanTextEncoderStep), @@ -128,9 +213,25 @@ def description(self): ) +IMAGE2VIDEO_BLOCKS = InsertableDict( + [ + ("text_encoder", WanTextEncoderStep), + ("input", WanInputStep), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("image_encoder", WanImageEncoderStep), + ("vae_encoder", WanVaeEncoderStep), + ("denoise", WanI2VDenoiseStep), + ("decode", WanDecodeStep), + ] +) + + AUTO_BLOCKS = InsertableDict( [ ("text_encoder", WanTextEncoderStep), + ("image_encoder", WanAutoImageEncoderStep), + ("vae_encoder", WanAutoVaeEncoderStep), ("before_denoise", WanAutoBeforeDenoiseStep), ("denoise", WanAutoDenoiseStep), ("decode", WanAutoDecodeStep), @@ -140,5 +241,6 @@ def description(self): ALL_BLOCKS = { "text2video": TEXT2VIDEO_BLOCKS, + "image2video": IMAGE2VIDEO_BLOCKS, "auto": AUTO_BLOCKS, }