Skip to content

Commit dfc195c

Browse files
committed
add transformer test
1 parent 84f4abe commit dfc195c

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2024 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import torch
18+
19+
from diffusers import WanTransformer3DModel
20+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
21+
22+
from ..test_modeling_common import ModelTesterMixin
23+
24+
25+
enable_full_determinism()
26+
27+
28+
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
29+
model_class = WanTransformer3DModel
30+
main_input_name = "hidden_states"
31+
uses_custom_attn_processor = True
32+
33+
@property
34+
def dummy_input(self):
35+
batch_size = 1
36+
num_channels = 4
37+
num_frames = 2
38+
height = 16
39+
width = 16
40+
text_encoder_embedding_dim = 16
41+
sequence_length = 12
42+
43+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
44+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
45+
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
46+
47+
return {
48+
"hidden_states": hidden_states,
49+
"encoder_hidden_states": encoder_hidden_states,
50+
"timestep": timestep,
51+
}
52+
53+
@property
54+
def input_shape(self):
55+
return (4, 1, 16, 16)
56+
57+
@property
58+
def output_shape(self):
59+
return (4, 1, 16, 16)
60+
61+
def prepare_init_args_and_inputs_for_common(self):
62+
init_dict = {
63+
"patch_size": (1, 2, 2),
64+
"num_attention_heads": 2,
65+
"attention_head_dim": 12,
66+
"in_channels": 4,
67+
"out_channels": 4,
68+
"text_dim": 16,
69+
"freq_dim": 256,
70+
"ffn_dim": 32,
71+
"num_layers": 2,
72+
"cross_attn_norm": True,
73+
"qk_norm": "rms_norm_across_heads",
74+
"rope_max_seq_len": 32,
75+
}
76+
inputs_dict = self.dummy_input
77+
return init_dict, inputs_dict
78+
79+
def test_gradient_checkpointing_is_applied(self):
80+
expected_set = {"WanTransformer3DModel"}
81+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)