Skip to content

Commit c3eebb2

Browse files
committed
add model tests
1 parent d5b3db9 commit c3eebb2

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

tests/models/test_modeling_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,10 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
935935
continue
936936
if name in skip:
937937
continue
938+
# TODO(aryan): remove the below lines after looking into easyanimate transformer a little more
939+
# It currently errors out the gradient checkpointing test because the gradients for attn2.to_out is None
940+
if param.grad is None:
941+
continue
938942
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
939943

940944
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import EasyAnimateTransformer3DModel
21+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
22+
23+
from ..test_modeling_common import ModelTesterMixin
24+
25+
26+
enable_full_determinism()
27+
28+
29+
class EasyAnimateTransformerTests(ModelTesterMixin, unittest.TestCase):
30+
model_class = EasyAnimateTransformer3DModel
31+
main_input_name = "hidden_states"
32+
uses_custom_attn_processor = True
33+
34+
@property
35+
def dummy_input(self):
36+
batch_size = 2
37+
num_channels = 4
38+
num_frames = 2
39+
height = 16
40+
width = 16
41+
embedding_dim = 16
42+
sequence_length = 16
43+
44+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
45+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
46+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
47+
48+
return {
49+
"hidden_states": hidden_states,
50+
"timestep": timestep,
51+
"timestep_cond": None,
52+
"encoder_hidden_states": encoder_hidden_states,
53+
"encoder_hidden_states_t5": None,
54+
"image_rotary_emb": None, # TODO(aryan): Create EasyAnimateRotaryPosEmbed layer
55+
"inpaint_latents": None,
56+
"control_latents": None,
57+
}
58+
59+
@property
60+
def input_shape(self):
61+
return (4, 2, 16, 16)
62+
63+
@property
64+
def output_shape(self):
65+
return (4, 2, 16, 16)
66+
67+
def prepare_init_args_and_inputs_for_common(self):
68+
init_dict = {
69+
"attention_head_dim": 8,
70+
"in_channels": 4,
71+
"mmdit_layers": 2,
72+
"num_attention_heads": 2,
73+
"num_layers": 2,
74+
"out_channels": 4,
75+
"patch_size": 2,
76+
"sample_height": 60,
77+
"sample_width": 90,
78+
"text_embed_dim": 16,
79+
"time_embed_dim": 8,
80+
"time_position_encoding_type": "3d_rope",
81+
"timestep_activation_fn": "silu",
82+
}
83+
inputs_dict = self.dummy_input
84+
return init_dict, inputs_dict
85+
86+
def test_gradient_checkpointing_is_applied(self):
87+
expected_set = {"EasyAnimateTransformer3DModel"}
88+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)