Skip to content

Commit 598ca27

Browse files
committed
update
1 parent aa2c37a commit 598ca27

File tree

2 files changed

+37
-17
lines changed

2 files changed

+37
-17
lines changed

src/diffusers/pipelines/wan/pipeline_wan_video2video.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,13 @@ def prepare_latents(
395395
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
396396
)
397397

398-
num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
398+
num_latent_frames = (
399+
(video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
400+
)
399401
shape = (
400402
batch_size,
401-
num_frames,
402403
num_channels_latents,
404+
num_latent_frames,
403405
height // self.vae_scale_factor_spatial,
404406
width // self.vae_scale_factor_spatial,
405407
)
@@ -412,10 +414,19 @@ def prepare_latents(
412414
else:
413415
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
414416

415-
init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
417+
init_latents = torch.cat(init_latents, dim=0).to(dtype)
418+
419+
latents_mean = (
420+
torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
421+
)
422+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
423+
device, dtype
424+
)
425+
426+
init_latents = (init_latents - latents_mean) * latents_std
416427

417428
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
418-
latents = self.scheduler.scale_noise(init_latents, timestep, noise)
429+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
419430
else:
420431
latents = latents.to(device)
421432

@@ -464,7 +475,6 @@ def __call__(
464475
negative_prompt: Union[str, List[str]] = None,
465476
height: int = 480,
466477
width: int = 832,
467-
num_frames: int = 81,
468478
num_inference_steps: int = 50,
469479
timesteps: Optional[List[int]] = None,
470480
guidance_scale: float = 5.0,
@@ -605,8 +615,9 @@ def __call__(
605615
self._num_timesteps = len(timesteps)
606616

607617
if latents is None:
608-
video = self.video_processor.preprocess_video(video, height=height, width=width)
609-
video = video.to(device=device, dtype=prompt_embeds.dtype)
618+
video = self.video_processor.preprocess_video(video, height=height, width=width).to(
619+
device, dtype=torch.float32
620+
)
610621

611622
# 5. Prepare latent variables
612623
num_channels_latents = self.transformer.config.in_channels

tests/pipelines/wan/test_wan_video_to_video.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
from PIL import Image
2121
from transformers import AutoTokenizer, T5EncoderModel
2222

23-
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanTransformer3DModel, WanVideoToVideoPipeline
23+
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanTransformer3DModel, WanVideoToVideoPipeline
2424
from diffusers.utils.testing_utils import (
2525
enable_full_determinism,
2626
require_torch_accelerator,
2727
slow,
2828
)
2929

30-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
30+
from ..pipeline_params import TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
3131
from ..test_pipelines_common import (
3232
PipelineTesterMixin,
3333
)
@@ -39,8 +39,7 @@
3939
class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
4040
pipeline_class = WanVideoToVideoPipeline
4141
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
42-
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
43-
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
42+
batch_params = frozenset(["video", "prompt", "negative_prompt"])
4443
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
4544
required_optional_params = frozenset(
4645
[
@@ -66,8 +65,7 @@ def get_dummy_components(self):
6665
)
6766

6867
torch.manual_seed(0)
69-
# TODO: impl FlowDPMSolverMultistepScheduler
70-
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
68+
scheduler = UniPCMultistepScheduler(flow_shift=3.0)
7169
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
7270
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
7371

@@ -102,7 +100,7 @@ def get_dummy_inputs(self, device, seed=0):
102100
else:
103101
generator = torch.Generator(device=device).manual_seed(seed)
104102

105-
video = [Image.new("RGB", (16, 16))] * 19
103+
video = [Image.new("RGB", (16, 16))] * 17
106104
inputs = {
107105
"video": video,
108106
"prompt": "dance monkey",
@@ -112,7 +110,6 @@ def get_dummy_inputs(self, device, seed=0):
112110
"guidance_scale": 6.0,
113111
"height": 16,
114112
"width": 16,
115-
"num_frames": 9,
116113
"max_sequence_length": 16,
117114
"output_type": "pt",
118115
}
@@ -130,15 +127,27 @@ def test_inference(self):
130127
video = pipe(**inputs).frames
131128
generated_video = video[0]
132129

133-
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
134-
expected_video = torch.randn(9, 3, 16, 16)
130+
self.assertEqual(generated_video.shape, (17, 3, 16, 16))
131+
expected_video = torch.randn(17, 3, 16, 16)
135132
max_diff = np.abs(generated_video - expected_video).max()
136133
self.assertLessEqual(max_diff, 1e10)
137134

138135
@unittest.skip("Test not supported")
139136
def test_attention_slicing_forward_pass(self):
140137
pass
141138

139+
@unittest.skip(
140+
"WanVideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors"
141+
)
142+
def test_float16_inference(self):
143+
pass
144+
145+
@unittest.skip(
146+
"WanVideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors"
147+
)
148+
def test_save_load_float16(self):
149+
pass
150+
142151

143152
@slow
144153
@require_torch_accelerator

0 commit comments

Comments
 (0)