Skip to content

Commit 20d738c

Browse files
committed
refactor image-to-video pipeline
1 parent 87b4b9e commit 20d738c

File tree

3 files changed

+64
-69
lines changed

3 files changed

+64
-69
lines changed

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,10 @@ def check_inputs(
300300
def prepare_latents(
301301
self,
302302
batch_size: int,
303-
num_channels_latents: 16,
304-
height: int = 720,
305-
width: int = 1280,
306-
num_latent_frames: int = 21,
303+
num_channels_latents: int = 16,
304+
height: int = 480,
305+
width: int = 832,
306+
num_frames: int = 81,
307307
dtype: Optional[torch.dtype] = None,
308308
device: Optional[torch.device] = None,
309309
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -312,6 +312,7 @@ def prepare_latents(
312312
if latents is not None:
313313
return latents.to(device=device, dtype=dtype)
314314

315+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
315316
shape = (
316317
batch_size,
317318
num_channels_latents,
@@ -358,8 +359,8 @@ def __call__(
358359
self,
359360
prompt: Union[str, List[str]] = None,
360361
negative_prompt: Union[str, List[str]] = None,
361-
height: int = 720,
362-
width: int = 1280,
362+
height: int = 480,
363+
width: int = 832,
363364
num_frames: int = 81,
364365
num_inference_steps: int = 50,
365366
guidance_scale: float = 5.0,
@@ -384,11 +385,11 @@ def __call__(
384385
prompt (`str` or `List[str]`, *optional*):
385386
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
386387
instead.
387-
height (`int`, defaults to `720`):
388+
height (`int`, defaults to `480`):
388389
The height in pixels of the generated image.
389-
width (`int`, defaults to `1280`):
390+
width (`int`, defaults to `832`):
390391
The width in pixels of the generated image.
391-
num_frames (`int`, defaults to `129`):
392+
num_frames (`int`, defaults to `81`):
392393
The number of frames in the generated video.
393394
num_inference_steps (`int`, defaults to `50`):
394395
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -492,14 +493,12 @@ def __call__(
492493

493494
# 5. Prepare latent variables
494495
num_channels_latents = self.transformer.config.in_channels
495-
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
496-
497496
latents = self.prepare_latents(
498497
batch_size * num_videos_per_prompt,
499498
num_channels_latents,
500499
height,
501500
width,
502-
num_latent_frames,
501+
num_frames,
503502
torch.float32,
504503
device,
505504
generator,

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1717

1818
import ftfy
19-
import numpy as np
2019
import PIL
2120
import regex as re
2221
import torch
@@ -165,7 +164,7 @@ def _get_t5_prompt_embeds(
165164
self,
166165
prompt: Union[str, List[str]] = None,
167166
num_videos_per_prompt: int = 1,
168-
max_sequence_length: int = 226,
167+
max_sequence_length: int = 512,
169168
device: Optional[torch.device] = None,
170169
dtype: Optional[torch.dtype] = None,
171170
):
@@ -292,15 +291,18 @@ def encode_prompt(
292291
def check_inputs(
293292
self,
294293
prompt,
294+
negative_prompt,
295295
image,
296-
max_area,
296+
height,
297+
width,
297298
prompt_embeds=None,
299+
negative_prompt_embeds=None,
298300
callback_on_step_end_tensor_inputs=None,
299301
):
300302
if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
301303
raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
302-
if max_area < 0:
303-
raise ValueError(f"`max_area` has to be positive but are {max_area}.")
304+
if height % 16 != 0 or width % 16 != 0:
305+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
304306

305307
if callback_on_step_end_tensor_inputs is not None and not all(
306308
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -314,43 +316,43 @@ def check_inputs(
314316
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
315317
" only forward one of the two."
316318
)
319+
elif negative_prompt is not None and negative_prompt_embeds is not None:
320+
raise ValueError(
321+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
322+
" only forward one of the two."
323+
)
317324
elif prompt is None and prompt_embeds is None:
318325
raise ValueError(
319326
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
320327
)
321328
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
322329
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
330+
elif negative_prompt is not None and (
331+
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
332+
):
333+
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
323334

324335
def prepare_latents(
325336
self,
326337
image: PipelineImageInput,
327338
batch_size: int,
328-
num_channels_latents: 32,
329-
height: int = 720,
330-
width: int = 1280,
331-
max_area: int = 720 * 1280,
339+
num_channels_latents: int = 16,
340+
height: int = 480,
341+
width: int = 832,
332342
num_frames: int = 81,
333-
num_latent_frames: int = 21,
334343
dtype: Optional[torch.dtype] = None,
335344
device: Optional[torch.device] = None,
336345
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
337346
latents: Optional[torch.Tensor] = None,
338347
) -> Tuple[torch.Tensor, torch.Tensor]:
339-
aspect_ratio = height / width
340-
mod_value = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
341-
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
342-
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
343-
344348
if latents is not None:
345349
return latents.to(device=device, dtype=dtype)
346350

347-
shape = (
348-
batch_size,
349-
num_channels_latents,
350-
num_latent_frames,
351-
int(height) // self.vae_scale_factor_spatial,
352-
int(width) // self.vae_scale_factor_spatial,
353-
)
351+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
352+
latent_height = height // self.vae_scale_factor_spatial
353+
latent_width = width // self.vae_scale_factor_spatial
354+
355+
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
354356
if isinstance(generator, list) and len(generator) != batch_size:
355357
raise ValueError(
356358
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -359,35 +361,25 @@ def prepare_latents(
359361

360362
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
361363

362-
image = self.video_processor.preprocess(image, height=height, width=width)[:, :, None]
364+
image = image.unsqueeze(2)
363365
video_condition = torch.cat(
364366
[image, torch.zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
365367
)
366368
video_condition = video_condition.to(device=device, dtype=dtype)
369+
367370
if isinstance(generator, list):
368371
latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator]
369372
latents = latent_condition = torch.cat(latent_condition)
370373
else:
371374
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
372375
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
373-
mask_lat_size = torch.ones(
374-
batch_size,
375-
1,
376-
num_frames,
377-
int(height) // self.vae_scale_factor_spatial,
378-
int(width) // self.vae_scale_factor_spatial,
379-
)
376+
377+
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
380378
mask_lat_size[:, :, list(range(1, num_frames))] = 0
381379
first_frame_mask = mask_lat_size[:, :, 0:1]
382380
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
383381
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
384-
mask_lat_size = mask_lat_size.view(
385-
batch_size,
386-
-1,
387-
self.vae_scale_factor_temporal,
388-
int(height) // self.vae_scale_factor_spatial,
389-
int(width) // self.vae_scale_factor_spatial,
390-
)
382+
mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
391383
mask_lat_size = mask_lat_size.transpose(1, 2)
392384
mask_lat_size = mask_lat_size.to(latent_condition.device)
393385

@@ -424,7 +416,8 @@ def __call__(
424416
image: PipelineImageInput,
425417
prompt: Union[str, List[str]] = None,
426418
negative_prompt: Union[str, List[str]] = None,
427-
max_area: int = 720 * 1280,
419+
height: int = 480,
420+
width: int = 832,
428421
num_frames: int = 81,
429422
num_inference_steps: int = 50,
430423
guidance_scale: float = 5.0,
@@ -451,9 +444,15 @@ def __call__(
451444
prompt (`str` or `List[str]`, *optional*):
452445
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
453446
instead.
454-
max_area (`int`, defaults to `1280 * 720`):
455-
The maximum area in pixels of the generated image.
456-
num_frames (`int`, defaults to `129`):
447+
negative_prompt (`str` or `List[str]`, *optional*):
448+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
449+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
450+
less than `1`).
451+
height (`int`, defaults to `480`):
452+
The height of the generated video.
453+
width (`int`, defaults to `832`):
454+
The width of the generated video.
455+
num_frames (`int`, defaults to `81`):
457456
The number of frames in the generated video.
458457
num_inference_steps (`int`, defaults to `50`):
459458
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -514,9 +513,12 @@ def __call__(
514513
# 1. Check inputs. Raise error if not correct
515514
self.check_inputs(
516515
prompt,
516+
negative_prompt,
517517
image,
518-
max_area,
518+
height,
519+
width,
519520
prompt_embeds,
521+
negative_prompt_embeds,
520522
callback_on_step_end_tensor_inputs,
521523
)
522524

@@ -548,36 +550,29 @@ def __call__(
548550
)
549551

550552
# Encode image embedding
551-
image_embeds = self.encode_image(image)
552-
image_embeds = image_embeds.repeat(batch_size, 1, 1)
553-
554553
transformer_dtype = self.transformer.dtype
555554
prompt_embeds = prompt_embeds.to(transformer_dtype)
556555
if negative_prompt_embeds is not None:
557556
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
557+
558+
image_embeds = self.encode_image(image)
559+
image_embeds = image_embeds.repeat(batch_size, 1, 1)
558560
image_embeds = image_embeds.to(transformer_dtype)
559561

560562
# 4. Prepare timesteps
561563
self.scheduler.set_timesteps(num_inference_steps, device=device)
562564
timesteps = self.scheduler.timesteps
563565

564-
if isinstance(image, torch.Tensor):
565-
height, width = image.shape[-2:]
566-
else:
567-
width, height = image.size
568-
569566
# 5. Prepare latent variables
570-
num_channels_latents = self.vae.config.z_dim
571-
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
567+
num_channels_latents = self.transformer.config.in_channels
568+
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
572569
latents, condition = self.prepare_latents(
573570
image,
574571
batch_size * num_videos_per_prompt,
575572
num_channels_latents,
576573
height,
577574
width,
578-
max_area,
579575
num_frames,
580-
num_latent_frames,
581576
torch.float32,
582577
device,
583578
generator,

tests/pipelines/wan/test_wan_image_to_video.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def get_dummy_inputs(self, device, seed=0):
125125
"image": image,
126126
"prompt": "dance monkey",
127127
"negative_prompt": "negative", # TODO
128-
"max_area": 1024,
128+
"height": image_height,
129+
"width": image_width,
129130
"generator": generator,
130131
"num_inference_steps": 2,
131132
"guidance_scale": 6.0,
@@ -147,8 +148,8 @@ def test_inference(self):
147148
video = pipe(**inputs).frames
148149
generated_video = video[0]
149150

150-
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
151-
expected_video = torch.randn(9, 3, 32, 32)
151+
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
152+
expected_video = torch.randn(9, 3, 16, 16)
152153
max_diff = np.abs(generated_video - expected_video).max()
153154
self.assertLessEqual(max_diff, 1e10)
154155

0 commit comments

Comments
 (0)