@@ -46,14 +46,15 @@ def __init__(
4646 act_func : str = "relu" ,
4747 dropout : float = 0.0 ,
4848 norm : bool = False ,
49+ norm_kwargs : Optional [dict ] = dict (),
4950 bias : bool = True ,
5051 dtype = None ,
5152 device = None ,
5253 ) -> None :
5354 super ().__init__ ()
5455
5556 self .norm = Norm (
56- func = "layer" if norm else None , in_dim = in_len , dtype = dtype , device = device
57+ func = "layer" if norm else None , in_dim = in_len , ** norm_kwargs , dtype = dtype , device = device
5758 )
5859 self .linear = nn .Linear (in_len , out_len , bias = bias , dtype = dtype , device = device )
5960 self .dropout = Dropout (dropout )
@@ -122,7 +123,7 @@ def __init__(
122123 dropout : float = 0.0 ,
123124 norm : bool = True ,
124125 norm_type = "batch" ,
125- norm_kwargs = None ,
126+ norm_kwargs : Optional [ dict ] = dict () ,
126127 residual : bool = False ,
127128 order : str = "CDNRA" ,
128129 bias : bool = True ,
@@ -142,7 +143,6 @@ def __init__(
142143 "R" ,
143144 ], "The string supplied in order must contain one occurrence each of A, C, D, N and R."
144145 self .order = order
145- norm_kwargs = norm_kwargs or dict ()
146146
147147 # Create norm
148148 if norm :
@@ -250,7 +250,7 @@ def __init__(
250250 dropout : float = 0.0 ,
251251 order : str = "CDNA" ,
252252 norm_type = "batch" ,
253- norm_kwargs = None ,
253+ norm_kwargs : Optional [ dict ] = dict () ,
254254 if_equal : bool = False ,
255255 dtype = None ,
256256 device = None ,
@@ -265,7 +265,6 @@ def __init__(
265265 "N" ,
266266 ], "The string supplied in order must contain one occurrence each of A, C, D and N."
267267 self .order = order
268- norm_kwargs = norm_kwargs or dict ()
269268
270269 # Create batch norm
271270 if norm :
@@ -465,6 +464,7 @@ def __init__(
465464 dilation_mult : float = 1 ,
466465 act_func : str = "relu" ,
467466 norm : bool = False ,
467+ norm_kwargs : Optional [dict ] = dict (),
468468 pool_func : Optional [str ] = None ,
469469 pool_size : Optional [int ] = None ,
470470 residual : bool = False ,
@@ -507,6 +507,7 @@ def __init__(
507507 dilation = dilation ,
508508 act_func = act_func ,
509509 norm = norm ,
510+ norm_kwargs = norm_kwargs ,
510511 residual = residual ,
511512 pool_func = pool_func ,
512513 pool_size = pool_size ,
@@ -572,13 +573,15 @@ def __init__(
572573 in_len : int ,
573574 dropout : float = 0.0 ,
574575 act_func : str = "relu" ,
576+ norm_kwargs : Optional [dict ] = dict (),
575577 ** kwargs ,
576578 ) -> None :
577579 super ().__init__ ()
578580 self .dense1 = LinearBlock (
579581 in_len ,
580582 in_len * 2 ,
581583 norm = True ,
584+ norm_kwargs = norm_kwargs ,
582585 dropout = dropout ,
583586 act_func = act_func ,
584587 bias = True ,
@@ -588,6 +591,7 @@ def __init__(
588591 in_len * 2 ,
589592 in_len ,
590593 norm = False ,
594+ norm_kwargs = norm_kwargs ,
591595 dropout = dropout ,
592596 act_func = None ,
593597 bias = True ,
@@ -706,11 +710,12 @@ def __init__(
706710 key_len : Optional [int ] = None ,
707711 value_len : Optional [int ] = None ,
708712 pos_dropout : Optional [float ] = None ,
713+ norm_kwargs : Optional [dict ] = dict (),
709714 dtype = None ,
710715 device = None ,
711716 ) -> None :
712717 super ().__init__ ()
713- self .norm = Norm ("layer" , in_len )
718+ self .norm = Norm ("layer" , in_len , ** norm_kwargs )
714719
715720 if flash_attn :
716721 if (
@@ -752,6 +757,7 @@ def __init__(
752757 in_len = in_len ,
753758 dropout = ff_dropout ,
754759 act_func = "relu" ,
760+ norm_kwargs = norm_kwargs ,
755761 dtype = dtype ,
756762 device = device ,
757763 )
@@ -808,6 +814,7 @@ def __init__(
808814 pos_dropout : float = 0.0 ,
809815 attn_dropout : float = 0.0 ,
810816 ff_dropout : float = 0.0 ,
817+ norm_kwargs : Optional [dict ] = dict (),
811818 flash_attn : bool = False ,
812819 dtype = None ,
813820 device = None ,
@@ -825,6 +832,7 @@ def __init__(
825832 key_len = key_len ,
826833 value_len = value_len ,
827834 pos_dropout = pos_dropout ,
835+ norm_kwargs = norm_kwargs ,
828836 dtype = dtype ,
829837 device = device ,
830838 )
@@ -867,7 +875,7 @@ def __init__(
867875 in_channels : int ,
868876 y_in_channels : int ,
869877 norm_type = "batch" ,
870- norm_kwargs = None ,
878+ norm_kwargs : Optional [ dict ] = dict () ,
871879 dtype = None ,
872880 device = None ,
873881 ) -> None :
@@ -877,10 +885,10 @@ def __init__(
877885 in_channels ,
878886 1 ,
879887 norm = True ,
880- act_func = "gelu" ,
881- order = "NACDR" ,
882888 norm_type = norm_type ,
883889 norm_kwargs = norm_kwargs ,
890+ act_func = "gelu_borzoi" ,
891+ order = "NACDR" ,
884892 dtype = dtype ,
885893 device = device ,
886894 )
@@ -891,7 +899,7 @@ def __init__(
891899 norm = True ,
892900 norm_type = norm_type ,
893901 norm_kwargs = norm_kwargs ,
894- act_func = "gelu " ,
902+ act_func = "gelu_borzoi " ,
895903 order = "NACD" ,
896904 if_equal = True ,
897905 dtype = dtype ,
0 commit comments