Skip to content

Commit 2f6bbad

Browse files
committed
remove build_stage_main
1 parent 4f5cbb4 commit 2f6bbad

File tree

1 file changed

+48
-55
lines changed
  • src/diffusers/models/autoencoders

1 file changed

+48
-55
lines changed

src/diffusers/models/autoencoders/dc_ae.py

Lines changed: 48 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -439,41 +439,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
439439
return res
440440

441441

442-
def build_stage_main(
443-
width: int, depth: int, block_type: str | list[str], norm: str, act: str, input_width: int
444-
) -> list[nn.Module]:
445-
assert isinstance(block_type, str) or (isinstance(block_type, list) and depth == len(block_type))
446-
stage = []
447-
for d in range(depth):
448-
current_block_type = block_type[d] if isinstance(block_type, list) else block_type
449-
450-
in_channels = width if d > 0 else input_width
451-
out_channels = width
452-
453-
if current_block_type == "ResBlock":
454-
assert in_channels == out_channels
455-
block = ResBlock(
456-
in_channels=in_channels,
457-
out_channels=out_channels,
458-
kernel_size=3,
459-
stride=1,
460-
use_bias=(True, False),
461-
norm=(None, norm),
462-
act_func=(act, None),
463-
)
464-
elif current_block_type == "EViTGLU":
465-
assert in_channels == out_channels
466-
block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=())
467-
elif current_block_type == "EViTS5GLU":
468-
assert in_channels == out_channels
469-
block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=(5,))
470-
else:
471-
raise ValueError(f"block_type {current_block_type} is not supported")
472-
473-
stage.append(block)
474-
return stage
475-
476-
477442
class Encoder(nn.Module):
478443
def __init__(
479444
self,
@@ -485,7 +450,6 @@ def __init__(
485450
norm: str = "rms2d",
486451
act: str = "silu",
487452
downsample_block_type: str = "ConvPixelUnshuffle",
488-
downsample_match_channel: bool = True,
489453
downsample_shortcut: Optional[str] = "averaging",
490454
out_norm: Optional[str] = None,
491455
out_act: Optional[str] = None,
@@ -533,11 +497,32 @@ def __init__(
533497
self.stages: list[nn.Module] = []
534498
for stage_id, (width, depth) in enumerate(zip(width_list, depth_list)):
535499
stage_block_type = block_type[stage_id] if isinstance(block_type, list) else block_type
536-
stage = build_stage_main(
537-
width=width, depth=depth, block_type=stage_block_type, norm=norm, act=act, input_width=width
538-
)
500+
if not (isinstance(stage_block_type, str) or (isinstance(stage_block_type, list) and depth == len(stage_block_type))):
501+
raise ValueError(f"block type {stage_block_type} is not supported for encoder stage {stage_id} with depth {depth}")
502+
stage = []
503+
# stage main
504+
for d in range(depth):
505+
current_block_type = stage_block_type[d] if isinstance(stage_block_type, list) else stage_block_type
506+
if current_block_type == "ResBlock":
507+
block = ResBlock(
508+
in_channels=width,
509+
out_channels=width,
510+
kernel_size=3,
511+
stride=1,
512+
use_bias=(True, False),
513+
norm=(None, norm),
514+
act_func=(act, None),
515+
)
516+
elif current_block_type == "EViTGLU":
517+
block = EfficientViTBlock(width, norm=norm, act_func=act, local_module="GLUMBConv", scales=())
518+
elif current_block_type == "EViTS5GLU":
519+
block = EfficientViTBlock(width, norm=norm, act_func=act, local_module="GLUMBConv", scales=(5,))
520+
else:
521+
raise ValueError(f"block type {current_block_type} is not supported")
522+
stage.append(block)
523+
# downsample
539524
if stage_id < num_stages - 1 and depth > 0:
540-
downsample_out_channels = width_list[stage_id + 1] if downsample_match_channel else width
525+
downsample_out_channels = width_list[stage_id + 1]
541526
if downsample_block_type == "Conv":
542527
downsample_block = nn.Conv2d(
543528
in_channels=width,
@@ -621,7 +606,6 @@ def __init__(
621606
norm: str | list[str] = "rms2d",
622607
act: str | list[str] = "silu",
623608
upsample_block_type: str = "ConvPixelShuffle",
624-
upsample_match_channel: bool = True,
625609
upsample_shortcut: str = "duplicating",
626610
out_norm: str = "rms2d",
627611
out_act: str = "relu",
@@ -665,8 +649,9 @@ def __init__(
665649
self.stages: list[nn.Module] = []
666650
for stage_id, (width, depth) in reversed(list(enumerate(zip(width_list, depth_list)))):
667651
stage = []
652+
# upsample
668653
if stage_id < num_stages - 1 and depth > 0:
669-
upsample_out_channels = width if upsample_match_channel else width_list[stage_id + 1]
654+
upsample_out_channels = width
670655
if upsample_block_type == "ConvPixelShuffle":
671656
upsample_block = ConvPixelShuffleUpsample2D(
672657
in_channels=width_list[stage_id + 1], out_channels=upsample_out_channels, kernel_size=3, factor=2
@@ -685,22 +670,30 @@ def __init__(
685670
else:
686671
raise ValueError(f"shortcut {upsample_shortcut} is not supported for upsample")
687672
stage.append(upsample_block)
688-
673+
# stage main
689674
stage_block_type = block_type[stage_id] if isinstance(block_type, list) else block_type
690675
stage_norm = norm[stage_id] if isinstance(norm, list) else norm
691676
stage_act = act[stage_id] if isinstance(act, list) else act
692-
stage.extend(
693-
build_stage_main(
694-
width=width,
695-
depth=depth,
696-
block_type=stage_block_type,
697-
norm=stage_norm,
698-
act=stage_act,
699-
input_width=(
700-
width if upsample_match_channel else width_list[min(stage_id + 1, num_stages - 1)]
701-
),
702-
)
703-
)
677+
for d in range(depth):
678+
current_block_type = stage_block_type[d] if isinstance(stage_block_type, list) else stage_block_type
679+
if current_block_type == "ResBlock":
680+
block = ResBlock(
681+
in_channels=width,
682+
out_channels=width,
683+
kernel_size=3,
684+
stride=1,
685+
use_bias=(True, False),
686+
norm=(None, stage_norm),
687+
act_func=(stage_act, None),
688+
)
689+
elif current_block_type == "EViTGLU":
690+
block = EfficientViTBlock(width, norm=stage_norm, act_func=stage_act, local_module="GLUMBConv", scales=())
691+
elif current_block_type == "EViTS5GLU":
692+
block = EfficientViTBlock(width, norm=stage_norm, act_func=stage_act, local_module="GLUMBConv", scales=(5,))
693+
else:
694+
raise ValueError(f"block type {current_block_type} is not supported")
695+
stage.append(block)
696+
704697
self.stages.insert(0, nn.Sequential(*stage))
705698
self.stages = nn.ModuleList(self.stages)
706699

0 commit comments

Comments
 (0)