Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/modular_pipelines/wan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"IMAGE2VIDEO_BLOCKS",
"TEXT2VIDEO_BLOCKS",
"WanAutoBeforeDenoiseStep",
"WanAutoBlocks",
"WanAutoBlocks",
"WanAutoDecodeStep",
"WanAutoDenoiseStep",
"WanAutoVaeEncoderStep",
]
_import_structure["modular_pipeline"] = ["WanModularPipeline"]

Expand All @@ -45,11 +47,13 @@
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE2VIDEO_BLOCKS,
TEXT2VIDEO_BLOCKS,
WanAutoBeforeDenoiseStep,
WanAutoBlocks,
WanAutoDecodeStep,
WanAutoDenoiseStep,
WanAutoVaeEncoderStep,
)
from .modular_pipeline import WanModularPipeline
else:
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/modular_pipelines/wan/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
)
),
OutputParam("height", type_hint=int),
OutputParam("width", type_hint=int),
OutputParam("num_frames", type_hint=int),
]

@staticmethod
Expand Down
184 changes: 180 additions & 4 deletions src/diffusers/modular_pipelines/wan/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,56 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class WanI2VLoopBeforeDenoiser(PipelineBlock):
model_name = "stable-diffusion-xl"

@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", UniPCMultistepScheduler),
]

@property
def description(self) -> str:
return (
"Step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanI2VDenoiseLoopWrapper`)"
)

@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process.",
),
InputParam(
"latent_condition",
required=True,
type_hint=torch.Tensor,
description="The latent condition to use for the denoising process.",
),
]

@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"concatenated_latents",
type_hint=torch.Tensor,
description="The concatenated noisy and conditioning latents to use for the denoising process.",
),
]

@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: int):
block_state.concatenated_latents = torch.cat([block_state.latents, block_state.latent_condition], dim=1)
return components, block_state


class WanLoopDenoiser(PipelineBlock):
model_name = "wan"

Expand Down Expand Up @@ -102,7 +152,7 @@ def __call__(
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)

# Prepare mini‐batches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
# Each guider_state_batch will have .prompt_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
Expand All @@ -120,7 +170,112 @@ def __call__(
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latents.to(transformer_dtype),
timestep=t.flatten(),
encoder_hidden_states=prompt_embeds,
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
)[0]
components.guider.cleanup_models(components.transformer)

# Perform guidance
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)

return components, block_state


class WanI2VLoopDenoiser(PipelineBlock):
model_name = "wan"

@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", WanTransformer3DModel),
]

@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)

@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("attention_kwargs"),
]

@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"concatenated_latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process.",
),
InputParam(
"encoder_hidden_states_image",
required=True,
type_hint=torch.Tensor,
description="The encoder hidden states for the image inputs.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process.",
),
InputParam(
kwargs_type="guider_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds. "
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]

@torch.no_grad()
def __call__(
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_input_fields = {
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
}
transformer_dtype = components.transformer.dtype

components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)

# Prepare mini‐batches according to guidance method and `guider_input_fields`
# Each guider_state_batch will have .prompt_embeds.
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)

# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
prompt_embeds = cond_kwargs.pop("prompt_embeds")

# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.concatenated_latents.to(transformer_dtype),
timestep=t.flatten(),
encoder_hidden_states=prompt_embeds.to(transformer_dtype),
encoder_hidden_states_image=block_state.encoder_hidden_states_image.to(transformer_dtype),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
)[0]
Expand Down Expand Up @@ -247,7 +402,7 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
WanLoopDenoiser,
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
block_names = ["denoiser", "after_denoiser"]

@property
def description(self) -> str:
Expand All @@ -257,5 +412,26 @@ def description(self) -> str:
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports both text2vid tasks."
"This block supports the text2vid task."
)


class WanI2VDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanI2VLoopBeforeDenoiser,
WanI2VLoopDenoiser,
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]

@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents with conditional first- and last-frame support. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `WanI2VLoopBeforeDenoiser`\n"
" - `WanI2VLoopDenoiser`\n"
" - `WanI2VLoopAfterDenoiser`\n"
"This block supports the image-to-video and first-last-frame-to-video tasks."
)
Loading
Loading