Skip to content

Commit c35c016

Browse files
authored
Merge pull request #147 from johahi/borzoi-small-fixes
Fixed small Borzoi discrepancies, fixes #144
2 parents cd234d5 + a0288aa commit c35c016

File tree

5 files changed

+40
-18
lines changed

5 files changed

+40
-18
lines changed

src/grelu/model/blocks.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,15 @@ def __init__(
4646
act_func: str = "relu",
4747
dropout: float = 0.0,
4848
norm: bool = False,
49+
norm_kwargs: Optional[dict] = dict(),
4950
bias: bool = True,
5051
dtype=None,
5152
device=None,
5253
) -> None:
5354
super().__init__()
5455

5556
self.norm = Norm(
56-
func="layer" if norm else None, in_dim=in_len, dtype=dtype, device=device
57+
func="layer" if norm else None, in_dim=in_len, **norm_kwargs, dtype=dtype, device=device
5758
)
5859
self.linear = nn.Linear(in_len, out_len, bias=bias, dtype=dtype, device=device)
5960
self.dropout = Dropout(dropout)
@@ -122,7 +123,7 @@ def __init__(
122123
dropout: float = 0.0,
123124
norm: bool = True,
124125
norm_type="batch",
125-
norm_kwargs=None,
126+
norm_kwargs: Optional[dict] = dict(),
126127
residual: bool = False,
127128
order: str = "CDNRA",
128129
bias: bool = True,
@@ -142,7 +143,6 @@ def __init__(
142143
"R",
143144
], "The string supplied in order must contain one occurrence each of A, C, D, N and R."
144145
self.order = order
145-
norm_kwargs = norm_kwargs or dict()
146146

147147
# Create norm
148148
if norm:
@@ -250,7 +250,7 @@ def __init__(
250250
dropout: float = 0.0,
251251
order: str = "CDNA",
252252
norm_type="batch",
253-
norm_kwargs=None,
253+
norm_kwargs: Optional[dict] = dict(),
254254
if_equal: bool = False,
255255
dtype=None,
256256
device=None,
@@ -265,7 +265,6 @@ def __init__(
265265
"N",
266266
], "The string supplied in order must contain one occurrence each of A, C, D and N."
267267
self.order = order
268-
norm_kwargs = norm_kwargs or dict()
269268

270269
# Create batch norm
271270
if norm:
@@ -465,6 +464,7 @@ def __init__(
465464
dilation_mult: float = 1,
466465
act_func: str = "relu",
467466
norm: bool = False,
467+
norm_kwargs: Optional[dict] = dict(),
468468
pool_func: Optional[str] = None,
469469
pool_size: Optional[int] = None,
470470
residual: bool = False,
@@ -507,6 +507,7 @@ def __init__(
507507
dilation=dilation,
508508
act_func=act_func,
509509
norm=norm,
510+
norm_kwargs=norm_kwargs,
510511
residual=residual,
511512
pool_func=pool_func,
512513
pool_size=pool_size,
@@ -572,13 +573,15 @@ def __init__(
572573
in_len: int,
573574
dropout: float = 0.0,
574575
act_func: str = "relu",
576+
norm_kwargs: Optional[dict] = dict(),
575577
**kwargs,
576578
) -> None:
577579
super().__init__()
578580
self.dense1 = LinearBlock(
579581
in_len,
580582
in_len * 2,
581583
norm=True,
584+
norm_kwargs=norm_kwargs,
582585
dropout=dropout,
583586
act_func=act_func,
584587
bias=True,
@@ -588,6 +591,7 @@ def __init__(
588591
in_len * 2,
589592
in_len,
590593
norm=False,
594+
norm_kwargs=norm_kwargs,
591595
dropout=dropout,
592596
act_func=None,
593597
bias=True,
@@ -706,11 +710,12 @@ def __init__(
706710
key_len: Optional[int] = None,
707711
value_len: Optional[int] = None,
708712
pos_dropout: Optional[float] = None,
713+
norm_kwargs: Optional[dict] = dict(),
709714
dtype=None,
710715
device=None,
711716
) -> None:
712717
super().__init__()
713-
self.norm = Norm("layer", in_len)
718+
self.norm = Norm("layer", in_len, **norm_kwargs)
714719

715720
if flash_attn:
716721
if (
@@ -752,6 +757,7 @@ def __init__(
752757
in_len=in_len,
753758
dropout=ff_dropout,
754759
act_func="relu",
760+
norm_kwargs=norm_kwargs,
755761
dtype=dtype,
756762
device=device,
757763
)
@@ -808,6 +814,7 @@ def __init__(
808814
pos_dropout: float = 0.0,
809815
attn_dropout: float = 0.0,
810816
ff_dropout: float = 0.0,
817+
norm_kwargs: Optional[dict] = dict(),
811818
flash_attn: bool = False,
812819
dtype=None,
813820
device=None,
@@ -825,6 +832,7 @@ def __init__(
825832
key_len=key_len,
826833
value_len=value_len,
827834
pos_dropout=pos_dropout,
835+
norm_kwargs=norm_kwargs,
828836
dtype=dtype,
829837
device=device,
830838
)
@@ -867,7 +875,7 @@ def __init__(
867875
in_channels: int,
868876
y_in_channels: int,
869877
norm_type="batch",
870-
norm_kwargs=None,
878+
norm_kwargs: Optional[dict] = dict(),
871879
dtype=None,
872880
device=None,
873881
) -> None:
@@ -877,10 +885,10 @@ def __init__(
877885
in_channels,
878886
1,
879887
norm=True,
880-
act_func="gelu",
881-
order="NACDR",
882888
norm_type=norm_type,
883889
norm_kwargs=norm_kwargs,
890+
act_func="gelu_borzoi",
891+
order="NACDR",
884892
dtype=dtype,
885893
device=device,
886894
)
@@ -891,7 +899,7 @@ def __init__(
891899
norm=True,
892900
norm_type=norm_type,
893901
norm_kwargs=norm_kwargs,
894-
act_func="gelu",
902+
act_func="gelu_borzoi",
895903
order="NACD",
896904
if_equal=True,
897905
dtype=dtype,

src/grelu/model/heads.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
act_func: Optional[str] = None,
3939
pool_func: Optional[str] = None,
4040
norm: bool = False,
41+
norm_kwargs: Optional[dict] = dict(),
4142
dtype=None,
4243
device=None,
4344
) -> None:
@@ -55,6 +56,7 @@ def __init__(
5556
self.n_tasks,
5657
act_func=self.act_func,
5758
norm=self.norm,
59+
norm_kwargs=norm_kwargs,
5860
dtype=dtype,
5961
device=device,
6062
)

src/grelu/model/layers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Activation(nn.Module):
2121
2222
Args:
2323
func: The type of activation function. Supported values are 'relu',
24-
'elu', 'softplus', 'gelu', 'gelu_enformer' and 'exp'. If None, will return nn.Identity.
24+
'elu', 'softplus', 'gelu', 'gelu_borzoi', 'gelu_enformer' and 'exp'. If None, will return nn.Identity.
2525
2626
Raises:
2727
NotImplementedError: If 'func' is not a supported activation function.
@@ -36,6 +36,8 @@ def __init__(self, func: str) -> None:
3636
self.layer = nn.ELU()
3737
elif func == "gelu":
3838
self.layer = nn.GELU()
39+
elif func == "gelu_borzoi":
40+
self.layer = nn.GELU(approximate = 'tanh')
3941
elif func == "gelu_enformer":
4042
self.layer = GELU()
4143
elif func == "softplus":

src/grelu/model/models.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,8 +515,10 @@ def __init__(
515515
n_transformers: int = 8,
516516
key_len: int = 64,
517517
value_len: int = 192,
518-
pos_dropout: float = 0.0,
519-
attn_dropout: float = 0.0,
518+
pos_dropout: float = 0.01,
519+
attn_dropout: float = 0.05,
520+
ff_dropout: float = 0.2,
521+
norm_kwargs: Optional[dict] = {"eps" : 0.001},
520522
n_heads: int = 8,
521523
n_pos_features: int = 32,
522524
# Head
@@ -540,6 +542,8 @@ def __init__(
540542
value_len=value_len,
541543
pos_dropout=pos_dropout,
542544
attn_dropout=attn_dropout,
545+
ff_dropout=ff_dropout,
546+
norm_kwargs=norm_kwargs,
543547
n_heads=n_heads,
544548
n_pos_features=n_pos_features,
545549
crop_len=crop_len,
@@ -551,6 +555,7 @@ def __init__(
551555
n_tasks,
552556
in_channels=round(channels * 1.25),
553557
norm=False,
558+
norm_kwargs=norm_kwargs,
554559
act_func=final_act_func,
555560
pool_func=final_pool_func,
556561
dtype=dtype,
@@ -587,8 +592,10 @@ def __init__(
587592
n_transformers=8,
588593
key_len=64,
589594
value_len=192,
590-
pos_dropout=0.0,
591-
attn_dropout=0.0,
595+
pos_dropout=0.01,
596+
attn_dropout=0.05,
597+
ff_dropout=0.2,
598+
norm_kwargs={"eps": 0.001},
592599
n_heads=8,
593600
n_pos_features=32,
594601
final_act_func=None,

src/grelu/model/trunks/borzoi.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
norm=True,
7272
norm_type=norm_type,
7373
norm_kwargs=norm_kwargs,
74-
act_func="gelu",
74+
act_func="gelu_borzoi",
7575
order="NACDR",
7676
pool_func="max",
7777
pool_size=2,
@@ -142,6 +142,7 @@ def __init__(
142142
value_len: int,
143143
pos_dropout: float,
144144
attn_dropout: float,
145+
ff_dropout: float,
145146
n_heads: int,
146147
n_pos_features: int,
147148
# Crop
@@ -173,6 +174,8 @@ def __init__(
173174
value_len=value_len,
174175
pos_dropout=pos_dropout,
175176
attn_dropout=attn_dropout,
177+
ff_dropout=ff_dropout,
178+
norm_kwargs=norm_kwargs,
176179
n_heads=n_heads,
177180
n_pos_features=n_pos_features,
178181
flash_attn=flash_attn,
@@ -192,7 +195,7 @@ def __init__(
192195
in_channels=channels,
193196
out_channels=round(channels * 1.25),
194197
kernel_size=1,
195-
act_func="gelu",
198+
act_func="gelu_borzoi",
196199
dropout=0.1,
197200
norm=True,
198201
norm_type=norm_type,
@@ -201,7 +204,7 @@ def __init__(
201204
device=device,
202205
dtype=dtype,
203206
)
204-
self.act = Activation("gelu")
207+
self.act = Activation("gelu_borzoi")
205208
self.crop = Crop(crop_len=crop_len)
206209

207210
def forward(self, x: Tensor) -> Tensor:

0 commit comments

Comments
 (0)