Skip to content

Commit 76b697f

Browse files
committed
add model test
1 parent 5797093 commit 76b697f

File tree

2 files changed

+116
-10
lines changed

2 files changed

+116
-10
lines changed

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Optional, Tuple
16+
from typing import Any, Dict, Optional, Tuple
1717

1818
import torch
1919
import torch.nn as nn
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
22-
from ...utils import logging
22+
from ...utils import is_torch_version, logging
2323
from ...utils.torch_utils import maybe_allow_in_graph
2424
from ..attention import FeedForward
2525
from ..attention_processor import Attention, MochiAttnProcessor2_0
@@ -131,7 +131,7 @@ def forward(
131131
) * torch.tanh(enc_gate_msa).unsqueeze(1)
132132
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
133133
context_ff_output = self.ff_context(norm_encoder_hidden_states)
134-
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(enc_gate_mlp).unsqueeze(0)
134+
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(enc_gate_mlp).unsqueeze(1)
135135

136136
return hidden_states, encoder_hidden_states
137137

@@ -248,6 +248,12 @@ def __init__(
248248
)
249249
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
250250

251+
self.gradient_checkpointing = False
252+
253+
def _set_gradient_checkpointing(self, module, value=False):
254+
if hasattr(module, "gradient_checkpointing"):
255+
module.gradient_checkpointing = value
256+
251257
def forward(
252258
self,
253259
hidden_states: torch.Tensor,
@@ -280,13 +286,30 @@ def forward(
280286
)
281287

282288
for i, block in enumerate(self.transformer_blocks):
283-
hidden_states, encoder_hidden_states = block(
284-
hidden_states=hidden_states,
285-
encoder_hidden_states=encoder_hidden_states,
286-
temb=temb,
287-
image_rotary_emb=image_rotary_emb,
288-
)
289-
print(hidden_states.mean(), hidden_states.std())
289+
if self.gradient_checkpointing:
290+
291+
def create_custom_forward(module):
292+
def custom_forward(*inputs):
293+
return module(*inputs)
294+
295+
return custom_forward
296+
297+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
298+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
299+
create_custom_forward(block),
300+
hidden_states,
301+
encoder_hidden_states,
302+
temb,
303+
image_rotary_emb,
304+
**ckpt_kwargs,
305+
)
306+
else:
307+
hidden_states, encoder_hidden_states = block(
308+
hidden_states=hidden_states,
309+
encoder_hidden_states=encoder_hidden_states,
310+
temb=temb,
311+
image_rotary_emb=image_rotary_emb,
312+
)
290313

291314
hidden_states = self.norm_out(hidden_states, temb)
292315
hidden_states = self.proj_out(hidden_states)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import MochiTransformer3DModel
21+
from diffusers.utils.testing_utils import (
22+
enable_full_determinism,
23+
torch_device
24+
)
25+
26+
from ..test_modeling_common import ModelTesterMixin
27+
28+
29+
enable_full_determinism()
30+
31+
32+
class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
33+
model_class = MochiTransformer3DModel
34+
main_input_name = "hidden_states"
35+
uses_custom_attn_processor = True
36+
37+
@property
38+
def dummy_input(self):
39+
batch_size = 2
40+
num_channels = 4
41+
num_frames = 2
42+
height = 16
43+
width = 16
44+
embedding_dim = 16
45+
sequence_length = 16
46+
47+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
48+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
49+
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
50+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
51+
52+
return {
53+
"hidden_states": hidden_states,
54+
"encoder_hidden_states": encoder_hidden_states,
55+
"timestep": timestep,
56+
"encoder_attention_mask": encoder_attention_mask,
57+
}
58+
59+
@property
60+
def input_shape(self):
61+
return (4, 2, 16, 16)
62+
63+
@property
64+
def output_shape(self):
65+
return (4, 2, 16, 16)
66+
67+
def prepare_init_args_and_inputs_for_common(self):
68+
init_dict = {
69+
"patch_size": 2,
70+
"num_attention_heads": 2,
71+
"attention_head_dim": 8,
72+
"num_layers": 2,
73+
"pooled_projection_dim": 16,
74+
"in_channels": 4,
75+
"out_channels": None,
76+
"qk_norm": "rms_norm",
77+
"text_embed_dim": 16,
78+
"time_embed_dim": 4,
79+
"activation_fn": "swiglu",
80+
"max_sequence_length": 16,
81+
}
82+
inputs_dict = self.dummy_input
83+
return init_dict, inputs_dict

0 commit comments

Comments
 (0)