Skip to content

Commit 5c65f82

Browse files
author
Muhammed Hasan Celik
committed
argument fix
1 parent f994085 commit 5c65f82

File tree

4 files changed

+52
-32
lines changed

4 files changed

+52
-32
lines changed

src/grelu/model/blocks.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class LinearBlock(nn.Module):
3434
act_func: Name of activation function
3535
dropout: Dropout probability
3636
norm: If True, apply layer normalization
37+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layer
3738
bias: If True, include bias term.
3839
dtype: Data type of the weights
3940
device: Device on which to store the weights
@@ -46,15 +47,15 @@ def __init__(
4647
act_func: str = "relu",
4748
dropout: float = 0.0,
4849
norm: bool = False,
49-
norm_kwargs: Optional[dict] = dict(),
50+
norm_kwargs: Optional[dict] = None,
5051
bias: bool = True,
5152
dtype=None,
5253
device=None,
5354
) -> None:
5455
super().__init__()
5556

5657
self.norm = Norm(
57-
func="layer" if norm else None, in_dim=in_len, **norm_kwargs, dtype=dtype, device=device
58+
func="layer" if norm else None, in_dim=in_len, **(norm_kwargs or dict()), dtype=dtype, device=device
5859
)
5960
self.linear = nn.Linear(in_len, out_len, bias=bias, dtype=dtype, device=device)
6061
self.dropout = Dropout(dropout)
@@ -123,7 +124,7 @@ def __init__(
123124
dropout: float = 0.0,
124125
norm: bool = True,
125126
norm_type="batch",
126-
norm_kwargs: Optional[dict] = dict(),
127+
norm_kwargs: Optional[dict] = None,
127128
residual: bool = False,
128129
order: str = "CDNRA",
129130
bias: bool = True,
@@ -152,15 +153,15 @@ def __init__(
152153
in_dim=out_channels,
153154
dtype=dtype,
154155
device=device,
155-
**norm_kwargs,
156+
**(norm_kwargs or dict()),
156157
)
157158
else:
158159
self.norm = Norm(
159160
norm_type,
160161
in_dim=in_channels,
161162
dtype=dtype,
162163
device=device,
163-
**norm_kwargs,
164+
**(norm_kwargs or dict()),
164165
)
165166
else:
166167
self.norm = Norm(None)
@@ -231,7 +232,7 @@ class ChannelTransformBlock(nn.Module):
231232
act_func: Name of the activation function
232233
dropout: Dropout probability
233234
norm_type: Type of normalization to apply: 'batch', 'syncbatch', 'layer', 'instance' or None
234-
norm_kwargs: Additional arguments to be passed to the normalization layer
235+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
235236
order: A string representing the order in which operations are
236237
to be performed on the input. For example, "CDNA" means that the
237238
operations will be performed in the order: convolution, dropout,
@@ -250,7 +251,7 @@ def __init__(
250251
dropout: float = 0.0,
251252
order: str = "CDNA",
252253
norm_type="batch",
253-
norm_kwargs: Optional[dict] = dict(),
254+
norm_kwargs: Optional[dict] = None,
254255
if_equal: bool = False,
255256
dtype=None,
256257
device=None,
@@ -274,15 +275,15 @@ def __init__(
274275
in_dim=out_channels,
275276
dtype=dtype,
276277
device=device,
277-
**norm_kwargs,
278+
**(norm_kwargs or dict()),
278279
)
279280
else:
280281
self.norm = Norm(
281282
"batch",
282283
in_dim=in_channels,
283284
dtype=dtype,
284285
device=device,
285-
**norm_kwargs,
286+
**(norm_kwargs or dict()),
286287
)
287288
else:
288289
self.norm = Norm(None)
@@ -441,6 +442,7 @@ class ConvTower(nn.Module):
441442
pool_size: Width of the pooling layers
442443
dropout: Dropout probability
443444
norm: If True, apply batch norm
445+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
444446
residual: If True, apply residual connection
445447
order: A string representing the order in which operations are
446448
to be performed on the input. For example, "CDNRA" means that the
@@ -464,7 +466,7 @@ def __init__(
464466
dilation_mult: float = 1,
465467
act_func: str = "relu",
466468
norm: bool = False,
467-
norm_kwargs: Optional[dict] = dict(),
469+
norm_kwargs: Optional[dict] = None,
468470
pool_func: Optional[str] = None,
469471
pool_size: Optional[int] = None,
470472
residual: bool = False,
@@ -565,15 +567,16 @@ class FeedForwardBlock(nn.Module):
565567
in_len: Length of the input tensor
566568
dropout: Dropout probability
567569
act_func: Name of the activation function
568-
kwargs: Additional arguments to be passed to the linear layers
570+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
571+
**kwargs: Additional arguments to be passed to the linear layers
569572
"""
570573

571574
def __init__(
572575
self,
573576
in_len: int,
574577
dropout: float = 0.0,
575578
act_func: str = "relu",
576-
norm_kwargs: Optional[dict] = dict(),
579+
norm_kwargs: Optional[dict] = None,
577580
**kwargs,
578581
) -> None:
579582
super().__init__()
@@ -693,6 +696,7 @@ class TransformerBlock(nn.Module):
693696
key_len: Length of the key vectors
694697
value_len: Length of the value vectors.
695698
pos_dropout: Dropout probability in the positional embeddings
699+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
696700
dtype: Data type of the weights
697701
device: Device on which to store the weights
698702
"""
@@ -710,7 +714,7 @@ def __init__(
710714
key_len: Optional[int] = None,
711715
value_len: Optional[int] = None,
712716
pos_dropout: Optional[float] = None,
713-
norm_kwargs: Optional[dict] = dict(),
717+
norm_kwargs: Optional[dict] = None,
714718
dtype=None,
715719
device=None,
716720
) -> None:
@@ -795,10 +799,10 @@ class TransformerTower(nn.Module):
795799
key_len: Length of the key vectors
796800
value_len: Length of the value vectors.
797801
pos_dropout: Dropout probability in the positional embeddings
798-
attn_dropout: Dropout probability in the output layer
799-
ff_droppout: Dropout probability in the linear feed-forward layers
800-
flash_attn: If True, uses Flash Attention with Rotational Position Embeddings. key_len, value_len,
801-
pos_dropout and n_pos_features are ignored.
802+
attn_dropout: Dropout probability in the attention layer
803+
ff_dropout: Dropout probability in the feed-forward layers
804+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
805+
flash_attn: If True, uses Flash Attention with Rotational Position Embeddings
802806
dtype: Data type of the weights
803807
device: Device on which to store the weights
804808
"""
@@ -814,7 +818,7 @@ def __init__(
814818
pos_dropout: float = 0.0,
815819
attn_dropout: float = 0.0,
816820
ff_dropout: float = 0.0,
817-
norm_kwargs: Optional[dict] = dict(),
821+
norm_kwargs: Optional[dict] = None,
818822
flash_attn: bool = False,
819823
dtype=None,
820824
device=None,
@@ -865,17 +869,19 @@ class UnetBlock(nn.Module):
865869
in_channels: Number of channels in the input
866870
y_in_channels: Number of channels in the higher-resolution representation.
867871
norm_type: Type of normalization to apply: 'batch', 'syncbatch', 'layer', 'instance' or None
868-
norm_kwargs: Additional arguments to be passed to the normalization layer
869-
device: Device on which to store the weights
872+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layers
873+
act_func: Name of the activation function
870874
dtype: Data type of the weights
875+
device: Device on which to store the weights
871876
"""
872877

873878
def __init__(
874879
self,
875880
in_channels: int,
876881
y_in_channels: int,
877882
norm_type="batch",
878-
norm_kwargs: Optional[dict] = dict(),
883+
norm_kwargs: Optional[dict] = None,
884+
act_func="gelu_borzoi",
879885
dtype=None,
880886
device=None,
881887
) -> None:
@@ -887,7 +893,7 @@ def __init__(
887893
norm=True,
888894
norm_type=norm_type,
889895
norm_kwargs=norm_kwargs,
890-
act_func="gelu_borzoi",
896+
act_func=act_func,
891897
order="NACDR",
892898
dtype=dtype,
893899
device=device,
@@ -899,7 +905,7 @@ def __init__(
899905
norm=True,
900906
norm_type=norm_type,
901907
norm_kwargs=norm_kwargs,
902-
act_func="gelu_borzoi",
908+
act_func=act_func,
903909
order="NACD",
904910
if_equal=True,
905911
dtype=dtype,
@@ -932,16 +938,17 @@ class UnetTower(nn.Module):
932938
in_channels: Number of channels in the input
933939
y_in_channels: Number of channels in the higher-resolution representations.
934940
n_blocks: Number of U-net blocks
941+
act_func: Name of the activation function
935942
kwargs: Additional arguments to be passed to the U-net blocks
936943
"""
937944

938945
def __init__(
939-
self, in_channels: int, y_in_channels: List[int], n_blocks: int, **kwargs
946+
self, in_channels: int, y_in_channels: List[int], n_blocks: int, act_func: str = "gelu_borzoi", **kwargs
940947
) -> None:
941948
super().__init__()
942949
self.blocks = nn.ModuleList()
943950
for y_c in y_in_channels:
944-
self.blocks.append(UnetBlock(in_channels, y_c, **kwargs))
951+
self.blocks.append(UnetBlock(in_channels, y_c, act_func=act_func, **kwargs))
945952

946953
def forward(self, x: Tensor, ys: List[Tensor]) -> Tensor:
947954
"""

src/grelu/model/heads.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class ConvHead(nn.Module):
2727
norm: If True, batch normalization will be included.
2828
act_func: Activation function for the convolutional layer
2929
pool_func: Pooling function.
30+
norm: If True, batch normalization will be included.
31+
norm_kwargs: Optional dictionary of keyword arguments to pass to the normalization layer
3032
dtype: Data type for the layers.
3133
device: Device for the layers.
3234
"""
@@ -38,7 +40,7 @@ def __init__(
3840
act_func: Optional[str] = None,
3941
pool_func: Optional[str] = None,
4042
norm: bool = False,
41-
norm_kwargs: Optional[dict] = dict(),
43+
norm_kwargs: Optional[dict] = None,
4244
dtype=None,
4345
device=None,
4446
) -> None:
@@ -56,7 +58,7 @@ def __init__(
5658
self.n_tasks,
5759
act_func=self.act_func,
5860
norm=self.norm,
59-
norm_kwargs=norm_kwargs,
61+
norm_kwargs=(norm_kwargs or dict()),
6062
dtype=dtype,
6163
device=device,
6264
)

src/grelu/model/models.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,17 +518,19 @@ def __init__(
518518
pos_dropout: float = 0.01,
519519
attn_dropout: float = 0.05,
520520
ff_dropout: float = 0.2,
521-
norm_kwargs: Optional[dict] = {"eps" : 0.001},
521+
norm_kwargs: Optional[dict] = None,
522522
n_heads: int = 8,
523523
n_pos_features: int = 32,
524524
# Head
525525
crop_len: int = 16,
526+
act_func: str = "gelu_borzoi",
526527
final_act_func: Optional[str] = None,
527528
final_pool_func: Optional[str] = "avg",
528529
flash_attn=False,
529530
dtype=None,
530531
device=None,
531532
) -> None:
533+
norm_kwargs = norm_kwargs or {"eps": 0.001}
532534
super().__init__(
533535
embedding=BorzoiTrunk(
534536
stem_channels=stem_channels,
@@ -548,6 +550,7 @@ def __init__(
548550
n_pos_features=n_pos_features,
549551
crop_len=crop_len,
550552
flash_attn=flash_attn,
553+
act_func=act_func,
551554
dtype=dtype,
552555
device=device,
553556
),
@@ -577,10 +580,13 @@ def __init__(
577580
n_transformers: int = 8,
578581
# head
579582
crop_len=0,
583+
act_func="gelu_borzoi",
584+
norm_kwargs: Optional[dict] = None,
580585
final_pool_func="avg",
581586
dtype=None,
582587
device=None,
583588
):
589+
norm_kwargs = norm_kwargs or {"eps": 0.001}
584590
model = BorzoiModel(
585591
crop_len=crop_len,
586592
n_tasks=7611,
@@ -595,9 +601,10 @@ def __init__(
595601
pos_dropout=0.01,
596602
attn_dropout=0.05,
597603
ff_dropout=0.2,
598-
norm_kwargs={"eps": 0.001},
604+
norm_kwargs=norm_kwargs,
599605
n_heads=8,
600606
n_pos_features=32,
607+
act_func=act_func,
601608
final_act_func=None,
602609
final_pool_func=None,
603610
dtype=dtype,

src/grelu/model/trunks/borzoi.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
n_blocks: int,
3737
norm_type="batch",
3838
norm_kwargs=None,
39+
act_func="gelu_borzoi",
3940
dtype=None,
4041
device=None,
4142
) -> None:
@@ -71,7 +72,7 @@ def __init__(
7172
norm=True,
7273
norm_type=norm_type,
7374
norm_kwargs=norm_kwargs,
74-
act_func="gelu_borzoi",
75+
act_func=act_func,
7576
order="NACDR",
7677
pool_func="max",
7778
pool_size=2,
@@ -150,6 +151,7 @@ def __init__(
150151
flash_attn: bool,
151152
norm_type="batch",
152153
norm_kwargs=None,
154+
act_func="gelu_borzoi",
153155
dtype=None,
154156
device=None,
155157
) -> None:
@@ -164,6 +166,7 @@ def __init__(
164166
n_blocks=n_conv,
165167
norm_type=norm_type,
166168
norm_kwargs=norm_kwargs,
169+
act_func=act_func,
167170
dtype=dtype,
168171
device=device,
169172
)
@@ -188,14 +191,15 @@ def __init__(
188191
y_in_channels=[channels, self.conv_tower.filters[-2]],
189192
norm_type=norm_type,
190193
norm_kwargs=norm_kwargs,
194+
act_func=act_func,
191195
dtype=dtype,
192196
device=device,
193197
)
194198
self.pointwise_conv = ConvBlock(
195199
in_channels=channels,
196200
out_channels=round(channels * 1.25),
197201
kernel_size=1,
198-
act_func="gelu_borzoi",
202+
act_func=act_func,
199203
dropout=0.1,
200204
norm=True,
201205
norm_type=norm_type,
@@ -204,7 +208,7 @@ def __init__(
204208
device=device,
205209
dtype=dtype,
206210
)
207-
self.act = Activation("gelu_borzoi")
211+
self.act = Activation(act_func)
208212
self.crop = Crop(crop_len=crop_len)
209213

210214
def forward(self, x: Tensor) -> Tensor:

0 commit comments

Comments
 (0)