Skip to content

Commit 7b9d7e5

Browse files
committed
make style
1 parent 1f8a3b3 commit 7b9d7e5

File tree

5 files changed

+60
-37
lines changed

5 files changed

+60
-37
lines changed

scripts/convert_dcae_to_diffusers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import torch
55
from safetensors.torch import load_file
6-
from transformers import T5EncoderModel, T5Tokenizer
76

87
from diffusers import AutoencoderDC
98

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,8 @@
572572
AllegroTransformer3DModel,
573573
AsymmetricAutoencoderKL,
574574
AuraFlowTransformer2DModel,
575-
AutoencoderKL,
576575
AutoencoderDC,
576+
AutoencoderKL,
577577
AutoencoderKLAllegro,
578578
AutoencoderKLCogVideoX,
579579
AutoencoderKLMochi,

src/diffusers/models/attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,16 @@
1919

2020
from ..utils import deprecate, logging
2121
from ..utils.torch_utils import maybe_allow_in_graph
22-
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU, get_activation
22+
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
2323
from .attention_processor import Attention, JointAttnProcessor2_0
2424
from .embeddings import SinusoidalPositionalEmbedding
25-
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX, RMSNormNd
25+
from .normalization import (
26+
AdaLayerNorm,
27+
AdaLayerNormContinuous,
28+
AdaLayerNormZero,
29+
RMSNorm,
30+
SD35AdaLayerNormZeroX,
31+
)
2632

2733

2834
logger = logging.get_logger(__name__)

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Any, Callable, Tuple, Optional
16+
from typing import Any, Callable, Optional, Tuple
1717

1818
import torch
1919
import torch.nn as nn
@@ -34,7 +34,7 @@ def val2tuple(x: list | tuple | Any, min_len: int = 1) -> tuple:
3434
return tuple(x)
3535

3636

37-
def build_norm(name: Optional[str]="bn2d", num_features: Optional[int]=None) -> Optional[nn.Module]:
37+
def build_norm(name: Optional[str] = "bn2d", num_features: Optional[int] = None) -> Optional[nn.Module]:
3838
if name is None:
3939
norm = None
4040
elif name == "rms2d":
@@ -481,7 +481,7 @@ def build_stage_main(
481481

482482
in_channels = width if d > 0 else input_width
483483
out_channels = width
484-
484+
485485
if current_block_type == "ResBlock":
486486
assert in_channels == out_channels
487487
block = ResBlock(
@@ -501,7 +501,7 @@ def build_stage_main(
501501
block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=(5,))
502502
else:
503503
raise ValueError(f"block_type {current_block_type} is not supported")
504-
504+
505505
stage.append(block)
506506
return stage
507507

@@ -543,7 +543,7 @@ def __init__(
543543
shortcut: bool = True,
544544
) -> None:
545545
super().__init__()
546-
546+
547547
self.downsample = downsample
548548
self.factor = 2
549549
self.stride = 1 if downsample else 2
@@ -552,21 +552,21 @@ def __init__(
552552
if downsample:
553553
assert out_channels % out_ratio == 0
554554
out_channels = out_channels // out_ratio
555-
555+
556556
self.conv = nn.Conv2d(
557557
in_channels,
558558
out_channels,
559559
kernel_size=kernel_size,
560560
stride=self.stride,
561561
padding=kernel_size // 2,
562562
)
563-
563+
564564
self.shortcut = None
565565
if shortcut:
566566
self.shortcut = DownsamplePixelUnshuffleChannelAveraging(
567567
in_channels=in_channels, out_channels=out_channels, factor=2
568568
)
569-
569+
570570
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
571571
x = self.conv(hidden_states)
572572
if self.downsample:
@@ -594,8 +594,8 @@ def __init__(
594594
self.interpolation_mode = interpolation_mode
595595
self.factor = 2
596596
self.stride = 1
597-
598-
out_ratio = self.factor ** 2
597+
598+
out_ratio = self.factor**2
599599
if not interpolate:
600600
out_channels = out_channels * out_ratio
601601

@@ -612,20 +612,20 @@ def __init__(
612612
self.shortcut = UpsampleChannelDuplicatingPixelUnshuffle(
613613
in_channels=in_channels, out_channels=out_channels, factor=2
614614
)
615-
615+
616616
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
617617
if self.interpolate:
618618
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
619619
x = self.conv(x)
620620
else:
621621
x = self.conv(hidden_states)
622622
x = F.pixel_shuffle(x, self.factor)
623-
623+
624624
if self.shortcut is not None:
625625
hidden_states = x + self.shortcut(hidden_states)
626626
else:
627627
hidden_states = x
628-
628+
629629
return hidden_states
630630

631631

@@ -644,9 +644,7 @@ def __init__(
644644
self.num_stages = num_stages
645645
assert len(layers_per_block) == num_stages
646646
assert len(block_out_channels) == num_stages
647-
assert isinstance(block_type, str) or (
648-
isinstance(block_type, list) and len(block_type) == num_stages
649-
)
647+
assert isinstance(block_type, str) or (isinstance(block_type, list) and len(block_type) == num_stages)
650648

651649
factor = 1 if layers_per_block[0] > 0 else 2
652650

@@ -722,19 +720,11 @@ def __init__(
722720
self.num_stages = num_stages
723721
assert len(layers_per_block) == num_stages
724722
assert len(block_out_channels) == num_stages
725-
assert isinstance(block_type, str) or (
726-
isinstance(block_type, list) and len(block_type) == num_stages
727-
)
723+
assert isinstance(block_type, str) or (isinstance(block_type, list) and len(block_type) == num_stages)
728724
assert isinstance(norm, str) or (isinstance(norm, list) and len(norm) == num_stages)
729725
assert isinstance(act, str) or (isinstance(act, list) and len(act) == num_stages)
730726

731-
self.conv_in = nn.Conv2d(
732-
latent_channels,
733-
block_out_channels[-1],
734-
kernel_size=3,
735-
stride=1,
736-
padding=1
737-
)
727+
self.conv_in = nn.Conv2d(latent_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
738728
self.norm_in = UpsampleChannelDuplicatingPixelUnshuffle(
739729
in_channels=latent_channels, out_channels=block_out_channels[-1], factor=1
740730
)
@@ -767,9 +757,15 @@ def __init__(
767757
stages.insert(0, nn.Sequential(*current_stage))
768758
self.stages = nn.ModuleList(stages)
769759

770-
factor = 1 if layers_per_block[0] > 0 else 2
760+
factor = 1 if layers_per_block[0] > 0 else 2
771761

772-
self.norm_out = RMSNormNd(block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1], eps=1e-5, elementwise_affine=True, bias=True, channel_dim=1)
762+
self.norm_out = RMSNormNd(
763+
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
764+
eps=1e-5,
765+
elementwise_affine=True,
766+
bias=True,
767+
channel_dim=1,
768+
)
773769
self.conv_act = nn.ReLU()
774770
self.conv_out = None
775771

@@ -884,7 +880,9 @@ def dc_ae_f32c32(name: str) -> dict:
884880
return cfg
885881

886882

887-
def dc_ae_f64c128(name: str,) -> dict:
883+
def dc_ae_f64c128(
884+
name: str,
885+
) -> dict:
888886
if name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
889887
cfg = {
890888
"latent_channels": 128,
@@ -901,14 +899,34 @@ def dc_ae_f64c128(name: str,) -> dict:
901899
return cfg
902900

903901

904-
def dc_ae_f128c512(name: str,) -> dict:
902+
def dc_ae_f128c512(
903+
name: str,
904+
) -> dict:
905905
if name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
906906
cfg = {
907907
"latent_channels": 512,
908-
"encoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViT_GLU", "EViT_GLU", "EViT_GLU", "EViT_GLU", "EViT_GLU"],
908+
"encoder_block_type": [
909+
"ResBlock",
910+
"ResBlock",
911+
"ResBlock",
912+
"EViT_GLU",
913+
"EViT_GLU",
914+
"EViT_GLU",
915+
"EViT_GLU",
916+
"EViT_GLU",
917+
],
909918
"block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
910919
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
911-
"decoder_block_type": ["ResBlock", "ResBlock", "ResBlock", "EViT_GLU", "EViT_GLU", "EViT_GLU", "EViT_GLU", "EViT_GLU"],
920+
"decoder_block_type": [
921+
"ResBlock",
922+
"ResBlock",
923+
"ResBlock",
924+
"EViT_GLU",
925+
"EViT_GLU",
926+
"EViT_GLU",
927+
"EViT_GLU",
928+
"EViT_GLU",
929+
],
912930
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
913931
"decoder_norm": ["bn2d", "bn2d", "bn2d", "rms2d", "rms2d", "rms2d", "rms2d", "rms2d"],
914932
"decoder_act": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],

src/diffusers/models/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool
525525

526526
self.weight = None
527527
self.bias = None
528-
528+
529529
if elementwise_affine:
530530
self.weight = nn.Parameter(torch.ones(dim))
531531
if bias:

0 commit comments

Comments
 (0)