@@ -299,6 +299,7 @@ def __init__(
299299        act_fn : Union [str , Tuple [str ]] =  "silu" ,
300300        upsample_block_type : str  =  "pixel_shuffle" ,
301301        in_shortcut : bool  =  True ,
302+         conv_act_fn : str  =  "relu" ,
302303    ):
303304        super ().__init__ ()
304305
@@ -349,7 +350,7 @@ def __init__(
349350        channels  =  block_out_channels [0 ] if  layers_per_block [0 ] >  0  else  block_out_channels [1 ]
350351
351352        self .norm_out  =  RMSNorm (channels , 1e-5 , elementwise_affine = True , bias = True )
352-         self .conv_act  =  nn . ReLU ( )
353+         self .conv_act  =  get_activation ( conv_act_fn )
353354        self .conv_out  =  None 
354355
355356        if  layers_per_block [0 ] >  0 :
@@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
414415            The normalization type(s) to use in the decoder. 
415416        decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`): 
416417            The activation function(s) to use in the decoder. 
418+         encoder_out_shortcut  (`bool`, defaults to `True`): 
419+             Whether to use shortcut at the end of the encoder. 
420+         decoder_in_shortcut (`bool`, defaults to `True`): 
421+             Whether to use shortcut at the beginning of the decoder. 
422+         decoder_conv_act_fn (`str`, defaults to `"relu"`): 
423+             The activation function to use at the end of the decoder. 
417424        scaling_factor (`float`, defaults to `1.0`): 
418425            The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent 
419426            space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = 
@@ -441,6 +448,9 @@ def __init__(
441448        downsample_block_type : str  =  "pixel_unshuffle" ,
442449        decoder_norm_types : Union [str , Tuple [str ]] =  "rms_norm" ,
443450        decoder_act_fns : Union [str , Tuple [str ]] =  "silu" ,
451+         encoder_out_shortcut : bool  =  True ,
452+         decoder_in_shortcut : bool  =  True ,
453+         decoder_conv_act_fn : str  =  "relu" ,
444454        scaling_factor : float  =  1.0 ,
445455    ) ->  None :
446456        super ().__init__ ()
@@ -454,6 +464,7 @@ def __init__(
454464            layers_per_block = encoder_layers_per_block ,
455465            qkv_multiscales = encoder_qkv_multiscales ,
456466            downsample_block_type = downsample_block_type ,
467+             out_shortcut = encoder_out_shortcut ,
457468        )
458469        self .decoder  =  Decoder (
459470            in_channels = in_channels ,
@@ -466,6 +477,8 @@ def __init__(
466477            norm_type = decoder_norm_types ,
467478            act_fn = decoder_act_fns ,
468479            upsample_block_type = upsample_block_type ,
480+             in_shortcut = decoder_in_shortcut ,
481+             conv_act_fn = decoder_conv_act_fn ,
469482        )
470483
471484        self .spatial_compression_ratio  =  2  **  (len (encoder_block_out_channels ) -  1 )
0 commit comments