Skip to content

Commit 4f1653c

Browse files
committed
refactor part 4; modeling tests
1 parent fd18f9a commit 4f1653c

File tree

5 files changed

+81
-46
lines changed

5 files changed

+81
-46
lines changed

src/diffusers/models/transformers/transformer_allegro.py

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -239,47 +239,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
239239
attention_bias (`bool`, *optional*):
240240
Configure if the `TransformerBlocks` attention should contain a bias parameter.
241241
"""
242-
243-
# {
244-
# "_class_name": "AllegroTransformer3DModel",
245-
# "_diffusers_version": "0.30.3",
246-
# "_name_or_path": "/cpfs/data/user/larrytsai/Projects/Yi-VG/allegro/transformer",
247-
# "activation_fn": "gelu-approximate",
248-
# "attention_bias": true,
249-
# "attention_head_dim": 96,
250-
# "ca_attention_mode": "xformers",
251-
# "caption_channels": 4096,
252-
# "cross_attention_dim": 2304,
253-
# "double_self_attention": false,
254-
# "downsampler": null,
255-
# "dropout": 0.0,
256-
# "in_channels": 4,
257-
# "interpolation_scale_h": 2.0,
258-
# "interpolation_scale_t": 2.2,
259-
# "interpolation_scale_w": 2.0,
260-
# "model_max_length": 300,
261-
# "norm_elementwise_affine": false,
262-
# "norm_eps": 1e-06,
263-
# "norm_type": "ada_norm_single",
264-
# "num_attention_heads": 24,
265-
# "num_embeds_ada_norm": 1000,
266-
# "num_layers": 32,
267-
# "only_cross_attention": false,
268-
# "out_channels": 4,
269-
# "patch_size": 2,
270-
# "patch_size_t": 1,
271-
# "sa_attention_mode": "flash",
272-
# "sample_size": [
273-
# 90,
274-
# 160
275-
# ],
276-
# "sample_size_t": 22,
277-
# "upcast_attention": false,
278-
# "use_additional_conditions": null,
279-
# "use_linear_projection": false,
280-
# "use_rope": true
281-
# }
282-
242+
283243
@register_to_config
284244
def __init__(
285245
self,
@@ -304,8 +264,6 @@ def __init__(
304264
interpolation_scale_h: float = 2.0,
305265
interpolation_scale_w: float = 2.0,
306266
interpolation_scale_t: float = 2.2,
307-
use_rotary_positional_embeddings: bool = True,
308-
model_max_length: int = 300,
309267
):
310268
super().__init__()
311269

@@ -369,8 +327,8 @@ def _set_gradient_checkpointing(self, module, value=False):
369327
def forward(
370328
self,
371329
hidden_states: torch.Tensor,
372-
timestep: Optional[torch.LongTensor] = None,
373330
encoder_hidden_states: Optional[torch.Tensor] = None,
331+
timestep: Optional[torch.LongTensor] = None,
374332
attention_mask: Optional[torch.Tensor] = None,
375333
encoder_attention_mask: Optional[torch.Tensor] = None,
376334
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def __init__(
193193

194194
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
195195

196-
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
197196
def encode_prompt(
198197
self,
199198
prompt: Union[str, List[str]],
@@ -207,7 +206,6 @@ def encode_prompt(
207206
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
208207
clean_caption: bool = False,
209208
max_sequence_length: int = 300,
210-
**kwargs,
211209
):
212210
r"""
213211
Encodes the prompt into text encoder hidden states.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2024 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import torch
18+
19+
from diffusers import AllegroTransformer3DModel
20+
from diffusers.utils.testing_utils import (
21+
enable_full_determinism,
22+
torch_device,
23+
)
24+
25+
from ..test_modeling_common import ModelTesterMixin
26+
27+
28+
enable_full_determinism()
29+
30+
31+
class AllegroTransformerTests(ModelTesterMixin, unittest.TestCase):
32+
model_class = AllegroTransformer3DModel
33+
main_input_name = "hidden_states"
34+
uses_custom_attn_processor = True
35+
36+
@property
37+
def dummy_input(self):
38+
batch_size = 2
39+
num_channels = 4
40+
num_frames = 8
41+
height = 8
42+
width = 8
43+
embedding_dim = 16
44+
sequence_length = 16
45+
46+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
47+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim // 2)).to(torch_device)
48+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
49+
50+
return {
51+
"hidden_states": hidden_states,
52+
"encoder_hidden_states": encoder_hidden_states,
53+
"timestep": timestep,
54+
}
55+
56+
@property
57+
def input_shape(self):
58+
return (4, 8, 8, 8)
59+
60+
@property
61+
def output_shape(self):
62+
return (4, 8, 8, 8)
63+
64+
def prepare_init_args_and_inputs_for_common(self):
65+
init_dict = {
66+
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
67+
"num_attention_heads": 2,
68+
"attention_head_dim": 8,
69+
"in_channels": 4,
70+
"out_channels": 4,
71+
"num_layers": 1,
72+
"cross_attention_dim": 16,
73+
"sample_width": 8,
74+
"sample_height": 8,
75+
"sample_frames": 8,
76+
"caption_channels": 8,
77+
}
78+
inputs_dict = self.dummy_input
79+
return init_dict, inputs_dict

tests/pipelines/allegro/__init__.py

Whitespace-only changes.

tests/pipelines/allegro/test_allegro.py

Whitespace-only changes.

0 commit comments

Comments
 (0)