Skip to content

[modular diffusers] Wan I2V/FLF2V #11997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/modular_pipelines/wan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"IMAGE2VIDEO_BLOCKS",
"TEXT2VIDEO_BLOCKS",
"WanAutoBeforeDenoiseStep",
"WanAutoBlocks",
"WanAutoBlocks",
"WanAutoDecodeStep",
"WanAutoDenoiseStep",
"WanAutoVaeEncoderStep",
]
_import_structure["modular_pipeline"] = ["WanModularPipeline"]

Expand All @@ -45,11 +47,13 @@
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE2VIDEO_BLOCKS,
TEXT2VIDEO_BLOCKS,
WanAutoBeforeDenoiseStep,
WanAutoBlocks,
WanAutoDecodeStep,
WanAutoDenoiseStep,
WanAutoVaeEncoderStep,
)
from .modular_pipeline import WanModularPipeline
else:
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/modular_pipelines/wan/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
)
),
OutputParam("height", type_hint=int),
OutputParam("width", type_hint=int),
OutputParam("num_frames", type_hint=int),
]

@staticmethod
Expand Down
184 changes: 180 additions & 4 deletions src/diffusers/modular_pipelines/wan/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,56 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class WanI2VLoopBeforeDenoiser(PipelineBlock):
model_name = "stable-diffusion-xl"

@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", UniPCMultistepScheduler),
]

@property
def description(self) -> str:
return (
"Step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanI2VDenoiseLoopWrapper`)"
)

@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process.",
),
InputParam(
"latent_condition",
required=True,
type_hint=torch.Tensor,
description="The latent condition to use for the denoising process.",
),
]

@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"concatenated_latents",
type_hint=torch.Tensor,
description="The concatenated noisy and conditioning latents to use for the denoising process.",
),
]

@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int):
block_state.concatenated_latents = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
return components, block_state


class WanLoopDenoiser(PipelineBlock):
model_name = "wan"

Expand Down Expand Up @@ -102,7 +152,7 @@ def __call__(
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)

# Prepare mini‐batches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# Each guider_state_batch will have .prompt_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
Expand All @@ -120,7 +170,112 @@ def __call__(
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latents.to(transformer_dtype),
timestep=t.flatten(),
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
)[0]
components.guider.cleanup_models(components.transformer)

# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)

return components, block_state


class WanI2VLoopDenoiser(PipelineBlock):
model_name = "wan"

@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", WanTransformer3DModel),
]

@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)

@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("attention_kwargs"),
]

@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"concatenated_latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process.",
),
InputParam(
"encoder_hidden_states_image",
required=True,
type_hint=torch.Tensor,
description="The encoder hidden states for the image inputs.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds. "
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]

@torch.no_grad()
def __call__(
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
}
transformer_dtype = components.transformer.dtype

components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)

# Prepare mini‐batches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)

# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds")

# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.concatenated_latents.to(transformer_dtype),
timestep=t.flatten(),
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
encoder_hidden_states_image=block_state.encoder_hidden_states_image.to(transformer_dtype),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
)[0]
Expand Down Expand Up @@ -247,7 +402,7 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
WanLoopDenoiser,
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
block_names = ["denoiser", "after_denoiser"]

@property
def description(self) -> str:
Expand All @@ -257,5 +412,26 @@ def description(self) -> str:
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports both text2vid tasks."
"This block supports the text2vid task."
)


class WanI2VDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanI2VLoopBeforeDenoiser,
WanI2VLoopDenoiser,
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]

@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents with conditional first- and last-frame support. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `WanI2VLoopBeforeDenoiser`\n"
" - `WanI2VLoopDenoiser`\n"
" - `WanI2VLoopAfterDenoiser`\n"
"This block supports the image-to-video and first-last-frame-to-video tasks."
)
Loading
Loading