Skip to content

Commit da834d5

Browse files
committed
add tests
1 parent 68f817a commit da834d5

File tree

4 files changed

+115
-16
lines changed

4 files changed

+115
-16
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ def __init__(
767767
channels,
768768
kernel_size,
769769
padding=kernel_size // 2,
770-
groups=3 * in_channels,
770+
groups=channels,
771771
bias=False,
772772
)
773773
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
@@ -786,7 +786,7 @@ def __init__(
786786
in_channels: int,
787787
out_channels: int,
788788
num_attention_heads: Optional[int] = None,
789-
heads_ratio: float = 1.0,
789+
mult: float = 1.0,
790790
attention_head_dim: int = 8,
791791
norm_type: str = "batch_norm",
792792
kernel_sizes: Tuple[int, ...] = (5,),
@@ -804,9 +804,7 @@ def __init__(
804804
self.residual_connection = residual_connection
805805

806806
num_attention_heads = (
807-
int(in_channels // attention_head_dim * heads_ratio)
808-
if num_attention_heads is None
809-
else num_attention_heads
807+
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
810808
)
811809
inner_dim = num_attention_heads * attention_head_dim
812810

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..attention_processor import SanaMultiscaleLinearAttention
2525
from ..modeling_utils import ModelMixin
2626
from ..normalization import RMSNorm, get_normalization
27+
from .vae import DecoderOutput
2728

2829

2930
class GLUMBConv(nn.Module):
@@ -90,8 +91,8 @@ class EfficientViTBlock(nn.Module):
9091
def __init__(
9192
self,
9293
in_channels: int,
93-
heads_ratio: float = 1.0,
94-
dim: int = 32,
94+
mult: float = 1.0,
95+
attention_head_dim: int = 32,
9596
qkv_multiscales: Tuple[int, ...] = (5,),
9697
norm_type: str = "batch_norm",
9798
) -> None:
@@ -100,8 +101,8 @@ def __init__(
100101
self.attn = SanaMultiscaleLinearAttention(
101102
in_channels=in_channels,
102103
out_channels=in_channels,
103-
heads_ratio=heads_ratio,
104-
attention_head_dim=dim,
104+
mult=mult,
105+
attention_head_dim=attention_head_dim,
105106
norm_type=norm_type,
106107
kernel_sizes=qkv_multiscales,
107108
residual_connection=True,
@@ -122,6 +123,7 @@ def get_block(
122123
block_type: str,
123124
in_channels: int,
124125
out_channels: int,
126+
attention_head_dim: int,
125127
norm_type: str,
126128
act_fn: str,
127129
qkv_mutliscales: Tuple[int] = (),
@@ -130,7 +132,9 @@ def get_block(
130132
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
131133

132134
elif block_type == "EfficientViTBlock":
133-
block = EfficientViTBlock(in_channels, norm_type=norm_type, qkv_multiscales=qkv_mutliscales)
135+
block = EfficientViTBlock(
136+
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
137+
)
134138

135139
else:
136140
raise ValueError(f"Block with {block_type=} is not supported.")
@@ -224,6 +228,7 @@ def __init__(
224228
self,
225229
in_channels: int,
226230
latent_channels: int,
231+
attention_head_dim: int = 32,
227232
block_type: Union[str, Tuple[str]] = "ResBlock",
228233
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
229234
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
@@ -262,6 +267,7 @@ def __init__(
262267
block_type[i],
263268
out_channel,
264269
out_channel,
270+
attention_head_dim=attention_head_dim,
265271
norm_type="rms_norm",
266272
act_fn="silu",
267273
qkv_mutliscales=qkv_multiscales[i],
@@ -305,6 +311,7 @@ def __init__(
305311
self,
306312
in_channels: int,
307313
latent_channels: int,
314+
attention_head_dim: int = 32,
308315
block_type: Union[str, Tuple[str]] = "ResBlock",
309316
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
310317
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
@@ -348,6 +355,7 @@ def __init__(
348355
block_type[i],
349356
out_channel,
350357
out_channel,
358+
attention_head_dim=attention_head_dim,
351359
norm_type=norm_type[i],
352360
act_fn=act_fn[i],
353361
qkv_mutliscales=qkv_multiscales[i],
@@ -425,13 +433,14 @@ class AutoencoderDC(ModelMixin, ConfigMixin):
425433
A scaling factor applied during model operations.
426434
"""
427435

428-
_supports_gradient_checkpointing = True
436+
_supports_gradient_checkpointing = False
429437

430438
@register_to_config
431439
def __init__(
432440
self,
433441
in_channels: int = 3,
434442
latent_channels: int = 32,
443+
attention_head_dim: int = 32,
435444
encoder_block_types: Union[str, Tuple[str]] = "ResBlock",
436445
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
437446
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
@@ -451,6 +460,7 @@ def __init__(
451460
self.encoder = Encoder(
452461
in_channels=in_channels,
453462
latent_channels=latent_channels,
463+
attention_head_dim=attention_head_dim,
454464
block_type=encoder_block_types,
455465
block_out_channels=encoder_block_out_channels,
456466
layers_per_block=encoder_layers_per_block,
@@ -460,6 +470,7 @@ def __init__(
460470
self.decoder = Decoder(
461471
in_channels=in_channels,
462472
latent_channels=latent_channels,
473+
attention_head_dim=attention_head_dim,
463474
block_type=decoder_block_types,
464475
block_out_channels=decoder_block_out_channels,
465476
layers_per_block=decoder_layers_per_block,
@@ -480,7 +491,9 @@ def decode(self, x: torch.Tensor) -> torch.Tensor:
480491
x = self.decoder(x)
481492
return x
482493

483-
def forward(self, x: torch.Tensor) -> torch.Tensor:
484-
x = self.encoder(x)
485-
x = self.decoder(x)
486-
return x
494+
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
495+
z = self.encode(sample)
496+
dec = self.decode(z)
497+
if not return_dict:
498+
return (dec,)
499+
return DecoderOutput(sample=dec)

src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
3535
from .modeling_stable_audio import StableAudioProjectionModel
3636

37+
3738
if is_torch_xla_available():
3839
import torch_xla.core.xla_model as xm
3940

@@ -732,7 +733,7 @@ def __call__(
732733
if callback is not None and i % callback_steps == 0:
733734
step_idx = i // getattr(self.scheduler, "order", 1)
734735
callback(step_idx, t, latents)
735-
736+
736737
if XLA_AVAILABLE:
737738
xm.mark_step()
738739

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
from diffusers import AutoencoderDC
19+
from diffusers.utils.testing_utils import (
20+
enable_full_determinism,
21+
floats_tensor,
22+
torch_device,
23+
)
24+
25+
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
26+
27+
28+
enable_full_determinism()
29+
30+
31+
class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
32+
model_class = AutoencoderDC
33+
main_input_name = "sample"
34+
base_precision = 1e-2
35+
36+
def get_autoencoder_dc_config(self):
37+
return {
38+
"in_channels": 3,
39+
"latent_channels": 4,
40+
"attention_head_dim": 2,
41+
"encoder_block_types": (
42+
"ResBlock",
43+
"EfficientViTBlock",
44+
),
45+
"decoder_block_types": (
46+
"ResBlock",
47+
"EfficientViTBlock",
48+
),
49+
"encoder_block_out_channels": (8, 8),
50+
"decoder_block_out_channels": (8, 8),
51+
"encoder_qkv_multiscales": ((), (5,)),
52+
"decoder_qkv_multiscales": ((), (5,)),
53+
"encoder_layers_per_block": (1, 1),
54+
"decoder_layers_per_block": [1, 1],
55+
"downsample_block_type": "conv",
56+
"upsample_block_type": "interpolate",
57+
"decoder_norm_types": "rms_norm",
58+
"decoder_act_fns": "silu",
59+
"scaling_factor": 0.41407,
60+
}
61+
62+
@property
63+
def dummy_input(self):
64+
batch_size = 4
65+
num_channels = 3
66+
sizes = (32, 32)
67+
68+
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
69+
70+
return {"sample": image}
71+
72+
@property
73+
def input_shape(self):
74+
return (3, 32, 32)
75+
76+
@property
77+
def output_shape(self):
78+
return (3, 32, 32)
79+
80+
def prepare_init_args_and_inputs_for_common(self):
81+
init_dict = self.get_autoencoder_dc_config()
82+
inputs_dict = self.dummy_input
83+
return init_dict, inputs_dict
84+
85+
@unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
86+
def test_forward_with_norm_groups(self):
87+
pass

0 commit comments

Comments
 (0)