Skip to content

Commit a109d12

Browse files
authored
Merge pull request #153 from MuhammedHasan/better-defaults-borzoi-activation-norm
better-defaults-borzoi-activation-norm
2 parents f994085 + 5c48bf4 commit a109d12

File tree

5 files changed

+93
-35
lines changed

5 files changed

+93
-35
lines changed

src/grelu/model/blocks.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class LinearBlock(nn.Module):
3434
act_func: Name of activation function
3535
dropout: Dropout probability
3636
norm: If True, apply layer normalization
37+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layer
3738
bias: If True, include bias term.
3839
dtype: Data type of the weights
3940
device: Device on which to store the weights
@@ -46,15 +47,15 @@ def __init__(
4647
act_func: str = "relu",
4748
dropout: float = 0.0,
4849
norm: bool = False,
49-
norm_kwargs: Optional[dict] = dict(),
50+
norm_kwargs: Optional[dict] = None,
5051
bias: bool = True,
5152
dtype=None,
5253
device=None,
5354
) -> None:
5455
super().__init__()
5556

5657
self.norm = Norm(
57-
func="layer" if norm else None, in_dim=in_len, **norm_kwargs, dtype=dtype, device=device
58+
func="layer" if norm else None, in_dim=in_len, **(norm_kwargs or dict()), dtype=dtype, device=device
5859
)
5960
self.linear = nn.Linear(in_len, out_len, bias=bias, dtype=dtype, device=device)
6061
self.dropout = Dropout(dropout)
@@ -123,7 +124,7 @@ def __init__(
123124
dropout: float = 0.0,
124125
norm: bool = True,
125126
norm_type="batch",
126-
norm_kwargs: Optional[dict] = dict(),
127+
norm_kwargs: Optional[dict] = None,
127128
residual: bool = False,
128129
order: str = "CDNRA",
129130
bias: bool = True,
@@ -152,15 +153,15 @@ def __init__(
152153
in_dim=out_channels,
153154
dtype=dtype,
154155
device=device,
155-
**norm_kwargs,
156+
**(norm_kwargs or dict()),
156157
)
157158
else:
158159
self.norm = Norm(
159160
norm_type,
160161
in_dim=in_channels,
161162
dtype=dtype,
162163
device=device,
163-
**norm_kwargs,
164+
**(norm_kwargs or dict()),
164165
)
165166
else:
166167
self.norm = Norm(None)
@@ -231,7 +232,7 @@ class ChannelTransformBlock(nn.Module):
231232
act_func: Name of the activation function
232233
dropout: Dropout probability
233234
norm_type: Type of normalization to apply: 'batch', 'syncbatch', 'layer', 'instance' or None
234-
norm_kwargs: Additional arguments to be passed to the normalization layer
235+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
235236
order: A string representing the order in which operations are
236237
to be performed on the input. For example, "CDNA" means that the
237238
operations will be performed in the order: convolution, dropout,
@@ -250,7 +251,7 @@ def __init__(
250251
dropout: float = 0.0,
251252
order: str = "CDNA",
252253
norm_type="batch",
253-
norm_kwargs: Optional[dict] = dict(),
254+
norm_kwargs: Optional[dict] = None,
254255
if_equal: bool = False,
255256
dtype=None,
256257
device=None,
@@ -274,15 +275,15 @@ def __init__(
274275
in_dim=out_channels,
275276
dtype=dtype,
276277
device=device,
277-
**norm_kwargs,
278+
**(norm_kwargs or dict()),
278279
)
279280
else:
280281
self.norm = Norm(
281282
"batch",
282283
in_dim=in_channels,
283284
dtype=dtype,
284285
device=device,
285-
**norm_kwargs,
286+
**(norm_kwargs or dict()),
286287
)
287288
else:
288289
self.norm = Norm(None)
@@ -441,6 +442,7 @@ class ConvTower(nn.Module):
441442
pool_size: Width of the pooling layers
442443
dropout: Dropout probability
443444
norm: If True, apply batch norm
445+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
444446
residual: If True, apply residual connection
445447
order: A string representing the order in which operations are
446448
to be performed on the input. For example, "CDNRA" means that the
@@ -464,7 +466,7 @@ def __init__(
464466
dilation_mult: float = 1,
465467
act_func: str = "relu",
466468
norm: bool = False,
467-
norm_kwargs: Optional[dict] = dict(),
469+
norm_kwargs: Optional[dict] = None,
468470
pool_func: Optional[str] = None,
469471
pool_size: Optional[int] = None,
470472
residual: bool = False,
@@ -565,15 +567,16 @@ class FeedForwardBlock(nn.Module):
565567
in_len: Length of the input tensor
566568
dropout: Dropout probability
567569
act_func: Name of the activation function
568-
kwargs: Additional arguments to be passed to the linear layers
570+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
571+
**kwargs: Additional arguments to be passed to the linear layers
569572
"""
570573

571574
def __init__(
572575
self,
573576
in_len: int,
574577
dropout: float = 0.0,
575578
act_func: str = "relu",
576-
norm_kwargs: Optional[dict] = dict(),
579+
norm_kwargs: Optional[dict] = None,
577580
**kwargs,
578581
) -> None:
579582
super().__init__()
@@ -693,6 +696,7 @@ class TransformerBlock(nn.Module):
693696
key_len: Length of the key vectors
694697
value_len: Length of the value vectors.
695698
pos_dropout: Dropout probability in the positional embeddings
699+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
696700
dtype: Data type of the weights
697701
device: Device on which to store the weights
698702
"""
@@ -710,12 +714,12 @@ def __init__(
710714
key_len: Optional[int] = None,
711715
value_len: Optional[int] = None,
712716
pos_dropout: Optional[float] = None,
713-
norm_kwargs: Optional[dict] = dict(),
717+
norm_kwargs: Optional[dict] = None,
714718
dtype=None,
715719
device=None,
716720
) -> None:
717721
super().__init__()
718-
self.norm = Norm("layer", in_len, **norm_kwargs)
722+
self.norm = Norm("layer", in_len, **(norm_kwargs or dict()))
719723

720724
if flash_attn:
721725
if (
@@ -795,10 +799,10 @@ class TransformerTower(nn.Module):
795799
key_len: Length of the key vectors
796800
value_len: Length of the value vectors.
797801
pos_dropout: Dropout probability in the positional embeddings
798-
attn_dropout: Dropout probability in the output layer
799-
ff_droppout: Dropout probability in the linear feed-forward layers
800-
flash_attn: If True, uses Flash Attention with Rotational Position Embeddings. key_len, value_len,
801-
pos_dropout and n_pos_features are ignored.
802+
attn_dropout: Dropout probability in the attention layer
803+
ff_dropout: Dropout probability in the feed-forward layers
804+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
805+
flash_attn: If True, uses Flash Attention with Rotational Position Embeddings
802806
dtype: Data type of the weights
803807
device: Device on which to store the weights
804808
"""
@@ -814,7 +818,7 @@ def __init__(
814818
pos_dropout: float = 0.0,
815819
attn_dropout: float = 0.0,
816820
ff_dropout: float = 0.0,
817-
norm_kwargs: Optional[dict] = dict(),
821+
norm_kwargs: Optional[dict] = None,
818822
flash_attn: bool = False,
819823
dtype=None,
820824
device=None,
@@ -865,17 +869,20 @@ class UnetBlock(nn.Module):
865869
in_channels: Number of channels in the input
866870
y_in_channels: Number of channels in the higher-resolution representation.
867871
norm_type: Type of normalization to apply: 'batch', 'syncbatch', 'layer', 'instance' or None
868-
norm_kwargs: Additional arguments to be passed to the normalization layer
869-
device: Device on which to store the weights
872+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
873+
act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
874+
tanh approximation (different from PyTorch's default GELU implementation).
870875
dtype: Data type of the weights
876+
device: Device on which to store the weights
871877
"""
872878

873879
def __init__(
874880
self,
875881
in_channels: int,
876882
y_in_channels: int,
877883
norm_type="batch",
878-
norm_kwargs: Optional[dict] = dict(),
884+
norm_kwargs: Optional[dict] = None,
885+
act_func="gelu_borzoi",
879886
dtype=None,
880887
device=None,
881888
) -> None:
@@ -887,7 +894,7 @@ def __init__(
887894
norm=True,
888895
norm_type=norm_type,
889896
norm_kwargs=norm_kwargs,
890-
act_func="gelu_borzoi",
897+
act_func=act_func,
891898
order="NACDR",
892899
dtype=dtype,
893900
device=device,
@@ -899,7 +906,7 @@ def __init__(
899906
norm=True,
900907
norm_type=norm_type,
901908
norm_kwargs=norm_kwargs,
902-
act_func="gelu_borzoi",
909+
act_func=act_func,
903910
order="NACD",
904911
if_equal=True,
905912
dtype=dtype,
@@ -932,16 +939,18 @@ class UnetTower(nn.Module):
932939
in_channels: Number of channels in the input
933940
y_in_channels: Number of channels in the higher-resolution representations.
934941
n_blocks: Number of U-net blocks
942+
act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
943+
tanh approximation (different from PyTorch's default GELU implementation).
935944
kwargs: Additional arguments to be passed to the U-net blocks
936945
"""
937946

938947
def __init__(
939-
self, in_channels: int, y_in_channels: List[int], n_blocks: int, **kwargs
948+
self, in_channels: int, y_in_channels: List[int], n_blocks: int, act_func: str = "gelu_borzoi", **kwargs
940949
) -> None:
941950
super().__init__()
942951
self.blocks = nn.ModuleList()
943952
for y_c in y_in_channels:
944-
self.blocks.append(UnetBlock(in_channels, y_c, **kwargs))
953+
self.blocks.append(UnetBlock(in_channels, y_c, act_func=act_func, **kwargs))
945954

946955
def forward(self, x: Tensor, ys: List[Tensor]) -> Tensor:
947956
"""

src/grelu/model/heads.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class ConvHead(nn.Module):
2727
norm: If True, batch normalization will be included.
2828
act_func: Activation function for the convolutional layer
2929
pool_func: Pooling function.
30+
norm: If True, batch normalization will be included.
31+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layer
3032
dtype: Data type for the layers.
3133
device: Device for the layers.
3234
"""
@@ -38,7 +40,7 @@ def __init__(
3840
act_func: Optional[str] = None,
3941
pool_func: Optional[str] = None,
4042
norm: bool = False,
41-
norm_kwargs: Optional[dict] = dict(),
43+
norm_kwargs: Optional[dict] = None,
4244
dtype=None,
4345
device=None,
4446
) -> None:
@@ -56,7 +58,7 @@ def __init__(
5658
self.n_tasks,
5759
act_func=self.act_func,
5860
norm=self.norm,
59-
norm_kwargs=norm_kwargs,
61+
norm_kwargs=(norm_kwargs or dict()),
6062
dtype=dtype,
6163
device=device,
6264
)

src/grelu/model/layers.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,15 @@ class Activation(nn.Module):
2020
A nonlinear activation layer.
2121
2222
Args:
23-
func: The type of activation function. Supported values are 'relu',
24-
'elu', 'softplus', 'gelu', 'gelu_borzoi', 'gelu_enformer' and 'exp'. If None, will return nn.Identity.
23+
func: The type of activation function. Supported values are:
24+
- 'relu': Standard ReLU activation
25+
- 'elu': Exponential Linear Unit
26+
- 'softplus': Softplus activation
27+
- 'gelu': Standard GELU activation using PyTorch's default approximation
28+
- 'gelu_borzoi': GELU activation using tanh approximation (different from PyTorch's default)
29+
- 'gelu_enformer': Custom GELU implementation from Enformer
30+
- 'exp': Exponential activation
31+
- None: Returns identity function (no activation)
2532
2633
Raises:
2734
NotImplementedError: If 'func' is not a supported activation function.
@@ -159,6 +166,14 @@ class Norm(nn.Module):
159166
'syncbatch', 'instance', or 'layer'. If None, will return nn.Identity.
160167
in_dim: Number of features in the input tensor.
161168
**kwargs: Additional arguments to pass to the normalization function.
169+
Common arguments include:
170+
- eps: Small constant added to denominator for numerical stability.
171+
Defaults to 1e-5 for all normalization types unless overridden.
172+
- momentum: Value used for the running_mean and running_var computation.
173+
Defaults to 0.1 for batch and sync batch norm.
174+
- affine: If True, adds learnable affine parameters. Defaults to True.
175+
- track_running_stats: If True, tracks running mean and variance.
176+
Defaults to True for batch and sync batch norm.
162177
"""
163178

164179
def __init__(

src/grelu/model/models.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,10 @@ class BorzoiModel(BaseModel):
496496
If None, no pooling will be applied at the end.
497497
flash_attn: If True, uses Flash Attention with Rotational Position Embeddings. key_len, value_len,
498498
pos_dropout and n_pos_features are ignored.
499+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers.
500+
Defaults to {"eps": 0.001}.
501+
act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
502+
tanh approximation (different from PyTorch's default GELU implementation).
499503
dtype: Data type for the layers.
500504
device: Device for the layers.
501505
"""
@@ -518,17 +522,19 @@ def __init__(
518522
pos_dropout: float = 0.01,
519523
attn_dropout: float = 0.05,
520524
ff_dropout: float = 0.2,
521-
norm_kwargs: Optional[dict] = {"eps" : 0.001},
525+
norm_kwargs: Optional[dict] = None,
522526
n_heads: int = 8,
523527
n_pos_features: int = 32,
524528
# Head
525529
crop_len: int = 16,
530+
act_func: str = "gelu_borzoi",
526531
final_act_func: Optional[str] = None,
527532
final_pool_func: Optional[str] = "avg",
528533
flash_attn=False,
529534
dtype=None,
530535
device=None,
531536
) -> None:
537+
norm_kwargs = norm_kwargs or {"eps": 0.001}
532538
super().__init__(
533539
embedding=BorzoiTrunk(
534540
stem_channels=stem_channels,
@@ -548,6 +554,7 @@ def __init__(
548554
n_pos_features=n_pos_features,
549555
crop_len=crop_len,
550556
flash_attn=flash_attn,
557+
act_func=act_func,
551558
dtype=dtype,
552559
device=device,
553560
),
@@ -567,6 +574,19 @@ def __init__(
567574
class BorzoiPretrainedModel(BaseModel):
568575
"""
569576
Borzoi model with published weights (ported from Keras).
577+
578+
Args:
579+
n_tasks: Number of tasks for the model to predict
580+
fold: Which fold of the model to load (default=0)
581+
n_transformers: Number of transformer blocks to use (default=8)
582+
crop_len: Number of positions to crop at either end of the output (default=0)
583+
act_func: Name of the activation function. Defaults to 'gelu_borzoi' which uses
584+
tanh approximation (different from PyTorch's default GELU implementation).
585+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers.
586+
Defaults to {"eps": 0.001}.
587+
final_pool_func: Name of the pooling function to apply to the final output (default="avg")
588+
dtype: Data type for the layers
589+
device: Device for the layers
570590
"""
571591

572592
def __init__(
@@ -577,10 +597,13 @@ def __init__(
577597
n_transformers: int = 8,
578598
# head
579599
crop_len=0,
600+
act_func="gelu_borzoi",
601+
norm_kwargs: Optional[dict] = None,
580602
final_pool_func="avg",
581603
dtype=None,
582604
device=None,
583605
):
606+
norm_kwargs = norm_kwargs or {"eps": 0.001}
584607
model = BorzoiModel(
585608
crop_len=crop_len,
586609
n_tasks=7611,
@@ -595,9 +618,10 @@ def __init__(
595618
pos_dropout=0.01,
596619
attn_dropout=0.05,
597620
ff_dropout=0.2,
598-
norm_kwargs={"eps": 0.001},
621+
norm_kwargs=norm_kwargs,
599622
n_heads=8,
600623
n_pos_features=32,
624+
act_func=act_func,
601625
final_act_func=None,
602626
final_pool_func=None,
603627
dtype=dtype,

0 commit comments

Comments
 (0)