Skip to content

Commit cab56b1

Browse files
committed
remove inheritance of RMSNorm2d from LayerNorm
1 parent 30d6308 commit cab56b1

File tree

1 file changed

+95
-72
lines changed
  • src/diffusers/models/autoencoders

1 file changed

+95
-72
lines changed

src/diffusers/models/autoencoders/dc_ae.py

Lines changed: 95 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,31 @@
3333
from .vae import DecoderOutput
3434

3535

36-
class RMSNorm2d(nn.LayerNorm):
36+
class RMSNorm2d(nn.Module):
37+
def __init__(self, num_features: int, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True, device=None, dtype=None) -> None:
38+
factory_kwargs = {'device': device, 'dtype': dtype}
39+
super().__init__()
40+
self.num_features = num_features
41+
self.eps = eps
42+
self.elementwise_affine = elementwise_affine
43+
if self.elementwise_affine:
44+
self.weight = torch.nn.parameter.Parameter(torch.empty(self.num_features, **factory_kwargs))
45+
if bias:
46+
self.bias = torch.nn.parameter.Parameter(torch.empty(self.num_features, **factory_kwargs))
47+
else:
48+
self.register_parameter('bias', None)
49+
else:
50+
self.register_parameter('weight', None)
51+
self.register_parameter('bias', None)
52+
53+
self.reset_parameters()
54+
55+
def reset_parameters(self) -> None:
56+
if self.elementwise_affine:
57+
torch.nn.init.ones_(self.weight)
58+
if self.bias is not None:
59+
torch.nn.init.zeros_(self.bias)
60+
3761
def forward(self, x: torch.Tensor) -> torch.Tensor:
3862
x = (x / torch.sqrt(torch.square(x.float()).mean(dim=1, keepdim=True) + self.eps)).to(x.dtype)
3963
if self.elementwise_affine:
@@ -74,7 +98,7 @@ def __init__(
7498
if norm is None:
7599
self.norm = None
76100
elif norm == "rms2d":
77-
self.norm = RMSNorm2d(normalized_shape=out_channels)
101+
self.norm = RMSNorm2d(num_features=out_channels)
78102
elif norm == "bn2d":
79103
self.norm = BatchNorm2d(num_features=out_channels)
80104
else:
@@ -469,54 +493,6 @@ def build_stage_main(
469493
return stage
470494

471495

472-
def build_downsample_block(block_type: str, in_channels: int, out_channels: int, shortcut: Optional[str]) -> nn.Module:
473-
if block_type == "Conv":
474-
block = nn.Conv2d(
475-
in_channels=in_channels,
476-
out_channels=out_channels,
477-
kernel_size=3,
478-
stride=2,
479-
padding=1,
480-
)
481-
elif block_type == "ConvPixelUnshuffle":
482-
block = ConvPixelUnshuffleDownsample2D(
483-
in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2
484-
)
485-
else:
486-
raise ValueError(f"block_type {block_type} is not supported for downsampling")
487-
if shortcut is None:
488-
pass
489-
elif shortcut == "averaging":
490-
shortcut_block = PixelUnshuffleChannelAveragingDownsample2D(
491-
in_channels=in_channels, out_channels=out_channels, factor=2
492-
)
493-
block = ResidualBlock(block, shortcut_block)
494-
else:
495-
raise ValueError(f"shortcut {shortcut} is not supported for downsample")
496-
return block
497-
498-
499-
def build_upsample_block(block_type: str, in_channels: int, out_channels: int, shortcut: Optional[str]) -> nn.Module:
500-
if block_type == "ConvPixelShuffle":
501-
block = ConvPixelShuffleUpsample2D(
502-
in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2
503-
)
504-
elif block_type == "InterpolateConv":
505-
block = Upsample2D(channels=in_channels, use_conv=True, out_channels=out_channels)
506-
else:
507-
raise ValueError(f"block_type {block_type} is not supported for upsampling")
508-
if shortcut is None:
509-
pass
510-
elif shortcut == "duplicating":
511-
shortcut_block = ChannelDuplicatingPixelUnshuffleUpsample2D(
512-
in_channels=in_channels, out_channels=out_channels, factor=2
513-
)
514-
block = ResidualBlock(block, shortcut_block)
515-
else:
516-
raise ValueError(f"shortcut {shortcut} is not supported for upsample")
517-
return block
518-
519-
520496
class Encoder(nn.Module):
521497
def __init__(
522498
self,
@@ -547,18 +523,30 @@ def __init__(
547523

548524
# project in
549525
if depth_list[0] > 0:
550-
self.project_in = nn.Conv2d(
526+
project_in_block = nn.Conv2d(
551527
in_channels=in_channels,
552528
out_channels=width_list[0],
553529
kernel_size=3,
554530
padding=1,
555531
)
556532
elif depth_list[1] > 0:
557-
self.project_in = build_downsample_block(
558-
block_type=downsample_block_type, in_channels=in_channels, out_channels=width_list[1], shortcut=None
559-
)
533+
if downsample_block_type == "Conv":
534+
project_in_block = nn.Conv2d(
535+
in_channels=in_channels,
536+
out_channels=width_list[1],
537+
kernel_size=3,
538+
stride=2,
539+
padding=1,
540+
)
541+
elif downsample_block_type == "ConvPixelUnshuffle":
542+
project_in_block = ConvPixelUnshuffleDownsample2D(
543+
in_channels=in_channels, out_channels=width_list[1], kernel_size=3, factor=2
544+
)
545+
else:
546+
raise ValueError(f"block_type {downsample_block_type} is not supported for downsampling")
560547
else:
561548
raise ValueError(f"depth list {depth_list} is not supported for encoder project in")
549+
self.project_in = project_in_block
562550

563551
# stages
564552
self.stages: list[nn.Module] = []
@@ -568,12 +556,30 @@ def __init__(
568556
width=width, depth=depth, block_type=stage_block_type, norm=norm, act=act, input_width=width
569557
)
570558
if stage_id < num_stages - 1 and depth > 0:
571-
downsample_block = build_downsample_block(
572-
block_type=downsample_block_type,
573-
in_channels=width,
574-
out_channels=width_list[stage_id + 1] if downsample_match_channel else width,
575-
shortcut=downsample_shortcut,
576-
)
559+
downsample_out_channels = width_list[stage_id + 1] if downsample_match_channel else width
560+
if downsample_block_type == "Conv":
561+
downsample_block = nn.Conv2d(
562+
in_channels=width,
563+
out_channels=downsample_out_channels,
564+
kernel_size=3,
565+
stride=2,
566+
padding=1,
567+
)
568+
elif downsample_block_type == "ConvPixelUnshuffle":
569+
downsample_block = ConvPixelUnshuffleDownsample2D(
570+
in_channels=width, out_channels=downsample_out_channels, kernel_size=3, factor=2
571+
)
572+
else:
573+
raise ValueError(f"downsample_block_type {downsample_block_type} is not supported for downsampling")
574+
if downsample_shortcut is None:
575+
pass
576+
elif downsample_shortcut == "averaging":
577+
shortcut_block = PixelUnshuffleChannelAveragingDownsample2D(
578+
in_channels=width, out_channels=downsample_out_channels, factor=2
579+
)
580+
downsample_block = ResidualBlock(downsample_block, shortcut_block)
581+
else:
582+
raise ValueError(f"shortcut {downsample_shortcut} is not supported for downsample")
577583
stage.append(downsample_block)
578584
self.stages.append(nn.Sequential(OrderedDict([("op_list", nn.Sequential(*stage))])))
579585
self.stages = nn.ModuleList(self.stages)
@@ -583,7 +589,7 @@ def __init__(
583589
if out_norm is None:
584590
pass
585591
elif out_norm == "rms2d":
586-
project_out_layers.append(RMSNorm2d(normalized_shape=width_list[-1]))
592+
project_out_layers.append(RMSNorm2d(num_features=width_list[-1]))
587593
elif out_norm == "bn2d":
588594
project_out_layers.append(BatchNorm2d(num_features=width_list[-1]))
589595
else:
@@ -679,12 +685,24 @@ def __init__(
679685
for stage_id, (width, depth) in reversed(list(enumerate(zip(width_list, depth_list)))):
680686
stage = []
681687
if stage_id < num_stages - 1 and depth > 0:
682-
upsample_block = build_upsample_block(
683-
block_type=upsample_block_type,
684-
in_channels=width_list[stage_id + 1],
685-
out_channels=width if upsample_match_channel else width_list[stage_id + 1],
686-
shortcut=upsample_shortcut,
687-
)
688+
upsample_out_channels = width if upsample_match_channel else width_list[stage_id + 1]
689+
if upsample_block_type == "ConvPixelShuffle":
690+
upsample_block = ConvPixelShuffleUpsample2D(
691+
in_channels=width_list[stage_id + 1], out_channels=upsample_out_channels, kernel_size=3, factor=2
692+
)
693+
elif upsample_block_type == "InterpolateConv":
694+
upsample_block = Upsample2D(channels=width_list[stage_id + 1], use_conv=True, out_channels=upsample_out_channels)
695+
else:
696+
raise ValueError(f"upsample_block_type {upsample_block_type} is not supported")
697+
if upsample_shortcut is None:
698+
pass
699+
elif upsample_shortcut == "duplicating":
700+
shortcut_block = ChannelDuplicatingPixelUnshuffleUpsample2D(
701+
in_channels=width_list[stage_id + 1], out_channels=upsample_out_channels, factor=2
702+
)
703+
upsample_block = ResidualBlock(upsample_block, shortcut_block)
704+
else:
705+
raise ValueError(f"shortcut {upsample_shortcut} is not supported for upsample")
688706
stage.append(upsample_block)
689707

690708
stage_block_type = block_type[stage_id] if isinstance(block_type, list) else block_type
@@ -716,7 +734,7 @@ def __init__(
716734
if out_norm is None:
717735
pass
718736
elif out_norm == "rms2d":
719-
project_out_layers.append(RMSNorm2d(normalized_shape=project_out_in_channels))
737+
project_out_layers.append(RMSNorm2d(num_features=project_out_in_channels))
720738
elif out_norm == "bn2d":
721739
project_out_layers.append(BatchNorm2d(num_features=project_out_in_channels))
722740
else:
@@ -735,11 +753,16 @@ def __init__(
735753
)
736754
)
737755
elif depth_list[1] > 0:
738-
project_out_layers.append(
739-
build_upsample_block(
740-
block_type=upsample_block_type, in_channels=project_out_in_channels, out_channels=in_channels, shortcut=None
756+
if upsample_block_type == "ConvPixelShuffle":
757+
project_out_conv = ConvPixelShuffleUpsample2D(
758+
in_channels=project_out_in_channels, out_channels=in_channels, kernel_size=3, factor=2
741759
)
742-
)
760+
elif upsample_block_type == "InterpolateConv":
761+
project_out_conv = Upsample2D(channels=project_out_in_channels, use_conv=True, out_channels=in_channels)
762+
else:
763+
raise ValueError(f"upsample_block_type {upsample_block_type} is not supported for upsampling")
764+
765+
project_out_layers.append(project_out_conv)
743766
else:
744767
raise ValueError(f"depth list {depth_list} is not supported for decoder project out")
745768
self.project_out = nn.Sequential(OrderedDict([("op_list", nn.Sequential(*project_out_layers))]))

0 commit comments

Comments
 (0)