Skip to content

Commit 320f341

Browse files
committed
mypy1
1 parent f7b19ad commit 320f341

File tree

5 files changed

+19
-6
lines changed

5 files changed

+19
-6
lines changed

autointent/_wrappers/base_torch_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def __init__(
7676

7777
def set_vocab(self, vocab: dict[str, Any]) -> None:
7878
"""Save vocabulary into module's attributes and initialize embeddings matrix."""
79+
if self.embed_dim is None:
80+
msg = "embed_dim must be set to initialize embeddings"
81+
raise ValueError(msg)
7982
self.vocab_config.vocab = vocab
8083
self.embedding = nn.Embedding(
8184
num_embeddings=len(self.vocab_config.vocab),

autointent/modules/scoring/_gcn/gcn_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def create_correlation_matrix(
8080
adj_matrix = (conditional_prob > tau).float()
8181

8282
adj_matrix_no_self_loop = adj_matrix - torch.eye(num_classes, device=adj_matrix.device)
83-
sum_neighbors = adj_matrix_no_self_loop.sum(axis=1)
83+
sum_neighbors = adj_matrix_no_self_loop.sum(dim=1)
8484

8585
weights_p = p / sum_neighbors
8686
weights_p.nan_to_num_(0)

autointent/modules/scoring/_gcn/gcn_scorer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,16 @@ def __init__( # noqa: PLR0913
4545
self.gcn_hidden_dims = gcn_hidden_dims
4646
self.p_reweight = p_reweight
4747
self.tau_threshold = tau_threshold
48-
self.torch_config = TorchTrainingConfig(
48+
torch_config = TorchTrainingConfig(
4949
num_train_epochs=num_train_epochs,
5050
batch_size=batch_size,
5151
learning_rate=learning_rate,
5252
seed=seed,
5353
)
5454
if device is not None:
55-
self.torch_config.device = device
56-
self.early_stopping_config = EarlyStoppingConfig.from_search_config(early_stopping_config)
55+
torch_config.device = device
56+
57+
super().__init__(torch_config=torch_config, early_stopping_config=early_stopping_config)
5758

5859
@classmethod
5960
def from_context( # noqa: PLR0913

autointent/modules/scoring/_torch/base_scorer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ class BaseTorchTrainerScorer(BaseScorer):
2727
_best_model_weights = "best_model.pt"
2828
_model: BaseTorchModule
2929

30+
def __init__(
31+
self,
32+
torch_config: TorchTrainingConfig | dict[str, Any] | None = None,
33+
early_stopping_config: EarlyStoppingConfig | dict[str, Any] | None = None,
34+
) -> None:
35+
self.torch_config = TorchTrainingConfig.from_search_config(torch_config)
36+
self.early_stopping_config = EarlyStoppingConfig.from_search_config(early_stopping_config)
37+
3038
def _train_model(
3139
self,
3240
train_x: torch.Tensor,
@@ -170,9 +178,8 @@ def __init__(
170178
vocab_config: VocabConfig | dict[str, Any] | None = None,
171179
early_stopping_config: EarlyStoppingConfig | dict[str, Any] | None = None,
172180
) -> None:
173-
self.torch_config = TorchTrainingConfig.from_search_config(torch_config)
181+
super().__init__(torch_config=torch_config, early_stopping_config=early_stopping_config)
174182
self.vocab_config = VocabConfig.from_search_config(vocab_config)
175-
self.early_stopping_config = EarlyStoppingConfig.from_search_config(early_stopping_config)
176183

177184
@abstractmethod
178185
def _init_model(self) -> BaseTorchModuleWithVocab: ...

autointent/modules/scoring/_torch/cnn_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
self.kernel_sizes = kernel_sizes
4747
self.num_filters = num_filters
4848
self.dropout_rate = dropout
49+
assert self.embed_dim is not None # noqa: S101
4950

5051
# Initialize other layers
5152
self.convs = nn.ModuleList(
@@ -70,6 +71,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7071
return self.fc(dropped) # type: ignore[no-any-return]
7172

7273
def dump(self, path: Path) -> None:
74+
assert self.embed_dim is not None # noqa: S101
7375
metadata = TextCNNDumpMetadata(
7476
embed_dim=self.embed_dim,
7577
n_classes=self.n_classes,

0 commit comments

Comments
 (0)