Skip to content

Commit 96e844b

Browse files
committed
update other blocks to support the removal of build_norm
1 parent 25ae389 commit 96e844b

File tree

1 file changed

+29
-6
lines changed
  • src/diffusers/models/autoencoders

1 file changed

+29
-6
lines changed

src/diffusers/models/autoencoders/dc_ae.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,14 @@ def __init__(
7373
groups=groups,
7474
bias=use_bias,
7575
)
76-
self.norm = build_norm(norm, num_features=out_channels)
76+
if norm is None:
77+
self.norm = None
78+
elif norm == "rms2d":
79+
self.norm = RMSNorm2d(normalized_shape=out_channels)
80+
elif norm == "bn2d":
81+
self.norm = BatchNorm2d(num_features=out_channels)
82+
else:
83+
raise ValueError(f"norm {norm} is not supported")
7784
self.act = get_activation(act_func) if act_func is not None else None
7885

7986
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -532,9 +539,17 @@ def build_encoder_project_in_block(in_channels: int, out_channels: int, factor:
532539
def build_encoder_project_out_block(
533540
in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str], shortcut: Optional[str]
534541
):
535-
layers = []
536-
if norm is not None:
537-
layers.append(build_norm(norm))
542+
layers: list[nn.Module] = []
543+
544+
if norm is None:
545+
pass
546+
elif norm == "rms2d":
547+
layers.append(RMSNorm2d(normalized_shape=in_channels))
548+
elif norm == "bn2d":
549+
layers.append(BatchNorm2d(num_features=in_channels))
550+
else:
551+
raise ValueError(f"norm {norm} is not supported")
552+
538553
if act is not None:
539554
layers.append(get_activation(act))
540555
layers.append(ConvLayer(
@@ -586,8 +601,16 @@ def build_decoder_project_out_block(
586601
in_channels: int, out_channels: int, factor: int, upsample_block_type: str, norm: Optional[str], act: Optional[str]
587602
):
588603
layers: list[nn.Module] = []
589-
if norm is not None:
590-
layers.append(build_norm(norm, in_channels))
604+
605+
if norm is None:
606+
pass
607+
elif norm == "rms2d":
608+
layers.append(RMSNorm2d(normalized_shape=in_channels))
609+
elif norm == "bn2d":
610+
layers.append(BatchNorm2d(num_features=in_channels))
611+
else:
612+
raise ValueError(f"norm {norm} is not supported")
613+
591614
if act is not None:
592615
layers.append(get_activation(act))
593616

0 commit comments

Comments
 (0)