Skip to content

Commit 03f51b4

Browse files
committed
Update cnn.py
1 parent 9c1b6d0 commit 03f51b4

File tree

1 file changed

+2
-12
lines changed
  • autointent/modules/scoring/_cnn

1 file changed

+2
-12
lines changed

autointent/modules/scoring/_cnn/cnn.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
kernel_sizes: list[int] = [3, 4, 5], # noqa: B006
3636
num_filters: int = 100,
3737
dropout: float = 0.1,
38+
batch_size: int = 8,
3839
cnn_config: CNNConfig | str | dict[str, Any] | None = None,
3940
) -> None:
4041
self.num_train_epochs = num_train_epochs
@@ -56,7 +57,7 @@ def __init__(
5657
self._multilabel: bool = False
5758
self._pad_idx = self.cnn_config.padding_idx
5859
self._unk_idx = self.cnn_config.unknown_idx
59-
self.batch_size = self.cnn_config.batch_size
60+
self.batch_size = batch_size
6061
self.max_seq_length = self.cnn_config.max_seq_length
6162

6263
@classmethod
@@ -86,17 +87,6 @@ def from_context(
8687
cnn_config=cnn_config
8788
)
8889

89-
def get_embedder_config(self) -> dict[str, Any]:
90-
"""Get the configuration of the embedder."""
91-
config = self.cnn_config.model_dump()
92-
config.update({
93-
"embed_dim": self.embed_dim,
94-
"hidden_dim": self.hidden_dim,
95-
"n_layers": self.n_layers,
96-
"dropout": self.dropout,
97-
})
98-
return config
99-
10090
def get_implicit_initialization_params(self) -> dict[str, Any]:
10191
return {"cnn_config": self.cnn_config.model_dump()}
10292

0 commit comments

Comments
 (0)