Skip to content

Commit 42ea0cd

Browse files
committed
typing
1 parent 64c575c commit 42ea0cd

File tree

4 files changed

+9
-30
lines changed

4 files changed

+9
-30
lines changed

autointent/modules/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4949
MLKnnScorer,
5050
BertScorer,
5151
BERTLoRAScorer,
52-
RNNScorer
52+
RNNScorer,
5353
]
5454
)
5555

autointent/modules/scoring/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919
"RNNScorer",
2020
"RerankScorer",
2121
"SklearnScorer",
22-
]
22+
]

autointent/modules/scoring/_rnn.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy.typing as npt
55
import torch
66
from torch import nn
7+
from torch.optim import Adam
78

89
from autointent import Context
910
from autointent._callbacks import REPORTERS_NAMES
@@ -63,15 +64,6 @@ def get_embedder_config(self) -> dict[str, Any]:
6364
"""Get the configuration of the embedder."""
6465
return self.rnn_config.model_dump()
6566

66-
def _validate_task(self, labels: ListOfLabels) -> None:
67-
"""Validate the task type and set appropriate attributes."""
68-
if isinstance(labels[0], list):
69-
self._multilabel = True
70-
self._n_classes = len(labels[0])
71-
else:
72-
self._multilabel = False
73-
self._n_classes = max(labels) + 1
74-
7567
def __initialize_model(self, vocab_size: int) -> None:
7668
"""Initialize the RNN model."""
7769
self._model = SupervisedRNNClassifier(
@@ -128,7 +120,7 @@ def _texts_to_sequences(self, texts: list[str]) -> torch.Tensor:
128120
def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
129121
"""Train the model."""
130122
self._model.train()
131-
optimizer = torch.optim.Adam(self._model.parameters(), lr=self.learning_rate)
123+
optimizer = Adam(self._model.parameters(), lr=self.learning_rate)
132124

133125
criterion = nn.BCEWithLogitsLoss() if self._multilabel else nn.CrossEntropyLoss()
134126

@@ -201,7 +193,7 @@ def __init__(
201193
super().__init__()
202194
if pretrained_embs is not None:
203195
_, embed_dim = pretrained_embs.shape
204-
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True)
196+
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True) # type: ignore[no-untyped-call]
205197
else:
206198
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
207199
self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers=n_layers, batch_first=True, dropout=dropout)

tests/modules/scoring/test_rnn.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import numpy as np
32
import pytest
43

@@ -11,11 +10,7 @@ def test_rnn_prediction(dataset):
1110
"""Test that the RNN model can fit and make predictions."""
1211
data_handler = DataHandler(dataset)
1312

14-
scorer = RNNScorer(
15-
rnn_config=RNNConfig(embed_dim=64, hidden_dim=128, n_layers=1),
16-
num_train_epochs=1,
17-
batch_size=8
18-
)
13+
scorer = RNNScorer(rnn_config=RNNConfig(embed_dim=64, hidden_dim=128, n_layers=1), num_train_epochs=1, batch_size=8)
1914

2015
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
2116

@@ -52,11 +47,7 @@ def test_rnn_cache_clearing(dataset):
5247
"""Test that the RNN model properly handles cache clearing."""
5348
data_handler = DataHandler(dataset)
5449

55-
scorer = RNNScorer(
56-
rnn_config=RNNConfig(embed_dim=64, hidden_dim=128, n_layers=1),
57-
num_train_epochs=1,
58-
batch_size=8
59-
)
50+
scorer = RNNScorer(rnn_config=RNNConfig(embed_dim=64, hidden_dim=128, n_layers=1), num_train_epochs=1, batch_size=8)
6051

6152
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
6253

@@ -82,9 +73,7 @@ def test_rnn_device(dataset):
8273

8374
# Force CPU
8475
scorer_cpu = RNNScorer(
85-
rnn_config=RNNConfig(embed_dim=64, hidden_dim=128, n_layers=1, device="cpu"),
86-
num_train_epochs=1,
87-
batch_size=8
76+
rnn_config=RNNConfig(embed_dim=64, hidden_dim=128, n_layers=1, device="cpu"), num_train_epochs=1, batch_size=8
8877
)
8978

9079
scorer_cpu.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
@@ -97,9 +86,7 @@ def test_rnn_device(dataset):
9786

9887
# Test with default device
9988
scorer_default = RNNScorer(
100-
rnn_config=RNNConfig(embed_dim=64, hidden_dim=128, n_layers=1),
101-
num_train_epochs=1,
102-
batch_size=8
89+
rnn_config=RNNConfig(embed_dim=64, hidden_dim=128, n_layers=1), num_train_epochs=1, batch_size=8
10390
)
10491

10592
scorer_default.fit(data_handler.train_utterances(0), data_handler.train_labels(0))

0 commit comments

Comments
 (0)