@@ -34,6 +34,7 @@ class LinearBlock(nn.Module):
3434 act_func: Name of activation function
3535 dropout: Dropout probability
3636 norm: If True, apply layer normalization
37+ norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layer
3738 bias: If True, include bias term.
3839 dtype: Data type of the weights
3940 device: Device on which to store the weights
@@ -46,15 +47,15 @@ def __init__(
4647 act_func : str = "relu" ,
4748 dropout : float = 0.0 ,
4849 norm : bool = False ,
49- norm_kwargs : Optional [dict ] = dict () ,
50+ norm_kwargs : Optional [dict ] = None ,
5051 bias : bool = True ,
5152 dtype = None ,
5253 device = None ,
5354 ) -> None :
5455 super ().__init__ ()
5556
5657 self .norm = Norm (
57- func = "layer" if norm else None , in_dim = in_len , ** norm_kwargs , dtype = dtype , device = device
58+ func = "layer" if norm else None , in_dim = in_len , ** ( norm_kwargs or dict ()) , dtype = dtype , device = device
5859 )
5960 self .linear = nn .Linear (in_len , out_len , bias = bias , dtype = dtype , device = device )
6061 self .dropout = Dropout (dropout )
@@ -123,7 +124,7 @@ def __init__(
123124 dropout : float = 0.0 ,
124125 norm : bool = True ,
125126 norm_type = "batch" ,
126- norm_kwargs : Optional [dict ] = dict () ,
127+ norm_kwargs : Optional [dict ] = None ,
127128 residual : bool = False ,
128129 order : str = "CDNRA" ,
129130 bias : bool = True ,
@@ -152,15 +153,15 @@ def __init__(
152153 in_dim = out_channels ,
153154 dtype = dtype ,
154155 device = device ,
155- ** norm_kwargs ,
156+ ** ( norm_kwargs or dict ()) ,
156157 )
157158 else :
158159 self .norm = Norm (
159160 norm_type ,
160161 in_dim = in_channels ,
161162 dtype = dtype ,
162163 device = device ,
163- ** norm_kwargs ,
164+ ** ( norm_kwargs or dict ()) ,
164165 )
165166 else :
166167 self .norm = Norm (None )
@@ -231,7 +232,7 @@ class ChannelTransformBlock(nn.Module):
231232 act_func: Name of the activation function
232233 dropout: Dropout probability
233234 norm_type: Type of normalization to apply: 'batch', 'syncbatch', 'layer', 'instance' or None
234- norm_kwargs: Additional arguments to be passed to the normalization layer
235+ norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
235236 order: A string representing the order in which operations are
236237 to be performed on the input. For example, "CDNA" means that the
237238 operations will be performed in the order: convolution, dropout,
@@ -250,7 +251,7 @@ def __init__(
250251 dropout : float = 0.0 ,
251252 order : str = "CDNA" ,
252253 norm_type = "batch" ,
253- norm_kwargs : Optional [dict ] = dict () ,
254+ norm_kwargs : Optional [dict ] = None ,
254255 if_equal : bool = False ,
255256 dtype = None ,
256257 device = None ,
@@ -274,15 +275,15 @@ def __init__(
274275 in_dim = out_channels ,
275276 dtype = dtype ,
276277 device = device ,
277- ** norm_kwargs ,
278+ ** ( norm_kwargs or dict ()) ,
278279 )
279280 else :
280281 self .norm = Norm (
281282 "batch" ,
282283 in_dim = in_channels ,
283284 dtype = dtype ,
284285 device = device ,
285- ** norm_kwargs ,
286+ ** ( norm_kwargs or dict ()) ,
286287 )
287288 else :
288289 self .norm = Norm (None )
@@ -441,6 +442,7 @@ class ConvTower(nn.Module):
441442 pool_size: Width of the pooling layers
442443 dropout: Dropout probability
443444 norm: If True, apply batch norm
445+ norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
444446 residual: If True, apply residual connection
445447 order: A string representing the order in which operations are
446448 to be performed on the input. For example, "CDNRA" means that the
@@ -464,7 +466,7 @@ def __init__(
464466 dilation_mult : float = 1 ,
465467 act_func : str = "relu" ,
466468 norm : bool = False ,
467- norm_kwargs : Optional [dict ] = dict () ,
469+ norm_kwargs : Optional [dict ] = None ,
468470 pool_func : Optional [str ] = None ,
469471 pool_size : Optional [int ] = None ,
470472 residual : bool = False ,
@@ -565,15 +567,16 @@ class FeedForwardBlock(nn.Module):
565567 in_len: Length of the input tensor
566568 dropout: Dropout probability
567569 act_func: Name of the activation function
568- kwargs: Additional arguments to be passed to the linear layers
570+ norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
571+ **kwargs: Additional arguments to be passed to the linear layers
569572 """
570573
571574 def __init__ (
572575 self ,
573576 in_len : int ,
574577 dropout : float = 0.0 ,
575578 act_func : str = "relu" ,
576- norm_kwargs : Optional [dict ] = dict () ,
579+ norm_kwargs : Optional [dict ] = None ,
577580 ** kwargs ,
578581 ) -> None :
579582 super ().__init__ ()
@@ -693,6 +696,7 @@ class TransformerBlock(nn.Module):
693696 key_len: Length of the key vectors
694697 value_len: Length of the value vectors.
695698 pos_dropout: Dropout probability in the positional embeddings
699+ norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
696700 dtype: Data type of the weights
697701 device: Device on which to store the weights
698702 """
@@ -710,12 +714,12 @@ def __init__(
710714 key_len : Optional [int ] = None ,
711715 value_len : Optional [int ] = None ,
712716 pos_dropout : Optional [float ] = None ,
713- norm_kwargs : Optional [dict ] = dict () ,
717+ norm_kwargs : Optional [dict ] = None ,
714718 dtype = None ,
715719 device = None ,
716720 ) -> None :
717721 super ().__init__ ()
718- self .norm = Norm ("layer" , in_len , ** norm_kwargs )
722+ self .norm = Norm ("layer" , in_len , ** ( norm_kwargs or dict ()) )
719723
720724 if flash_attn :
721725 if (
@@ -795,10 +799,10 @@ class TransformerTower(nn.Module):
795799 key_len: Length of the key vectors
796800 value_len: Length of the value vectors.
797801 pos_dropout: Dropout probability in the positional embeddings
798- attn_dropout: Dropout probability in the output layer
799- ff_droppout : Dropout probability in the linear feed-forward layers
800- flash_attn: If True, uses Flash Attention with Rotational Position Embeddings. key_len, value_len,
801- pos_dropout and n_pos_features are ignored.
802+ attn_dropout: Dropout probability in the attention layer
803+ ff_dropout : Dropout probability in the feed-forward layers
804+ norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
805+ flash_attn: If True, uses Flash Attention with Rotational Position Embeddings
802806 dtype: Data type of the weights
803807 device: Device on which to store the weights
804808 """
@@ -814,7 +818,7 @@ def __init__(
814818 pos_dropout : float = 0.0 ,
815819 attn_dropout : float = 0.0 ,
816820 ff_dropout : float = 0.0 ,
817- norm_kwargs : Optional [dict ] = dict () ,
821+ norm_kwargs : Optional [dict ] = None ,
818822 flash_attn : bool = False ,
819823 dtype = None ,
820824 device = None ,
@@ -865,17 +869,20 @@ class UnetBlock(nn.Module):
865869 in_channels: Number of channels in the input
866870 y_in_channels: Number of channels in the higher-resolution representation.
867871 norm_type: Type of normalization to apply: 'batch', 'syncbatch', 'layer', 'instance' or None
868- norm_kwargs: Additional arguments to be passed to the normalization layer
869- device: Device on which to store the weights
872+ norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
873+ act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
874+ tanh approximation (different from PyTorch's default GELU implementation).
870875 dtype: Data type of the weights
876+ device: Device on which to store the weights
871877 """
872878
873879 def __init__ (
874880 self ,
875881 in_channels : int ,
876882 y_in_channels : int ,
877883 norm_type = "batch" ,
878- norm_kwargs : Optional [dict ] = dict (),
884+ norm_kwargs : Optional [dict ] = None ,
885+ act_func = "gelu_borzoi" ,
879886 dtype = None ,
880887 device = None ,
881888 ) -> None :
@@ -887,7 +894,7 @@ def __init__(
887894 norm = True ,
888895 norm_type = norm_type ,
889896 norm_kwargs = norm_kwargs ,
890- act_func = "gelu_borzoi" ,
897+ act_func = act_func ,
891898 order = "NACDR" ,
892899 dtype = dtype ,
893900 device = device ,
@@ -899,7 +906,7 @@ def __init__(
899906 norm = True ,
900907 norm_type = norm_type ,
901908 norm_kwargs = norm_kwargs ,
902- act_func = "gelu_borzoi" ,
909+ act_func = act_func ,
903910 order = "NACD" ,
904911 if_equal = True ,
905912 dtype = dtype ,
@@ -932,16 +939,18 @@ class UnetTower(nn.Module):
932939 in_channels: Number of channels in the input
933940 y_in_channels: Number of channels in the higher-resolution representations.
934941 n_blocks: Number of U-net blocks
942+ act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
943+ tanh approximation (different from PyTorch's default GELU implementation).
935944 kwargs: Additional arguments to be passed to the U-net blocks
936945 """
937946
938947 def __init__ (
939- self , in_channels : int , y_in_channels : List [int ], n_blocks : int , ** kwargs
948+ self , in_channels : int , y_in_channels : List [int ], n_blocks : int , act_func : str = "gelu_borzoi" , ** kwargs
940949 ) -> None :
941950 super ().__init__ ()
942951 self .blocks = nn .ModuleList ()
943952 for y_c in y_in_channels :
944- self .blocks .append (UnetBlock (in_channels , y_c , ** kwargs ))
953+ self .blocks .append (UnetBlock (in_channels , y_c , act_func = act_func , ** kwargs ))
945954
946955 def forward (self , x : Tensor , ys : List [Tensor ]) -> Tensor :
947956 """
0 commit comments