Skip to content

Commit 157a24d

Browse files
committed
update tests
1 parent 425a3b0 commit 157a24d

File tree

4 files changed

+239
-261
lines changed

4 files changed

+239
-261
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
from diffusers import AutoencoderKLWan
19+
from diffusers.utils.testing_utils import (
20+
enable_full_determinism,
21+
floats_tensor,
22+
torch_device
23+
)
24+
25+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
26+
27+
28+
enable_full_determinism()
29+
30+
31+
class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
32+
model_class = AutoencoderKLWan
33+
main_input_name = "sample"
34+
base_precision = 1e-2
35+
36+
def get_autoencoder_kl_wan_config(self):
37+
return {
38+
"base_dim": 3,
39+
"z_dim": 16,
40+
"dim_mult": [1, 1, 1, 1],
41+
"num_res_blocks": 1,
42+
"temperal_downsample": [False, True, True],
43+
}
44+
45+
@property
46+
def dummy_input(self):
47+
batch_size = 2
48+
num_frames = 9
49+
num_channels = 3
50+
sizes = (16, 16)
51+
52+
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
53+
54+
return {"sample": image}
55+
56+
@property
57+
def input_shape(self):
58+
return (3, 9, 16, 16)
59+
60+
@property
61+
def output_shape(self):
62+
return (3, 9, 16, 16)
63+
64+
def prepare_init_args_and_inputs_for_common(self):
65+
init_dict = self.get_autoencoder_kl_wan_config()
66+
inputs_dict = self.dummy_input
67+
return init_dict, inputs_dict
68+
69+
@unittest.skip("Gradient checkpointing has not been implemented yet")
70+
def test_gradient_checkpointing_is_applied(self):
71+
pass
72+
73+
@unittest.skip("Test not supported")
74+
def test_forward_with_norm_groups(self):
75+
pass
76+
77+
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
78+
def test_layerwise_casting_inference(self):
79+
pass
80+
81+
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
82+
def test_layerwise_casting_training(self):
83+
pass
File renamed without changes.

tests/pipelines/wan/test_wan.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2024 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import gc
16+
import unittest
17+
18+
import numpy as np
19+
import torch
20+
from transformers import AutoTokenizer, T5EncoderModel
21+
22+
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel
23+
from diffusers.utils.testing_utils import (
24+
enable_full_determinism,
25+
require_torch_accelerator,
26+
slow,
27+
)
28+
29+
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
30+
from ..test_pipelines_common import (
31+
PipelineTesterMixin,
32+
)
33+
34+
35+
enable_full_determinism()
36+
37+
38+
class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
39+
pipeline_class = WanPipeline
40+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
41+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
42+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
43+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
44+
required_optional_params = frozenset(
45+
[
46+
"num_inference_steps",
47+
"generator",
48+
"latents",
49+
"return_dict",
50+
"callback_on_step_end",
51+
"callback_on_step_end_tensor_inputs",
52+
]
53+
)
54+
test_xformers_attention = False
55+
supports_dduf = False
56+
57+
def get_dummy_components(self):
58+
torch.manual_seed(0)
59+
vae = AutoencoderKLWan(
60+
base_dim=3,
61+
z_dim=16,
62+
dim_mult=[1, 1, 1, 1],
63+
num_res_blocks=1,
64+
temperal_downsample=[False, True, True],
65+
)
66+
67+
torch.manual_seed(0)
68+
# TODO: impl FlowDPMSolverMultistepScheduler
69+
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
70+
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
71+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
72+
73+
torch.manual_seed(0)
74+
transformer = WanTransformer3DModel(
75+
patch_size=(1, 2, 2),
76+
num_attention_heads=2,
77+
attention_head_dim=12,
78+
in_channels=16,
79+
out_channels=16,
80+
text_dim=32,
81+
freq_dim=256,
82+
ffn_dim=32,
83+
num_layers=2,
84+
cross_attn_norm=True,
85+
qk_norm="rms_norm_across_heads",
86+
rope_max_seq_len=32,
87+
)
88+
89+
components = {
90+
"transformer": transformer,
91+
"vae": vae,
92+
"scheduler": scheduler,
93+
"text_encoder": text_encoder,
94+
"tokenizer": tokenizer,
95+
}
96+
return components
97+
98+
def get_dummy_inputs(self, device, seed=0):
99+
if str(device).startswith("mps"):
100+
generator = torch.manual_seed(seed)
101+
else:
102+
generator = torch.Generator(device=device).manual_seed(seed)
103+
inputs = {
104+
"prompt": "dance monkey",
105+
"negative_prompt": "negative", # TODO
106+
"generator": generator,
107+
"num_inference_steps": 2,
108+
"guidance_scale": 6.0,
109+
"height": 16,
110+
"width": 16,
111+
"num_frames": 9,
112+
"max_sequence_length": 16,
113+
"output_type": "pt",
114+
}
115+
return inputs
116+
117+
def test_inference(self):
118+
device = "cpu"
119+
120+
components = self.get_dummy_components()
121+
pipe = self.pipeline_class(**components)
122+
pipe.to(device)
123+
pipe.set_progress_bar_config(disable=None)
124+
125+
inputs = self.get_dummy_inputs(device)
126+
video = pipe(**inputs).frames
127+
generated_video = video[0]
128+
129+
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
130+
expected_video = torch.randn(9, 3, 16, 16)
131+
max_diff = np.abs(generated_video - expected_video).max()
132+
self.assertLessEqual(max_diff, 1e10)
133+
134+
@unittest.skip("Test not supported")
135+
def test_attention_slicing_forward_pass(self):
136+
pass
137+
138+
139+
@slow
140+
@require_torch_accelerator
141+
class WanPipelineIntegrationTests(unittest.TestCase):
142+
prompt = "A painting of a squirrel eating a burger."
143+
144+
def setUp(self):
145+
super().setUp()
146+
gc.collect()
147+
torch.cuda.empty_cache()
148+
149+
def tearDown(self):
150+
super().tearDown()
151+
gc.collect()
152+
torch.cuda.empty_cache()
153+
154+
@unittest.skip("TODO: test needs to be implemented")
155+
def test_Wanx(self):
156+
pass

0 commit comments

Comments
 (0)