Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/diffusers/models/autoencoders/autoencoder_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def __init__(
act_fn: Union[str, Tuple[str]] = "silu",
upsample_block_type: str = "pixel_shuffle",
in_shortcut: bool = True,
conv_act_fn: str = "relu",
):
super().__init__()

Expand Down Expand Up @@ -349,7 +350,7 @@ def __init__(
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]

self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
self.conv_act = nn.ReLU()
self.conv_act = get_activation(conv_act_fn)
self.conv_out = None

if layers_per_block[0] > 0:
Expand Down Expand Up @@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The normalization type(s) to use in the decoder.
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
The activation function(s) to use in the decoder.
encoder_out_shortcut (`bool`, defaults to `True`):
Whether to use shortcut at the end of the encoder.
decoder_in_shortcut (`bool`, defaults to `True`):
Whether to use shortcut at the beginning of the decoder.
decoder_conv_act_fn (`str`, defaults to `"relu"`):
The activation function to use at the end of the decoder.
scaling_factor (`float`, defaults to `1.0`):
The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
Expand Down Expand Up @@ -441,6 +448,9 @@ def __init__(
downsample_block_type: str = "pixel_unshuffle",
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
decoder_act_fns: Union[str, Tuple[str]] = "silu",
encoder_out_shortcut: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting for other reviewers - this change looks safe because previous out_shortcut default was True as well in the encoder

decoder_in_shortcut: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is safe because previous in_shortcut default was True in decoder

decoder_conv_act_fn: str = "relu",
scaling_factor: float = 1.0,
) -> None:
super().__init__()
Expand All @@ -454,6 +464,7 @@ def __init__(
layers_per_block=encoder_layers_per_block,
qkv_multiscales=encoder_qkv_multiscales,
downsample_block_type=downsample_block_type,
out_shortcut=encoder_out_shortcut,
)
self.decoder = Decoder(
in_channels=in_channels,
Expand All @@ -466,6 +477,8 @@ def __init__(
norm_type=decoder_norm_types,
act_fn=decoder_act_fns,
upsample_block_type=upsample_block_type,
in_shortcut=decoder_in_shortcut,
conv_act_fn=decoder_conv_act_fn,
)

self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
Expand Down
Loading