Skip to content

Commit a9768d2

Browse files
committed
tests
1 parent 157a24d commit a9768d2

File tree

6 files changed

+171
-31
lines changed

6 files changed

+171
-31
lines changed

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,7 @@ class WanPipeline(DiffusionPipeline):
114114
"""
115115

116116
model_cpu_offload_seq = "text_encoder->transformer->vae"
117-
_callback_tensor_inputs = [
118-
"latents",
119-
"prompt_embeds",
120-
"negative_prompt_embeds",
121-
]
117+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
122118

123119
def __init__(
124120
self,

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,7 @@ class WanI2VPipeline(DiffusionPipeline):
153153
"""
154154

155155
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
156-
_callback_tensor_inputs = [
157-
"latents",
158-
"prompt_embeds",
159-
"negative_prompt_embeds",
160-
]
156+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
161157

162158
def __init__(
163159
self,
@@ -231,7 +227,7 @@ def _get_t5_prompt_embeds(
231227
def encode_image(self, image: PipelineImageInput):
232228
image = self.image_processor(images=image, return_tensors="pt").to(self.device)
233229
image_embeds = self.image_encoder(**image, output_hidden_states=True)
234-
return image_embeds.hidden_states[31]
230+
return image_embeds.hidden_states[-1]
235231

236232
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
237233
def encode_prompt(
@@ -392,7 +388,7 @@ def prepare_latents(
392388
video_condition = video_condition.to(device=device, dtype=dtype)
393389
if isinstance(generator, list):
394390
latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator]
395-
latents = torch.stack(latent_condition)
391+
latents = latent_condition = torch.cat(latent_condition)
396392
else:
397393
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
398394
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
@@ -474,7 +470,7 @@ def __call__(
474470
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
475471
instead.
476472
max_area (`int`, defaults to `1280 * 720`):
477-
The maximum area in pixels of the generated image. The width in pixels of the generated image.
473+
The maximum area in pixels of the generated image.
478474
num_frames (`int`, defaults to `129`):
479475
The number of frames in the generated video.
480476
num_inference_steps (`int`, defaults to `50`):
@@ -570,7 +566,8 @@ def __call__(
570566

571567
transformer_dtype = self.transformer.dtype
572568
prompt_embeds = prompt_embeds.to(transformer_dtype)
573-
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
569+
if negative_prompt_embeds is not None:
570+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
574571
image_embeds = image_embeds.to(transformer_dtype)
575572

576573
# 4. Prepare timesteps

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def from_pretrained(cls, *args, **kwargs):
111111
requires_backends(cls, ["torch"])
112112

113113

114-
class AutoencoderKLWan(metaclass=DummyObject):
114+
class AutoencoderKLAllegro(metaclass=DummyObject):
115115
_backends = ["torch"]
116116

117117
def __init__(self, *args, **kwargs):
@@ -126,7 +126,7 @@ def from_pretrained(cls, *args, **kwargs):
126126
requires_backends(cls, ["torch"])
127127

128128

129-
class AutoencoderKLAllegro(metaclass=DummyObject):
129+
class AutoencoderKLCogVideoX(metaclass=DummyObject):
130130
_backends = ["torch"]
131131

132132
def __init__(self, *args, **kwargs):
@@ -141,7 +141,7 @@ def from_pretrained(cls, *args, **kwargs):
141141
requires_backends(cls, ["torch"])
142142

143143

144-
class AutoencoderKLCogVideoX(metaclass=DummyObject):
144+
class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
145145
_backends = ["torch"]
146146

147147
def __init__(self, *args, **kwargs):
@@ -156,7 +156,7 @@ def from_pretrained(cls, *args, **kwargs):
156156
requires_backends(cls, ["torch"])
157157

158158

159-
class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
159+
class AutoencoderKLLTXVideo(metaclass=DummyObject):
160160
_backends = ["torch"]
161161

162162
def __init__(self, *args, **kwargs):
@@ -171,7 +171,7 @@ def from_pretrained(cls, *args, **kwargs):
171171
requires_backends(cls, ["torch"])
172172

173173

174-
class AutoencoderKLLTXVideo(metaclass=DummyObject):
174+
class AutoencoderKLMochi(metaclass=DummyObject):
175175
_backends = ["torch"]
176176

177177
def __init__(self, *args, **kwargs):
@@ -186,7 +186,7 @@ def from_pretrained(cls, *args, **kwargs):
186186
requires_backends(cls, ["torch"])
187187

188188

189-
class AutoencoderKLMochi(metaclass=DummyObject):
189+
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
190190
_backends = ["torch"]
191191

192192
def __init__(self, *args, **kwargs):
@@ -201,7 +201,7 @@ def from_pretrained(cls, *args, **kwargs):
201201
requires_backends(cls, ["torch"])
202202

203203

204-
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
204+
class AutoencoderKLWan(metaclass=DummyObject):
205205
_backends = ["torch"]
206206

207207
def __init__(self, *args, **kwargs):

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2552,7 +2552,7 @@ def from_pretrained(cls, *args, **kwargs):
25522552
requires_backends(cls, ["torch", "transformers"])
25532553

25542554

2555-
class WuerstchenCombinedPipeline(metaclass=DummyObject):
2555+
class WanI2VPipeline(metaclass=DummyObject):
25562556
_backends = ["torch", "transformers"]
25572557

25582558
def __init__(self, *args, **kwargs):
@@ -2567,7 +2567,7 @@ def from_pretrained(cls, *args, **kwargs):
25672567
requires_backends(cls, ["torch", "transformers"])
25682568

25692569

2570-
class WuerstchenDecoderPipeline(metaclass=DummyObject):
2570+
class WanPipeline(metaclass=DummyObject):
25712571
_backends = ["torch", "transformers"]
25722572

25732573
def __init__(self, *args, **kwargs):
@@ -2582,7 +2582,7 @@ def from_pretrained(cls, *args, **kwargs):
25822582
requires_backends(cls, ["torch", "transformers"])
25832583

25842584

2585-
class WuerstchenPriorPipeline(metaclass=DummyObject):
2585+
class WuerstchenCombinedPipeline(metaclass=DummyObject):
25862586
_backends = ["torch", "transformers"]
25872587

25882588
def __init__(self, *args, **kwargs):
@@ -2597,7 +2597,7 @@ def from_pretrained(cls, *args, **kwargs):
25972597
requires_backends(cls, ["torch", "transformers"])
25982598

25992599

2600-
class WanPipeline(metaclass=DummyObject):
2600+
class WuerstchenDecoderPipeline(metaclass=DummyObject):
26012601
_backends = ["torch", "transformers"]
26022602

26032603
def __init__(self, *args, **kwargs):
@@ -2612,7 +2612,7 @@ def from_pretrained(cls, *args, **kwargs):
26122612
requires_backends(cls, ["torch", "transformers"])
26132613

26142614

2615-
class WanI2VPipelin(metaclass=DummyObject):
2615+
class WuerstchenPriorPipeline(metaclass=DummyObject):
26162616
_backends = ["torch", "transformers"]
26172617

26182618
def __init__(self, *args, **kwargs):

tests/models/autoencoders/test_models_autoencoder_wan.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@
1616
import unittest
1717

1818
from diffusers import AutoencoderKLWan
19-
from diffusers.utils.testing_utils import (
20-
enable_full_determinism,
21-
floats_tensor,
22-
torch_device
23-
)
19+
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
2420

2521
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
2622

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright 2024 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import torch
19+
from PIL import Image
20+
from transformers import AutoTokenizer, T5EncoderModel, CLIPVisionConfig, CLIPVisionModel, CLIPImageProcessor
21+
22+
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanI2VPipeline, WanTransformer3DModel
23+
from diffusers.utils.testing_utils import enable_full_determinism
24+
25+
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
26+
from ..test_pipelines_common import PipelineTesterMixin
27+
28+
29+
enable_full_determinism()
30+
31+
32+
class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
33+
pipeline_class = WanI2VPipeline
34+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
35+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
36+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
37+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
38+
required_optional_params = frozenset(
39+
[
40+
"num_inference_steps",
41+
"generator",
42+
"latents",
43+
"return_dict",
44+
"callback_on_step_end",
45+
"callback_on_step_end_tensor_inputs",
46+
]
47+
)
48+
test_xformers_attention = False
49+
supports_dduf = False
50+
51+
def get_dummy_components(self):
52+
torch.manual_seed(0)
53+
vae = AutoencoderKLWan(
54+
base_dim=3,
55+
z_dim=16,
56+
dim_mult=[1, 1, 1, 1],
57+
num_res_blocks=1,
58+
temperal_downsample=[False, True, True],
59+
)
60+
61+
torch.manual_seed(0)
62+
# TODO: impl FlowDPMSolverMultistepScheduler
63+
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
64+
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
65+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
66+
67+
torch.manual_seed(0)
68+
transformer = WanTransformer3DModel(
69+
patch_size=(1, 2, 2),
70+
num_attention_heads=2,
71+
attention_head_dim=12,
72+
in_channels=36,
73+
out_channels=16,
74+
text_dim=32,
75+
freq_dim=256,
76+
ffn_dim=32,
77+
num_layers=2,
78+
cross_attn_norm=True,
79+
qk_norm="rms_norm_across_heads",
80+
rope_max_seq_len=32,
81+
image_embedding_dim=4,
82+
)
83+
84+
torch.manual_seed(0)
85+
image_encoder_config = CLIPVisionConfig(
86+
hidden_size=4,
87+
projection_dim=4,
88+
num_hidden_layers=2,
89+
num_attention_heads=2,
90+
image_size=32,
91+
intermediate_size=16,
92+
patch_size=1,
93+
)
94+
image_encoder = CLIPVisionModel(image_encoder_config)
95+
96+
torch.manual_seed(0)
97+
image_processor = CLIPImageProcessor(crop_size=32, size=32)
98+
99+
components = {
100+
"transformer": transformer,
101+
"vae": vae,
102+
"scheduler": scheduler,
103+
"text_encoder": text_encoder,
104+
"tokenizer": tokenizer,
105+
"image_encoder": image_encoder,
106+
"image_processor": image_processor,
107+
}
108+
return components
109+
110+
def get_dummy_inputs(self, device, seed=0):
111+
if str(device).startswith("mps"):
112+
generator = torch.manual_seed(seed)
113+
else:
114+
generator = torch.Generator(device=device).manual_seed(seed)
115+
image_height = 16
116+
image_width = 16
117+
image = Image.new("RGB", (image_width, image_height))
118+
inputs = {
119+
"image": image,
120+
"prompt": "dance monkey",
121+
"negative_prompt": "negative", # TODO
122+
"max_area": 1024,
123+
"generator": generator,
124+
"num_inference_steps": 2,
125+
"guidance_scale": 6.0,
126+
"num_frames": 9,
127+
"max_sequence_length": 16,
128+
"output_type": "pt",
129+
}
130+
return inputs
131+
132+
def test_inference(self):
133+
device = "cpu"
134+
135+
components = self.get_dummy_components()
136+
pipe = self.pipeline_class(**components)
137+
pipe.to(device)
138+
pipe.set_progress_bar_config(disable=None)
139+
140+
inputs = self.get_dummy_inputs(device)
141+
video = pipe(**inputs).frames
142+
generated_video = video[0]
143+
144+
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
145+
expected_video = torch.randn(9, 3, 32, 32)
146+
max_diff = np.abs(generated_video - expected_video).max()
147+
self.assertLessEqual(max_diff, 1e10)
148+
149+
@unittest.skip("Test not supported")
150+
def test_attention_slicing_forward_pass(self):
151+
pass

0 commit comments

Comments
 (0)