Skip to content

Commit b7be78d

Browse files
committed
fix ruff
1 parent 43d7e72 commit b7be78d

File tree

1 file changed

+31
-29
lines changed
  • autointent/modules/scoring/_cnn

1 file changed

+31
-29
lines changed

autointent/modules/scoring/_cnn/cnn.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
"""CNNScorer class for scoring."""
22

3+
from collections import Counter
4+
import re
35
from typing import Any
6+
47
import numpy as np
58
import numpy.typing as npt
9+
from torch import nn
610
import torch
7-
import torch.nn as nn
8-
from collections import Counter
9-
import re
11+
from torch.utils.data import TensorDataset, DataLoader
1012

1113
from autointent import Context
1214
from autointent._callbacks import REPORTERS_NAMES
13-
from autointent.configs import EmbedderConfig
1415
from autointent.custom_types import ListOfLabels
1516
from autointent.modules.base import BaseScorer
1617
from autointent.modules.scoring._cnn.textcnn import TextCNN
@@ -47,8 +48,8 @@ def __init__(
4748
self._model = None
4849
self._vocab = None
4950
self._padding_idx = 0
50-
self._unk_token = "<UNK>"
51-
self._pad_token = "<PAD>"
51+
self._unk_token = "<UNK>" # noqa: S105
52+
self._pad_token = "<PAD>" # noqa: S105
5253

5354
@classmethod
5455
def from_context(
@@ -74,45 +75,46 @@ def fit(self, utterances: list[str], labels: ListOfLabels, clear_cache: bool = F
7475
self.clear_cache()
7576

7677
self._validate_task(labels)
77-
self._multilabel = isinstance(labels[0], (list, np.ndarray))
78+
self._multilabel = isinstance(labels[0], list | np.ndarray)
7879

7980
# Build vocabulary and tokenize
8081
self._build_vocab(utterances)
8182

8283
# Convert text to padded indices
83-
X = self._text_to_indices(utterances)
84-
X = torch.tensor(X, dtype=torch.long)
84+
x = self._text_to_indices(utterances)
85+
x = torch.tensor(x, dtype=torch.long)
8586
y = torch.tensor(labels, dtype=torch.long)
8687

8788
# Initialize model
8889
self._model = TextCNN(
8990
vocab_size=len(self._vocab),
9091
n_classes=self._n_classes,
91-
embed_dim=self.cnn_config.get('embed_dim', 128),
92-
kernel_sizes=self.cnn_config.get('kernel_sizes', (3, 4, 5)),
93-
num_filters=self.cnn_config.get('num_filters', 100),
94-
dropout=self.cnn_config.get('dropout', 0.1),
92+
embed_dim=self.cnn_config.get("embed_dim", 128),
93+
kernel_sizes=self.cnn_config.get("kernel_sizes", (3, 4, 5)),
94+
num_filters=self.cnn_config.get("num_filters", 100),
95+
dropout=self.cnn_config.get("dropout", 0.1),
9596
padding_idx=self._padding_idx,
96-
pretrained_embs=self.cnn_config.get('pretrained_embs', None)
97+
pretrained_embs=self.cnn_config.get("pretrained_embs", None)
9798
)
9899

99100
# Training
100-
self._train_model(X, y)
101+
self._train_model(x, y)
101102

102103
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
103104
if self._model is None:
104-
raise RuntimeError("Model not trained. Call fit() first.")
105+
error_msg = "Model not trained. Call fit() first."
106+
raise RuntimeError(error_msg)
105107

106-
X = self._text_to_indices(utterances)
107-
X = torch.tensor(X, dtype=torch.long)
108+
x = self._text_to_indices(utterances)
109+
x = torch.tensor(x, dtype=torch.long)
108110

109111
self._model.eval()
110112
all_probs = []
111113

112114
with torch.no_grad():
113-
for i in range(0, len(X), self.batch_size):
114-
batch_X = X[i:i+self.batch_size]
115-
outputs = self._model(batch_X)
115+
for i in range(0, len(x), self.batch_size):
116+
batch_x = x[i:i+self.batch_size]
117+
outputs = self._model(batch_x)
116118
if self._multilabel:
117119
probs = torch.sigmoid(outputs).cpu().numpy()
118120
else:
@@ -125,7 +127,7 @@ def _build_vocab(self, utterances: list[str]) -> None:
125127
"""Build vocabulary from training utterances."""
126128
word_counts = Counter()
127129
for utterance in utterances:
128-
words = re.findall(r'\w+', utterance.lower())
130+
words = re.findall(r"\w+", utterance.lower())
129131
word_counts.update(words)
130132

131133
# Create vocabulary with special tokens
@@ -146,7 +148,7 @@ def _text_to_indices(self, utterances: list[str]) -> list[list[int]]:
146148
"""Convert utterances to padded sequences of word indices."""
147149
sequences = []
148150
for utterance in utterances:
149-
words = re.findall(r'\w+', utterance.lower())
151+
words = re.findall(r"\w+", utterance.lower())
150152
# Convert words to indices, using UNK for unknown words
151153
seq = [self._vocab.get(word, self._unk_idx) for word in words]
152154
# Truncate if too long
@@ -160,9 +162,9 @@ def clear_cache(self) -> None:
160162
self._model = None
161163
torch.cuda.empty_cache()
162164

163-
def _train_model(self, X: torch.Tensor, y: torch.Tensor) -> None:
164-
dataset = torch.utils.data.TensorDataset(X, y)
165-
dataloader = torch.utils.data.DataLoader(
165+
def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
166+
dataset = TensorDataset(x, y)
167+
dataloader = DataLoader(
166168
dataset,
167169
batch_size=self.batch_size,
168170
shuffle=True
@@ -172,10 +174,10 @@ def _train_model(self, X: torch.Tensor, y: torch.Tensor) -> None:
172174
optimizer = torch.optim.Adam(self._model.parameters(), lr=self.learning_rate)
173175

174176
self._model.train()
175-
for epoch in range(self.num_train_epochs):
176-
for batch_X, batch_y in dataloader:
177+
for _ in range(self.num_train_epochs):
178+
for batch_x, batch_y in dataloader:
177179
optimizer.zero_grad()
178-
outputs = self._model(batch_X)
180+
outputs = self._model(batch_x)
179181
loss = criterion(outputs, batch_y)
180182
loss.backward()
181183
optimizer.step()

0 commit comments

Comments
 (0)