Skip to content

Commit eb43339

Browse files
committed
add get_implicit_initialization_params, fix mypy partially
1 parent 8fd4e97 commit eb43339

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

autointent/modules/scoring/_cnn/cnn.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,18 @@ def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
199199
optimizer.step()
200200

201201
self._model.eval()
202+
203+
def get_implicit_initialization_params(self) -> dict[str, Any]:
204+
"""Return default params used in initialization."""
205+
return {
206+
"max_seq_length": self.max_seq_length,
207+
"num_train_epochs": self.num_train_epochs,
208+
"batch_size": self.batch_size,
209+
"learning_rate": self.learning_rate,
210+
"seed": self.seed,
211+
"report_to": self.report_to,
212+
"embed_dim": self.embed_dim,
213+
"kernel_sizes": self.kernel_sizes,
214+
"num_filters": self.num_filters,
215+
"dropout": self.dropout
216+
}

autointent/modules/scoring/_cnn/textcnn.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,47 +13,33 @@ def __init__(
1313
vocab_size: int = 0,
1414
n_classes: int = 0,
1515
embed_dim: int = 128,
16-
kernel_sizes: list[int] = [3, 4, 5], # noqa: B006
16+
kernel_sizes: list[int] = [3, 4, 5], # noqa: B006
1717
num_filters: int = 100,
1818
dropout: float = 0.1,
1919
padding_idx: int = 0,
2020
pretrained_embs: torch.Tensor | None = None,
2121
) -> None:
22-
"""Initialize TextCNN model."""
2322
super().__init__()
2423

25-
# Register model hyperparameters as buffers
26-
self.register_buffer("vocab_size", torch.tensor(vocab_size))
27-
self.register_buffer("n_classes", torch.tensor(n_classes))
28-
self.register_buffer("embed_dim", torch.tensor(embed_dim))
29-
self.register_buffer("kernel_sizes", torch.tensor(kernel_sizes))
30-
self.register_buffer("num_filters", torch.tensor(num_filters))
31-
self.register_buffer("dropout_rate", torch.tensor(dropout))
32-
self.register_buffer("padding_idx", torch.tensor(padding_idx))
24+
self.vocab_size = vocab_size
25+
self.n_classes = n_classes
26+
self.embed_dim = embed_dim
27+
self.kernel_sizes = kernel_sizes
28+
self.num_filters = num_filters
29+
self.dropout_rate = dropout
30+
self.padding_idx = padding_idx
3331

3432
if pretrained_embs is not None:
3533
_, embed_dim = pretrained_embs.shape
36-
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True) # type: ignore[no-untyped-call]
37-
# Register pretrained embeddings as buffer if they exist
38-
self.register_buffer("pretrained_embs", pretrained_embs)
34+
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True)
35+
self.pretrained_embs = pretrained_embs
3936
else:
4037
self.embedding = nn.Embedding(
4138
num_embeddings=vocab_size,
4239
embedding_dim=embed_dim,
43-
padding_idx=padding_idx
40+
padding_idx=padding_idx,
4441
)
45-
# Register None for pretrained_embs buffer
46-
self.register_buffer("pretrained_embs", None)
47-
48-
self.convs = nn.ModuleList([
49-
nn.Conv1d(
50-
in_channels=embed_dim,
51-
out_channels=num_filters,
52-
kernel_size=k
53-
) for k in kernel_sizes
54-
])
55-
self.dropout = nn.Dropout(dropout)
56-
self.fc = nn.Linear(num_filters * len(kernel_sizes), n_classes)
42+
self.pretrained_embs = None
5743

5844
def forward(self, x: torch.Tensor) -> torch.Tensor:
5945
"""Forward pass of the model."""
@@ -73,7 +59,7 @@ def load(self, model_path: str) -> None:
7359
state_dict = torch.load(model_path)
7460
self.load_state_dict(state_dict)
7561

76-
def get_config(self) -> dict:
62+
def get_config(self) -> -> dict[str, int | list[int] | torch.Tensor | None]:
7763
return {
7864
"vocab_size": self.vocab_size.item(),
7965
"n_classes": self.n_classes.item(),

0 commit comments

Comments
 (0)