Skip to content

Commit 883bcf4

Browse files
committed
remove opsequential
1 parent 80dce02 commit 883bcf4

File tree

1 file changed

+28
-39
lines changed
  • src/diffusers/models/autoencoders

1 file changed

+28
-39
lines changed

src/diffusers/models/autoencoders/dc_ae.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -431,21 +431,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
431431
return res
432432

433433

434-
class OpSequential(nn.Module):
435-
def __init__(self, op_list: list[Optional[nn.Module]]):
436-
super().__init__()
437-
valid_op_list = []
438-
for op in op_list:
439-
if op is not None:
440-
valid_op_list.append(op)
441-
self.op_list = nn.ModuleList(valid_op_list)
442-
443-
def forward(self, x: torch.Tensor) -> torch.Tensor:
444-
for op in self.op_list:
445-
x = op(x)
446-
return x
447-
448-
449434
def build_block(
450435
block_type: str, in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str]
451436
) -> nn.Module:
@@ -557,21 +542,22 @@ def build_encoder_project_in_block(in_channels: int, out_channels: int, factor:
557542
def build_encoder_project_out_block(
558543
in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str], shortcut: Optional[str]
559544
):
560-
block = OpSequential(
561-
[
562-
build_norm(norm),
563-
get_activation(act) if act is not None else None,
564-
ConvLayer(
565-
in_channels=in_channels,
566-
out_channels=out_channels,
567-
kernel_size=3,
568-
stride=1,
569-
use_bias=True,
570-
norm=None,
571-
act_func=None,
572-
),
573-
]
574-
)
545+
layers = []
546+
if norm is not None:
547+
layers.append(build_norm(norm))
548+
if act is not None:
549+
layers.append(get_activation(act))
550+
layers.append(ConvLayer(
551+
in_channels=in_channels,
552+
out_channels=out_channels,
553+
kernel_size=3,
554+
stride=1,
555+
use_bias=True,
556+
norm=None,
557+
act_func=None,
558+
))
559+
block = nn.Sequential(OrderedDict([("op_list", nn.Sequential(*layers))]))
560+
575561
if shortcut is None:
576562
pass
577563
elif shortcut == "averaging":
@@ -609,10 +595,12 @@ def build_decoder_project_in_block(in_channels: int, out_channels: int, shortcut
609595
def build_decoder_project_out_block(
610596
in_channels: int, out_channels: int, factor: int, upsample_block_type: str, norm: Optional[str], act: Optional[str]
611597
):
612-
layers: list[nn.Module] = [
613-
build_norm(norm, in_channels),
614-
get_activation(act) if act is not None else None,
615-
]
598+
layers: list[nn.Module] = []
599+
if norm is not None:
600+
layers.append(build_norm(norm, in_channels))
601+
if act is not None:
602+
layers.append(get_activation(act))
603+
616604
if factor == 1:
617605
layers.append(
618606
ConvLayer(
@@ -633,7 +621,8 @@ def build_decoder_project_out_block(
633621
)
634622
else:
635623
raise ValueError(f"upsample factor {factor} is not supported for decoder project out")
636-
return OpSequential(layers)
624+
block = nn.Sequential(OrderedDict([("op_list", nn.Sequential(*layers))]))
625+
return block
637626

638627

639628
class Encoder(nn.Module):
@@ -671,7 +660,7 @@ def __init__(
671660
downsample_block_type=downsample_block_type,
672661
)
673662

674-
self.stages: list[OpSequential] = []
663+
self.stages: list[nn.Module] = []
675664
for stage_id, (width, depth) in enumerate(zip(width_list, depth_list)):
676665
stage_block_type = block_type[stage_id] if isinstance(block_type, list) else block_type
677666
stage = build_stage_main(
@@ -685,7 +674,7 @@ def __init__(
685674
shortcut=downsample_shortcut,
686675
)
687676
stage.append(downsample_block)
688-
self.stages.append(OpSequential(stage))
677+
self.stages.append(nn.Sequential(OrderedDict([("op_list", nn.Sequential(*stage))])))
689678
self.stages = nn.ModuleList(self.stages)
690679

691680
self.project_out = build_encoder_project_out_block(
@@ -743,7 +732,7 @@ def __init__(
743732
shortcut=in_shortcut,
744733
)
745734

746-
self.stages: list[OpSequential] = []
735+
self.stages: list[nn.Module] = []
747736
for stage_id, (width, depth) in reversed(list(enumerate(zip(width_list, depth_list)))):
748737
stage = []
749738
if stage_id < num_stages - 1 and depth > 0:
@@ -770,7 +759,7 @@ def __init__(
770759
),
771760
)
772761
)
773-
self.stages.insert(0, OpSequential(stage))
762+
self.stages.insert(0, nn.Sequential(OrderedDict([("op_list", nn.Sequential(*stage))])))
774763
self.stages = nn.ModuleList(self.stages)
775764

776765
self.project_out = build_decoder_project_out_block(

0 commit comments

Comments
 (0)