Skip to content

Commit 2f1f0cf

Browse files
committed
update
1 parent 9ea73b7 commit 2f1f0cf

File tree

5 files changed

+470
-36
lines changed

5 files changed

+470
-36
lines changed

src/diffusers/modular_pipelines/wan/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,16 @@
2121

2222
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2323
else:
24-
_import_structure["encoders"] = []
24+
_import_structure["encoders"] = ["WanTextEncoderStep"]
2525
_import_structure["modular_blocks"] = [
2626
"ALL_BLOCKS",
2727
"AUTO_BLOCKS",
2828
"TEXT2VIDEO_BLOCKS",
29+
"WanAutoBeforeDenoiseStep",
2930
"WanAutoBlocks",
31+
"WanAutoBlocks",
32+
"WanAutoDecodeStep",
33+
"WanAutoDenoiseStep",
3034
]
3135
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
3236

@@ -37,11 +41,15 @@
3741
except OptionalDependencyNotAvailable:
3842
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3943
else:
44+
from .encoders import WanTextEncoderStep
4045
from .modular_blocks import (
4146
ALL_BLOCKS,
4247
AUTO_BLOCKS,
4348
TEXT2VIDEO_BLOCKS,
49+
WanAutoBeforeDenoiseStep,
4450
WanAutoBlocks,
51+
WanAutoDecodeStep,
52+
WanAutoDenoiseStep,
4553
)
4654
from .modular_pipeline import WanModularPipeline
4755
else:

src/diffusers/modular_pipelines/wan/before_denoise.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import torch
1919

20-
from ...models import AutoencoderKLWan
2120
from ...schedulers import FlowMatchEulerDiscreteScheduler
2221
from ...utils import logging
2322
from ...utils.torch_utils import randn_tensor
@@ -230,7 +229,6 @@ def intermediate_outputs(self) -> List[OutputParam]:
230229
@torch.no_grad()
231230
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
232231
block_state = self.get_block_state(state)
233-
234232
block_state.device = components._execution_device
235233

236234
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
@@ -250,10 +248,7 @@ class WanPrepareLatentsStep(PipelineBlock):
250248

251249
@property
252250
def expected_components(self) -> List[ComponentSpec]:
253-
return [
254-
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
255-
ComponentSpec("vae", AutoencoderKLWan),
256-
]
251+
return []
257252

258253
@property
259254
def description(self) -> str:
@@ -262,11 +257,11 @@ def description(self) -> str:
262257
@property
263258
def inputs(self) -> List[InputParam]:
264259
return [
265-
InputParam("height"),
266-
InputParam("width"),
267-
InputParam("num_frames"),
268-
InputParam("latents"),
269-
InputParam("num_videos_per_prompt", default=1),
260+
InputParam("height", type_hint=int),
261+
InputParam("width", type_hint=int),
262+
InputParam("num_frames", type_hint=int),
263+
InputParam("latents", type_hint=Optional[torch.Tensor]),
264+
InputParam("num_videos_per_prompt", type_hint=int, default=1),
270265
]
271266

272267
@property
@@ -277,7 +272,7 @@ def intermediate_inputs(self) -> List[InputParam]:
277272
"batch_size",
278273
required=True,
279274
type_hint=int,
280-
description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt. Can be generated in input step.",
275+
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_videos_per_prompt`. Can be generated in input step.",
281276
),
282277
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
283278
]
@@ -343,17 +338,15 @@ def prepare_latents(
343338
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
344339
block_state = self.get_block_state(state)
345340

346-
if block_state.dtype is None:
347-
block_state.dtype = components.vae.dtype
348-
349-
block_state.device = components._execution_device
350-
351-
self.check_inputs(components, block_state)
352-
353341
block_state.height = block_state.height or components.default_height
354342
block_state.width = block_state.width or components.default_width
355343
block_state.num_frames = block_state.num_frames or components.default_num_frames
344+
block_state.device = components._execution_device
345+
block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
356346
block_state.num_channels_latents = components.num_channels_latents
347+
348+
self.check_inputs(components, block_state)
349+
357350
block_state.latents = self.prepare_latents(
358351
components,
359352
block_state.batch_size * block_state.num_videos_per_prompt,
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, List, Tuple, Union
16+
17+
import numpy as np
18+
import PIL
19+
import torch
20+
21+
from ...configuration_utils import FrozenDict
22+
from ...models import AutoencoderKLWan
23+
from ...utils import logging
24+
from ...video_processor import VideoProcessor
25+
from ..modular_pipeline import PipelineBlock, PipelineState
26+
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
27+
28+
29+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30+
31+
32+
class WanDecodeStep(PipelineBlock):
33+
model_name = "stable-diffusion-xl"
34+
35+
@property
36+
def expected_components(self) -> List[ComponentSpec]:
37+
return [
38+
ComponentSpec("vae", AutoencoderKLWan),
39+
ComponentSpec(
40+
"video_processor",
41+
VideoProcessor,
42+
config=FrozenDict({"vae_scale_factor": 8}),
43+
default_creation_method="from_config",
44+
),
45+
]
46+
47+
@property
48+
def description(self) -> str:
49+
return "Step that decodes the denoised latents into images"
50+
51+
@property
52+
def inputs(self) -> List[Tuple[str, Any]]:
53+
return [
54+
InputParam("output_type", default="pil"),
55+
]
56+
57+
@property
58+
def intermediate_inputs(self) -> List[str]:
59+
return [
60+
InputParam(
61+
"latents",
62+
required=True,
63+
type_hint=torch.Tensor,
64+
description="The denoised latents from the denoising step",
65+
)
66+
]
67+
68+
@property
69+
def intermediate_outputs(self) -> List[str]:
70+
return [
71+
OutputParam(
72+
"videos",
73+
type_hint=Union[List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]],
74+
description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
75+
)
76+
]
77+
78+
@torch.no_grad()
79+
def __call__(self, components, state: PipelineState) -> PipelineState:
80+
block_state = self.get_block_state(state)
81+
vae_dtype = components.vae.dtype
82+
83+
if not block_state.output_type == "latent":
84+
latents = block_state.latents
85+
latents_mean = (
86+
torch.tensor(components.vae.config.latents_mean)
87+
.view(1, components.vae.config.z_dim, 1, 1, 1)
88+
.to(latents.device, latents.dtype)
89+
)
90+
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
91+
1, components.vae.config.z_dim, 1, 1, 1
92+
).to(latents.device, latents.dtype)
93+
latents = latents / latents_std + latents_mean
94+
latents = latents.to(vae_dtype)
95+
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
96+
else:
97+
block_state.videos = block_state.latents
98+
99+
block_state.videos = components.video_processor.postprocess_video(
100+
block_state.videos, output_type=block_state.output_type
101+
)
102+
103+
self.set_block_state(state, block_state)
104+
105+
return components, state

0 commit comments

Comments
 (0)