Skip to content

Commit 183bcd5

Browse files
committed
update
1 parent 3d2f8ae commit 183bcd5

File tree

2 files changed

+279
-3
lines changed

2 files changed

+279
-3
lines changed

src/diffusers/modular_pipelines/wan/encoders.py

Lines changed: 248 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717

1818
import regex as re
1919
import torch
20-
from transformers import AutoTokenizer, UMT5EncoderModel
20+
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
2121

2222
from ...configuration_utils import FrozenDict
2323
from ...guiders import ClassifierFreeGuidance
24+
from ...image_processor import PipelineImageInput
25+
from ...models import AutoencoderKLWan
2426
from ...utils import is_ftfy_available, logging
27+
from ...video_processor import VideoProcessor
2528
from ..modular_pipeline import PipelineBlock, PipelineState
2629
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
2730
from .modular_pipeline import WanModularPipeline
@@ -51,6 +54,20 @@ def prompt_clean(text):
5154
return text
5255

5356

57+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
58+
def retrieve_latents(
59+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
60+
):
61+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
62+
return encoder_output.latent_dist.sample(generator)
63+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
64+
return encoder_output.latent_dist.mode()
65+
elif hasattr(encoder_output, "latents"):
66+
return encoder_output.latents
67+
else:
68+
raise AttributeError("Could not access latents of provided encoder_output")
69+
70+
5471
class WanTextEncoderStep(PipelineBlock):
5572
model_name = "wan"
5673

@@ -240,3 +257,233 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
240257
# Add outputs
241258
self.set_block_state(state, block_state)
242259
return components, state
260+
261+
262+
class WanImageEncodeStep(PipelineBlock):
263+
model_name = "wan"
264+
265+
@property
266+
def description(self) -> str:
267+
return "Image Encoder step to compute image embeddings to guide the video generation"
268+
269+
@property
270+
def expected_components(self) -> List[ComponentSpec]:
271+
return [
272+
ComponentSpec("image_encoder", CLIPVisionModel),
273+
ComponentSpec("image_processor", CLIPImageProcessor),
274+
]
275+
276+
@property
277+
def expected_configs(self) -> List[ConfigSpec]:
278+
return []
279+
280+
@property
281+
def inputs(self) -> List[InputParam]:
282+
return [
283+
InputParam(
284+
"image",
285+
required=True,
286+
description="The input image to condition the generation on for first-frame conditioned video generation.",
287+
),
288+
InputParam(
289+
"last_image",
290+
required=False,
291+
description="The last image to condition the generation on for last-frame conditioned video generation.",
292+
),
293+
]
294+
295+
@property
296+
def intermediate_outputs(self) -> List[OutputParam]:
297+
return [
298+
OutputParam(
299+
"encoder_hidden_states_image",
300+
type_hint=torch.Tensor,
301+
description="image embeddings used to guide the image generation",
302+
),
303+
]
304+
305+
@staticmethod
306+
def check_inputs(block_state):
307+
if not isinstance(block_state.image, PipelineImageInput):
308+
raise ValueError(f"`image` has to be of type `PipelineImageInput` but is {type(block_state.image)}.")
309+
if block_state.last_image is not None and not isinstance(block_state.last_image, PipelineImageInput):
310+
raise ValueError(
311+
f"`last_image` has to be of type `PipelineImageInput` but is {type(block_state.last_image)}."
312+
)
313+
314+
@staticmethod
315+
def encode_image(
316+
components,
317+
image: PipelineImageInput,
318+
device: torch.device,
319+
):
320+
image = components.image_processor(images=image, return_tensors="pt").to(device)
321+
image_embeds = components.image_encoder(**image, output_hidden_states=True)
322+
return image_embeds.hidden_states[-2]
323+
324+
@torch.no_grad()
325+
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
326+
# Get inputs and intermediates
327+
block_state = self.get_block_state(state)
328+
self.check_inputs(block_state)
329+
330+
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
331+
block_state.device = components._execution_device
332+
333+
# Encode input images
334+
image = block_state.image
335+
if block_state.last_image is not None:
336+
image = [block_state.image, block_state.last_image]
337+
338+
block_state.encoder_hidden_states_image = self.encode_image(components, image, block_state.device)
339+
340+
# Add outputs
341+
self.set_block_state(state, block_state)
342+
return components, state
343+
344+
345+
class WanVaeEncoderStep(PipelineBlock):
346+
model_name = "wan"
347+
348+
@property
349+
def description(self) -> str:
350+
return (
351+
"VAE encode step that encodes the input image/last_image to latents for conditioning the video generation"
352+
)
353+
354+
@property
355+
def expected_components(self) -> List[ComponentSpec]:
356+
return [
357+
ComponentSpec("vae", AutoencoderKLWan),
358+
ComponentSpec(
359+
"video_processor",
360+
VideoProcessor,
361+
config=FrozenDict({"vae_scale_factor": 8}),
362+
default_creation_method="from_config",
363+
),
364+
]
365+
366+
@property
367+
def inputs(self) -> List[InputParam]:
368+
return [
369+
InputParam("image", required=True),
370+
InputParam("last_image", required=False),
371+
InputParam("height", type_hint=int),
372+
InputParam("width", type_hint=int),
373+
InputParam("num_frames", type_hint=int),
374+
]
375+
376+
@property
377+
def intermediate_inputs(self) -> List[InputParam]:
378+
return [
379+
InputParam("num_channels_latents", type_hint=int),
380+
InputParam("generator"),
381+
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
382+
]
383+
384+
@property
385+
def intermediate_outputs(self) -> List[OutputParam]:
386+
return [
387+
OutputParam(
388+
"latent_condition",
389+
type_hint=torch.Tensor,
390+
description="The latents representing the reference first-frame/last-frame for conditioned video generation.",
391+
)
392+
]
393+
394+
def _encode_vae_image(
395+
self,
396+
components: WanModularPipeline,
397+
batch_size: int,
398+
height: int,
399+
width: int,
400+
num_frames: int,
401+
image: torch.Tensor,
402+
device: torch.device,
403+
dtype: torch.dtype,
404+
last_image: Optional[torch.Tensor] = None,
405+
generator: Optional[torch.Generator] = None,
406+
):
407+
latent_height = height // self.vae_scale_factor_spatial
408+
latent_width = width // self.vae_scale_factor_spatial
409+
410+
latents_mean = (
411+
torch.tensor(components.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
412+
)
413+
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
414+
1, components.vae.config.z_dim, 1, 1, 1
415+
).to(device, dtype)
416+
417+
image = image.unsqueeze(2)
418+
if last_image is None:
419+
video_condition = torch.cat(
420+
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
421+
)
422+
else:
423+
last_image = last_image.unsqueeze(2)
424+
video_condition = torch.cat(
425+
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
426+
dim=2,
427+
)
428+
video_condition = video_condition.to(device=device, dtype=dtype)
429+
430+
if isinstance(generator, list):
431+
latent_condition = [
432+
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
433+
]
434+
latent_condition = torch.cat(latent_condition)
435+
else:
436+
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
437+
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
438+
439+
latent_condition = latent_condition.to(dtype)
440+
latent_condition = (latent_condition - latents_mean) * latents_std
441+
442+
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
443+
if last_image is None:
444+
mask_lat_size[:, :, list(range(1, num_frames))] = 0
445+
else:
446+
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
447+
first_frame_mask = mask_lat_size[:, :, 0:1]
448+
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
449+
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
450+
mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
451+
mask_lat_size = mask_lat_size.transpose(1, 2)
452+
mask_lat_size = mask_lat_size.to(latent_condition.device)
453+
latent_condition = torch.concat([mask_lat_size, latent_condition], dim=1)
454+
455+
return latent_condition
456+
457+
@torch.no_grad()
458+
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
459+
block_state = self.get_block_state(state)
460+
block_state.device = components._execution_device
461+
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
462+
block_state.num_channels_latents = self.vae.config.z_dim
463+
block_state.batch_size = (
464+
block_state.batch_size if block_state.batch_size is not None else block_state.image.shape[0]
465+
)
466+
467+
block_state.image = self.video_processor.preprocess(
468+
block_state.image, height=block_state.height, width=block_state.width
469+
).to(block_state.device, dtype=torch.float32)
470+
if block_state.last_image is not None:
471+
block_state.last_image = self.video_processor.preprocess(
472+
block_state.last_image, height=block_state.height, width=block_state.width
473+
).to(block_state.device, dtype=torch.float32)
474+
475+
block_state.latent_condition = self._encode_vae_image(
476+
components,
477+
batch_size=block_state.batch_size,
478+
height=block_state.height,
479+
width=block_state.width,
480+
num_frames=block_state.num_frames,
481+
image=block_state.image,
482+
device=block_state.device,
483+
dtype=block_state.dtype,
484+
last_image=block_state.last_image,
485+
generator=block_state.generator,
486+
)
487+
488+
self.set_block_state(state, block_state)
489+
return components, state

src/diffusers/modular_pipelines/wan/modular_blocks.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,27 @@
2222
)
2323
from .decoders import WanDecodeStep
2424
from .denoise import WanDenoiseStep
25-
from .encoders import WanTextEncoderStep
25+
from .encoders import WanTextEncoderStep, WanVaeEncoderStep
2626

2727

2828
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2929

3030

31+
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
32+
block_classes = [WanVaeEncoderStep]
33+
block_names = ["img2vid"]
34+
block_trigger_inputs = ["image"]
35+
36+
@property
37+
def description(self):
38+
return (
39+
"Vae encoder step that encode the image inputs into their latent representations.\n"
40+
+ "This is an auto pipeline block that works for both first-frame and first-last-frame conditioning tasks.\n"
41+
+ " - `WanVaeEncoderStep` (img2vid) is used when `image`, and possibly `last_image` is provided."
42+
+ " - if `image` is provided, this step will be skipped."
43+
)
44+
45+
3146
# before_denoise: text2vid
3247
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
3348
block_classes = [
@@ -97,6 +112,7 @@ def description(self):
97112
class WanAutoBlocks(SequentialPipelineBlocks):
98113
block_classes = [
99114
WanTextEncoderStep,
115+
WanAutoVaeEncoderStep,
100116
WanAutoBeforeDenoiseStep,
101117
WanAutoDenoiseStep,
102118
WanAutoDecodeStep,
@@ -128,10 +144,23 @@ def description(self):
128144
)
129145

130146

147+
IMAGE2VIDEO_BLOCKS = InsertableDict(
148+
[
149+
("text_encoder", WanTextEncoderStep),
150+
("input", WanInputStep),
151+
("image_encoder", WanVaeEncoderStep),
152+
("set_timesteps", WanSetTimestepsStep),
153+
("prepare_latents", WanPrepareLatentsStep),
154+
("denoise", WanDenoiseStep),
155+
("decode", WanDecodeStep),
156+
]
157+
)
158+
159+
131160
AUTO_BLOCKS = InsertableDict(
132161
[
133162
("text_encoder", WanTextEncoderStep),
134-
("before_denoise", WanAutoBeforeDenoiseStep),
163+
("image_encoder", WanAutoVaeEncoderStep)("before_denoise", WanAutoBeforeDenoiseStep),
135164
("denoise", WanAutoDenoiseStep),
136165
("decode", WanAutoDecodeStep),
137166
]

0 commit comments

Comments
 (0)