@@ -299,6 +299,7 @@ def __init__(
299299 act_fn : Union [str , Tuple [str ]] = "silu" ,
300300 upsample_block_type : str = "pixel_shuffle" ,
301301 in_shortcut : bool = True ,
302+ conv_act_fn : str = "relu" ,
302303 ):
303304 super ().__init__ ()
304305
@@ -349,7 +350,7 @@ def __init__(
349350 channels = block_out_channels [0 ] if layers_per_block [0 ] > 0 else block_out_channels [1 ]
350351
351352 self .norm_out = RMSNorm (channels , 1e-5 , elementwise_affine = True , bias = True )
352- self .conv_act = nn . ReLU ( )
353+ self .conv_act = get_activation ( conv_act_fn )
353354 self .conv_out = None
354355
355356 if layers_per_block [0 ] > 0 :
@@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
414415 The normalization type(s) to use in the decoder.
415416 decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
416417 The activation function(s) to use in the decoder.
418+ encoder_out_shortcut (`bool`, defaults to `True`):
419+ Whether to use shortcut at the end of the encoder.
420+ decoder_in_shortcut (`bool`, defaults to `True`):
421+ Whether to use shortcut at the beginning of the decoder.
422+ decoder_conv_act_fn (`str`, defaults to `"relu"`):
423+ The activation function to use at the end of the decoder.
417424 scaling_factor (`float`, defaults to `1.0`):
418425 The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
419426 space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
@@ -441,6 +448,9 @@ def __init__(
441448 downsample_block_type : str = "pixel_unshuffle" ,
442449 decoder_norm_types : Union [str , Tuple [str ]] = "rms_norm" ,
443450 decoder_act_fns : Union [str , Tuple [str ]] = "silu" ,
451+ encoder_out_shortcut : bool = True ,
452+ decoder_in_shortcut : bool = True ,
453+ decoder_conv_act_fn : str = "relu" ,
444454 scaling_factor : float = 1.0 ,
445455 ) -> None :
446456 super ().__init__ ()
@@ -454,6 +464,7 @@ def __init__(
454464 layers_per_block = encoder_layers_per_block ,
455465 qkv_multiscales = encoder_qkv_multiscales ,
456466 downsample_block_type = downsample_block_type ,
467+ out_shortcut = encoder_out_shortcut ,
457468 )
458469 self .decoder = Decoder (
459470 in_channels = in_channels ,
@@ -466,6 +477,8 @@ def __init__(
466477 norm_type = decoder_norm_types ,
467478 act_fn = decoder_act_fns ,
468479 upsample_block_type = upsample_block_type ,
480+ in_shortcut = decoder_in_shortcut ,
481+ conv_act_fn = decoder_conv_act_fn ,
469482 )
470483
471484 self .spatial_compression_ratio = 2 ** (len (encoder_block_out_channels ) - 1 )
0 commit comments