Skip to content

[modular diffusers] Wan I2V/FLF2V #11997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jul 27, 2025

repos:

code (i2v):

import numpy as np
import torch
from diffusers import AutoGuidance, SkipLayerGuidance, ClassifierFreeGuidance, SmoothedEnergyGuidance, SmoothedEnergyGuidanceConfig, AdaptiveProjectedGuidance, PerturbedAttentionGuidance, ClassifierFreeZeroStarGuidance, TangentialClassifierFreeGuidance, LayerSkipConfig
from diffusers.modular_pipelines import SequentialPipelineBlocks, ComponentSpec, ComponentsManager
from diffusers.modular_pipelines.wan import IMAGE2VIDEO_BLOCKS
from diffusers.utils.logging import set_verbosity_debug
from diffusers.utils import export_to_video, load_image

set_verbosity_debug()

model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"

blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2VIDEO_BLOCKS)

pipeline = blocks.init_pipeline()
pipeline.load_components(["text_encoder"], repo=model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
pipeline.load_components(["tokenizer"], repo=model_id, subfolder="tokenizer")
pipeline.load_components(["scheduler"], repo=model_id, subfolder="scheduler")
pipeline.load_components(["transformer"], repo=model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipeline.load_components(["vae"], repo=model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline.load_components(["image_encoder"], repo=model_id, subfolder="image_encoder", torch_dtype=torch.float32)
pipeline.load_components(["image_processor"], repo=model_id, subfolder="image_processor")
pipeline.to("cuda")

image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 480 * 832
aspect_ratio = image.height / image.width
height = round(np.sqrt(max_area * aspect_ratio)) // 16 * 16
width = round(np.sqrt(max_area / aspect_ratio)) // 16 * 16
image = image.resize((width, height))

for guider_cls in [
    AutoGuidance,
    SkipLayerGuidance,
    ClassifierFreeGuidance,
    SmoothedEnergyGuidance,
    AdaptiveProjectedGuidance,
    PerturbedAttentionGuidance,
    ClassifierFreeZeroStarGuidance,
    TangentialClassifierFreeGuidance,
]:
    print(f"Testing {guider_cls.__name__}...")
    
    kwargs = {"guidance_scale": 5.0}
    if guider_cls is AutoGuidance:
        kwargs.update({"auto_guidance_config": LayerSkipConfig(indices=[13], skip_attention=True, skip_ff=True, dropout=0.1)})
        kwargs.update({"stop": 0.8})
    elif guider_cls is SkipLayerGuidance:
        kwargs.update({"skip_layer_config": LayerSkipConfig(indices=[21], skip_attention=True, skip_ff=True)})
        kwargs.update({"skip_layer_guidance_scale": 1.5})
        kwargs.update({"skip_layer_guidance_stop": 0.3})
    elif guider_cls is SmoothedEnergyGuidance:
        kwargs.update({"seg_guidance_config": SmoothedEnergyGuidanceConfig(indices=[21])})
        kwargs.update({"seg_guidance_scale": 2.0})
        kwargs.update({"seg_guidance_stop": 0.4})
    elif guider_cls is PerturbedAttentionGuidance:
        kwargs.update({"perturbed_guidance_config": LayerSkipConfig(indices=[11, 12, 13], skip_attention=False, skip_attention_scores=True, skip_ff=False)})
        kwargs.update({"perturbed_guidance_scale": 2.0})
        kwargs.update({"perturbed_guidance_stop": 0.25})
    elif guider_cls is AdaptiveProjectedGuidance:
        kwargs["adaptive_projected_guidance_rescale"] = 40.0

    pipeline.update_components(
        guider=ComponentSpec(
            name="cfg",
            type_hint=guider_cls,
            config=kwargs,
            default_creation_method="from_config",
        )
    )

    prompt = (
        "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
        "the background. Almost suddenly, waves of water emerge from the surface of the moon, taking the newly hatched "
        "astronaut for a swim! High quality, ultrarealistic detail and breath-taking movie-like camera shot."
    )
    negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

    video = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=image, num_inference_steps=30, output="videos", generator=torch.Generator().manual_seed(0))[0]
    output_filename = f"output_guider_{guider_cls.__name__.lower()}.mp4"
    export_to_video(video, output_filename, fps=16)

results (i2v):

admittedly, not what i expected from the default prompt modification 🫠

CFG APG TCFG
output_guider_classifierfreeguidance.mp4
output_guider_adaptiveprojectedguidance.mp4
output_guider_tangentialclassifierfreeguidance.mp4
CFG-Zero* PAG AutoGuidance
output_guider_classifierfreezerostarguidance.mp4
output_guider_perturbedattentionguidance.mp4
output_guider_autoguidance.mp4

code (flf2v):

import numpy as np
import torch
import torchvision.transforms.functional as TF
from diffusers import AutoGuidance, SkipLayerGuidance, ClassifierFreeGuidance, SmoothedEnergyGuidance, SmoothedEnergyGuidanceConfig, AdaptiveProjectedGuidance, PerturbedAttentionGuidance, ClassifierFreeZeroStarGuidance, TangentialClassifierFreeGuidance, LayerSkipConfig
from diffusers.modular_pipelines import SequentialPipelineBlocks, ComponentSpec, ComponentsManager, PipelineBlock,  InputParam, OutputParam, WanModularPipeline
from diffusers.modular_pipelines.wan import IMAGE2VIDEO_BLOCKS
from diffusers.utils.logging import set_verbosity_debug
from diffusers.utils import export_to_video, load_image

set_verbosity_debug()


class PreprocessBlock(PipelineBlock):
    model_name = "wan"
    
    @property
    def description(self):
        return "Preprocess input image for the pipeline."

    @property
    def inputs(self):
        return [
            InputParam(name="image", required=True),
            InputParam(name="last_image", required=False),
        ]

    @property
    def intermediate_outputs(self):
        return [
            OutputParam(name="image"),
            OutputParam(name="last_image"),
        ]
    
    @staticmethod
    def aspect_ratio_resize(image, max_area=720 * 1280):
        aspect_ratio = image.height / image.width
        height = round(np.sqrt(max_area * aspect_ratio)) // 16 * 16
        width = round(np.sqrt(max_area / aspect_ratio)) // 16 * 16
        image = image.resize((width, height))
        return image, height, width

    @staticmethod
    def center_crop_resize(image, height, width):
        # Calculate resize ratio to match first frame dimensions
        resize_ratio = max(width / image.width, height / image.height)

        # Resize the image
        width = round(image.width * resize_ratio)
        height = round(image.height * resize_ratio)
        size = [width, height]
        image = TF.center_crop(image, size)

        return image, height, width

    def __call__(self, components, state):
        block_state = self.get_block_state(state)

        image = block_state.image
        block_state.image, height, width = self.aspect_ratio_resize(image)
        if block_state.last_image is not None and block_state.last_image.size != block_state.image.size:
            block_state.last_image, _, _ = self.center_crop_resize(block_state.last_image, height, width)

        self.set_block_state(state, block_state)
        return components, state


model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"

BLOCKS = IMAGE2VIDEO_BLOCKS.copy()
BLOCKS.insert("preprocess", PreprocessBlock(), 0)
blocks = SequentialPipelineBlocks.from_blocks_dict(BLOCKS)

pipeline = blocks.init_pipeline()
pipeline.load_components(["text_encoder"], repo=model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
pipeline.load_components(["tokenizer"], repo=model_id, subfolder="tokenizer")
pipeline.load_components(["scheduler"], repo=model_id, subfolder="scheduler")
pipeline.load_components(["transformer"], repo=model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipeline.load_components(["vae"], repo=model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline.load_components(["image_encoder"], repo=model_id, subfolder="image_encoder", torch_dtype=torch.float32)
pipeline.load_components(["image_processor"], repo=model_id, subfolder="image_processor")
pipeline.to("cuda")

height = 512
width = 512
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
last_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")

for guider_cls in [
    AutoGuidance,
    SkipLayerGuidance,
    ClassifierFreeGuidance,
    SmoothedEnergyGuidance,
    AdaptiveProjectedGuidance,
    PerturbedAttentionGuidance,
    ClassifierFreeZeroStarGuidance,
    TangentialClassifierFreeGuidance,
]:
    print(f"Testing {guider_cls.__name__}...")
    
    kwargs = {"guidance_scale": 5.0}
    if guider_cls is AutoGuidance:
        kwargs.update({"auto_guidance_config": LayerSkipConfig(indices=[13], skip_attention=True, skip_ff=True, dropout=0.1)})
        kwargs.update({"stop": 0.8})
    elif guider_cls is SkipLayerGuidance:
        kwargs.update({"skip_layer_config": LayerSkipConfig(indices=[21], skip_attention=True, skip_ff=True)})
        kwargs.update({"skip_layer_guidance_scale": 1.5})
        kwargs.update({"skip_layer_guidance_stop": 0.3})
    elif guider_cls is SmoothedEnergyGuidance:
        kwargs.update({"seg_guidance_config": SmoothedEnergyGuidanceConfig(indices=[21])})
        kwargs.update({"seg_guidance_scale": 2.0})
        kwargs.update({"seg_guidance_stop": 0.4})
    elif guider_cls is PerturbedAttentionGuidance:
        kwargs.update({"perturbed_guidance_config": LayerSkipConfig(indices=[11, 12, 13], skip_attention=False, skip_attention_scores=True, skip_ff=False)})
        kwargs.update({"perturbed_guidance_scale": 2.0})
        kwargs.update({"perturbed_guidance_stop": 0.25})
    elif guider_cls is AdaptiveProjectedGuidance:
        kwargs["adaptive_projected_guidance_rescale"] = 40.0

    pipeline.update_components(
        guider=ComponentSpec(
            name="cfg",
            type_hint=guider_cls,
            config=kwargs,
            default_creation_method="from_config",
        )
    )

    prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
    negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

    video = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=image, last_image=last_image, height=height, width=width, num_inference_steps=30, output="videos", generator=torch.Generator().manual_seed(0))[0]
    output_filename = f"output_guider_{guider_cls.__name__.lower()}.mp4"
    export_to_video(video, output_filename, fps=16)

results (flf2v):

TODO: code is currently running

code for saving pipelines:

import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks, ComponentSpec, ComponentsManager
from diffusers.modular_pipelines.wan import IMAGE2VIDEO_BLOCKS
from diffusers.utils.logging import set_verbosity_debug

set_verbosity_debug()

# model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"

blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2VIDEO_BLOCKS)

pipeline = blocks.init_pipeline()
pipeline.load_components(["text_encoder"], repo=model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
pipeline.load_components(["tokenizer"], repo=model_id, subfolder="tokenizer")
pipeline.load_components(["scheduler"], repo=model_id, subfolder="scheduler")
pipeline.load_components(["transformer"], repo=model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipeline.load_components(["vae"], repo=model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline.load_components(["image_encoder"], repo=model_id, subfolder="image_encoder", torch_dtype=torch.float32)
pipeline.load_components(["image_processor"], repo=model_id, subfolder="image_processor")

# pipeline.push_to_hub("diffusers-internal-dev/Modular-Wan-I2V-14B-480P-Diffusers")
pipeline.push_to_hub("diffusers-internal-dev/Modular-Wan-I2V-14B-720P-Diffusers")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review July 28, 2025 21:49
@a-r-r-o-w a-r-r-o-w requested review from yiyixuxu July 29, 2025 01:31
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking great! thanks @a-r-r-o-w
I left some comments, let me know what you think! from here I think we it's very easy to support wan 2.2!

return components, state


class WanVaeEncoderStep(PipelineBlock):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the vae enoder step should just encode the image and return the image_latents

the rest of logic should go into prepare_latents

This way it's more "modular", both for developing and using
e.g. at runtime, if you only want to change first or last frame you only need to encode one of them and use the image_latents directly; same if you want to change num_frames

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, I understand now. I missed the PrepareLatents nodes specific to img2img example in SDXL. Will implement it correctly soon

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu I was taking a look at this today. A little confused by what is entailed by "rest of the logic" here. Is it just the following part, or more?

        mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
        if last_image is None:
            mask_lat_size[:, :, list(range(1, num_frames))] = 0
        else:
            mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
        first_frame_mask = mask_lat_size[:, :, 0:1]
        first_frame_mask = torch.repeat_interleave(
            first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
        )
        mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
        mask_lat_size = mask_lat_size.view(
            batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
        )
        mask_lat_size = mask_lat_size.transpose(1, 2)
        mask_lat_size = mask_lat_size.to(latent_condition.device)
        latent_condition = torch.concat([mask_lat_size, latent_condition], dim=1)

Copy link
Collaborator

@yiyixuxu yiyixuxu Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, it was indeed confusing, I think I didn't not think it through before

How about

  1. have a encode_vae_video method that just encodes videos (after the image condition is pre-processeed and converted into 5D tensor) that can be imported from diffusers and used by different blocks, including the custom one( I started to refactor SDXL and I made similar one here for image )
  2. in WanVaeEncoderStep, it should include logic to create the viceo_condition and put it through encode_vae_video to encode into latent_condition; The logic to create mask should stay in here too since it is closely related in how the video_condition is created; we should make a different Wan*VaeEncoderStep for 5B IT2V
  3. In a separate prepare_latents should only include:
    1. generate randn_tensor,
    2. adjust the latent_condition and mask based on batch_size (or if you just want to keep the prepare_latents for 1, we can handle this logic somewhere else)

basically

  • if user want to increase num_videos_per_prompt , they should not need to encode images again
  • if you want to use a different initial noise, you should not need to encode again

let me know what you think

@yiyixuxu
Copy link
Collaborator

another comment w.r.t auto_blocks is I think we only need to pack workflows that can use the same checkpoint into the same package, e.g. for 5B the i2v and t2v can be packaged into an autoblocks; but not for 14B t2v and i2v we do not need to combine them because you need to load diffeerent checkpoint anyway and we should be able map the checkpoint to corresponding blocks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants