Skip to content

Commit b99ecb4

Browse files
committed
fix furr for cnn.py
1 parent 500b58d commit b99ecb4

File tree

1 file changed

+14
-10
lines changed
  • autointent/modules/scoring/_cnn

1 file changed

+14
-10
lines changed

autointent/modules/scoring/_cnn/cnn.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""CNNScorer class for scoring."""
22

3-
from collections import Counter
43
import re
4+
from collections import Counter
55
from typing import Any
66

77
import numpy as np
88
import numpy.typing as npt
9-
from torch import nn
109
import torch
11-
from torch.utils.data import TensorDataset, DataLoader
10+
from torch import nn, Tensor
11+
from torch.utils.data import DataLoader, TensorDataset
1212

1313
from autointent import Context
1414
from autointent._callbacks import REPORTERS_NAMES
@@ -61,7 +61,7 @@ def from_context(
6161
learning_rate: float = 5e-5,
6262
seed: int = 0,
6363
**cnn_kwargs: dict[str, Any],
64-
) -> CNNScorer:
64+
) -> "CNNScorer":
6565
return cls(
6666
num_train_epochs=num_train_epochs,
6767
batch_size=batch_size,
@@ -88,7 +88,8 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
8888

8989
# Initialize model
9090
if self._vocab is None:
91-
raise ValueError("Vocabulary not built")
91+
msg = "Vocabulary not built"
92+
raise ValueError(msg)
9293

9394
self._model = TextCNN(
9495
vocab_size=len(self._vocab),
@@ -106,7 +107,8 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
106107

107108
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
108109
if self._model is None:
109-
raise ValueError("Model not trained. Call fit() first.")
110+
msg = "Model not trained. Call fit() first."
111+
raise ValueError(msg)
110112

111113
x = self._text_to_indices(utterances)
112114
x_tensor = torch.tensor(x, dtype=torch.long)
@@ -138,7 +140,8 @@ def _build_vocab(self, utterances: list[str]) -> None:
138140

139141
# Add words to vocabulary
140142
if self._vocab is None:
141-
raise ValueError("Vocabulary not initialized")
143+
msg = "Vocabulary not initialized"
144+
raise ValueError(msg)
142145

143146
for word, _ in word_counts.most_common():
144147
if word not in self._vocab:
@@ -150,7 +153,8 @@ def _build_vocab(self, utterances: list[str]) -> None:
150153
def _text_to_indices(self, utterances: list[str]) -> list[list[int]]:
151154
"""Convert utterances to padded sequences of word indices."""
152155
if self._vocab is None:
153-
raise ValueError("Vocabulary not built")
156+
msg = "Vocabulary not built"
157+
raise ValueError(msg)
154158

155159
sequences: list[list[int]] = []
156160
for utterance in utterances:
@@ -170,7 +174,8 @@ def clear_cache(self) -> None:
170174

171175
def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
172176
if self._model is None:
173-
raise ValueError("Model not initialized")
177+
msg = "Model not initialized"
178+
raise ValueError(msg)
174179

175180
dataset = TensorDataset(x, y)
176181
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
@@ -190,4 +195,3 @@ def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
190195
optimizer.step()
191196

192197
self._model.eval()
193-

0 commit comments

Comments
 (0)