Skip to content

Commit 7ce9ff2

Browse files
committed
remove build encoder/decoder project in/out
1 parent 96e844b commit 7ce9ff2

File tree

1 file changed

+104
-144
lines changed
  • src/diffusers/models/autoencoders

1 file changed

+104
-144
lines changed

src/diffusers/models/autoencoders/dc_ae.py

Lines changed: 104 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4141
return x
4242

4343

44-
45-
4644
class ConvLayer(nn.Module):
4745
def __init__(
4846
self,
@@ -519,125 +517,6 @@ def build_upsample_block(block_type: str, in_channels: int, out_channels: int, s
519517
return block
520518

521519

522-
def build_encoder_project_in_block(in_channels: int, out_channels: int, factor: int, downsample_block_type: str):
523-
if factor == 1:
524-
block = nn.Conv2d(
525-
in_channels=in_channels,
526-
out_channels=out_channels,
527-
kernel_size=3,
528-
padding=1,
529-
)
530-
elif factor == 2:
531-
block = build_downsample_block(
532-
block_type=downsample_block_type, in_channels=in_channels, out_channels=out_channels, shortcut=None
533-
)
534-
else:
535-
raise ValueError(f"downsample factor {factor} is not supported for encoder project in")
536-
return block
537-
538-
539-
def build_encoder_project_out_block(
540-
in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str], shortcut: Optional[str]
541-
):
542-
layers: list[nn.Module] = []
543-
544-
if norm is None:
545-
pass
546-
elif norm == "rms2d":
547-
layers.append(RMSNorm2d(normalized_shape=in_channels))
548-
elif norm == "bn2d":
549-
layers.append(BatchNorm2d(num_features=in_channels))
550-
else:
551-
raise ValueError(f"norm {norm} is not supported")
552-
553-
if act is not None:
554-
layers.append(get_activation(act))
555-
layers.append(ConvLayer(
556-
in_channels=in_channels,
557-
out_channels=out_channels,
558-
kernel_size=3,
559-
stride=1,
560-
use_bias=True,
561-
norm=None,
562-
act_func=None,
563-
))
564-
block = nn.Sequential(OrderedDict([("op_list", nn.Sequential(*layers))]))
565-
566-
if shortcut is None:
567-
pass
568-
elif shortcut == "averaging":
569-
shortcut_block = PixelUnshuffleChannelAveragingDownsample2D(
570-
in_channels=in_channels, out_channels=out_channels, factor=1
571-
)
572-
block = ResidualBlock(block, shortcut_block)
573-
else:
574-
raise ValueError(f"shortcut {shortcut} is not supported for encoder project out")
575-
return block
576-
577-
578-
def build_decoder_project_in_block(in_channels: int, out_channels: int, shortcut: Optional[str]):
579-
block = ConvLayer(
580-
in_channels=in_channels,
581-
out_channels=out_channels,
582-
kernel_size=3,
583-
stride=1,
584-
use_bias=True,
585-
norm=None,
586-
act_func=None,
587-
)
588-
if shortcut is None:
589-
pass
590-
elif shortcut == "duplicating":
591-
shortcut_block = ChannelDuplicatingPixelUnshuffleUpsample2D(
592-
in_channels=in_channels, out_channels=out_channels, factor=1
593-
)
594-
block = ResidualBlock(block, shortcut_block)
595-
else:
596-
raise ValueError(f"shortcut {shortcut} is not supported for decoder project in")
597-
return block
598-
599-
600-
def build_decoder_project_out_block(
601-
in_channels: int, out_channels: int, factor: int, upsample_block_type: str, norm: Optional[str], act: Optional[str]
602-
):
603-
layers: list[nn.Module] = []
604-
605-
if norm is None:
606-
pass
607-
elif norm == "rms2d":
608-
layers.append(RMSNorm2d(normalized_shape=in_channels))
609-
elif norm == "bn2d":
610-
layers.append(BatchNorm2d(num_features=in_channels))
611-
else:
612-
raise ValueError(f"norm {norm} is not supported")
613-
614-
if act is not None:
615-
layers.append(get_activation(act))
616-
617-
if factor == 1:
618-
layers.append(
619-
ConvLayer(
620-
in_channels=in_channels,
621-
out_channels=out_channels,
622-
kernel_size=3,
623-
stride=1,
624-
use_bias=True,
625-
norm=None,
626-
act_func=None,
627-
)
628-
)
629-
elif factor == 2:
630-
layers.append(
631-
build_upsample_block(
632-
block_type=upsample_block_type, in_channels=in_channels, out_channels=out_channels, shortcut=None
633-
)
634-
)
635-
else:
636-
raise ValueError(f"upsample factor {factor} is not supported for decoder project out")
637-
block = nn.Sequential(OrderedDict([("op_list", nn.Sequential(*layers))]))
638-
return block
639-
640-
641520
class Encoder(nn.Module):
642521
def __init__(
643522
self,
@@ -665,14 +544,23 @@ def __init__(
665544
raise ValueError(f"len(depth_list) {len(depth_list)} and len(width_list) {len(width_list)} should be equal to num_stages {num_stages}")
666545
if not isinstance(block_type, (str, list)) or (isinstance(block_type, list) and len(block_type) != num_stages):
667546
raise ValueError(f"block_type should be either a str or a list of str with length {num_stages}, but got {block_type}")
547+
548+
# project in
549+
if depth_list[0] > 0:
550+
self.project_in = nn.Conv2d(
551+
in_channels=in_channels,
552+
out_channels=width_list[0],
553+
kernel_size=3,
554+
padding=1,
555+
)
556+
elif depth_list[1] > 0:
557+
self.project_in = build_downsample_block(
558+
block_type=downsample_block_type, in_channels=in_channels, out_channels=width_list[1], shortcut=None
559+
)
560+
else:
561+
raise ValueError(f"depth list {depth_list} is not supported for encoder project in")
668562

669-
self.project_in = build_encoder_project_in_block(
670-
in_channels=in_channels,
671-
out_channels=width_list[0] if depth_list[0] > 0 else width_list[1],
672-
factor=1 if depth_list[0] > 0 else 2,
673-
downsample_block_type=downsample_block_type,
674-
)
675-
563+
# stages
676564
self.stages: list[nn.Module] = []
677565
for stage_id, (width, depth) in enumerate(zip(width_list, depth_list)):
678566
stage_block_type = block_type[stage_id] if isinstance(block_type, list) else block_type
@@ -690,13 +578,39 @@ def __init__(
690578
self.stages.append(nn.Sequential(OrderedDict([("op_list", nn.Sequential(*stage))])))
691579
self.stages = nn.ModuleList(self.stages)
692580

693-
self.project_out = build_encoder_project_out_block(
581+
# project out
582+
project_out_layers: list[nn.Module] = []
583+
if out_norm is None:
584+
pass
585+
elif out_norm == "rms2d":
586+
project_out_layers.append(RMSNorm2d(normalized_shape=width_list[-1]))
587+
elif out_norm == "bn2d":
588+
project_out_layers.append(BatchNorm2d(num_features=width_list[-1]))
589+
else:
590+
raise ValueError(f"norm {out_norm} is not supported for encoder project out")
591+
if out_act is not None:
592+
project_out_layers.append(get_activation(out_act))
593+
project_out_out_channels = 2 * latent_channels if double_latent else latent_channels
594+
project_out_layers.append(ConvLayer(
694595
in_channels=width_list[-1],
695-
out_channels=2 * latent_channels if double_latent else latent_channels,
696-
norm=out_norm,
697-
act=out_act,
698-
shortcut=out_shortcut,
699-
)
596+
out_channels=project_out_out_channels,
597+
kernel_size=3,
598+
stride=1,
599+
use_bias=True,
600+
norm=None,
601+
act_func=None,
602+
))
603+
project_out_block = nn.Sequential(OrderedDict([("op_list", nn.Sequential(*project_out_layers))]))
604+
if out_shortcut is None:
605+
pass
606+
elif out_shortcut == "averaging":
607+
shortcut_block = PixelUnshuffleChannelAveragingDownsample2D(
608+
in_channels=width_list[-1], out_channels=project_out_out_channels, factor=1
609+
)
610+
project_out_block = ResidualBlock(project_out_block, shortcut_block)
611+
else:
612+
raise ValueError(f"shortcut {out_shortcut} is not supported for encoder project out")
613+
self.project_out = project_out_block
700614

701615
def forward(self, x: torch.Tensor) -> torch.Tensor:
702616
x = self.project_in(x)
@@ -739,12 +653,28 @@ def __init__(
739653
if not isinstance(act, (str, list)) or (isinstance(act, list) and len(act) != num_stages):
740654
raise ValueError(f"act should be either a str or a list of str with length {num_stages}, but got {act}")
741655

742-
self.project_in = build_decoder_project_in_block(
656+
# project in
657+
project_in_block = ConvLayer(
743658
in_channels=latent_channels,
744659
out_channels=width_list[-1],
745-
shortcut=in_shortcut,
660+
kernel_size=3,
661+
stride=1,
662+
use_bias=True,
663+
norm=None,
664+
act_func=None,
746665
)
666+
if in_shortcut is None:
667+
pass
668+
elif in_shortcut == "duplicating":
669+
shortcut_block = ChannelDuplicatingPixelUnshuffleUpsample2D(
670+
in_channels=latent_channels, out_channels=width_list[-1], factor=1
671+
)
672+
project_in_block = ResidualBlock(project_in_block, shortcut_block)
673+
else:
674+
raise ValueError(f"shortcut {in_shortcut} is not supported for decoder project in")
675+
self.project_in = project_in_block
747676

677+
# stages
748678
self.stages: list[nn.Module] = []
749679
for stage_id, (width, depth) in reversed(list(enumerate(zip(width_list, depth_list)))):
750680
stage = []
@@ -775,14 +705,44 @@ def __init__(
775705
self.stages.insert(0, nn.Sequential(OrderedDict([("op_list", nn.Sequential(*stage))])))
776706
self.stages = nn.ModuleList(self.stages)
777707

778-
self.project_out = build_decoder_project_out_block(
779-
in_channels=width_list[0] if depth_list[0] > 0 else width_list[1],
780-
out_channels=in_channels,
781-
factor=1 if depth_list[0] > 0 else 2,
782-
upsample_block_type=upsample_block_type,
783-
norm=out_norm,
784-
act=out_act,
785-
)
708+
# project out
709+
project_out_layers: list[nn.Module] = []
710+
if depth_list[0] > 0:
711+
project_out_in_channels = width_list[0]
712+
elif depth_list[1] > 0:
713+
project_out_in_channels = width_list[1]
714+
else:
715+
raise ValueError(f"depth list {depth_list} is not supported for decoder project out")
716+
if out_norm is None:
717+
pass
718+
elif out_norm == "rms2d":
719+
project_out_layers.append(RMSNorm2d(normalized_shape=project_out_in_channels))
720+
elif out_norm == "bn2d":
721+
project_out_layers.append(BatchNorm2d(num_features=project_out_in_channels))
722+
else:
723+
raise ValueError(f"norm {out_norm} is not supported for decoder project out")
724+
project_out_layers.append(get_activation(out_act))
725+
if depth_list[0] > 0:
726+
project_out_layers.append(
727+
ConvLayer(
728+
in_channels=project_out_in_channels,
729+
out_channels=in_channels,
730+
kernel_size=3,
731+
stride=1,
732+
use_bias=True,
733+
norm=None,
734+
act_func=None,
735+
)
736+
)
737+
elif depth_list[1] > 0:
738+
project_out_layers.append(
739+
build_upsample_block(
740+
block_type=upsample_block_type, in_channels=project_out_in_channels, out_channels=in_channels, shortcut=None
741+
)
742+
)
743+
else:
744+
raise ValueError(f"depth list {depth_list} is not supported for decoder project out")
745+
self.project_out = nn.Sequential(OrderedDict([("op_list", nn.Sequential(*project_out_layers))]))
786746

787747
def forward(self, x: torch.Tensor) -> torch.Tensor:
788748
x = self.project_in(x)

0 commit comments

Comments
 (0)