Skip to content

Commit 500b58d

Browse files
committed
fix ruff
1 parent abb4e34 commit 500b58d

File tree

2 files changed

+61
-67
lines changed

2 files changed

+61
-67
lines changed
Lines changed: 55 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
"""CNNScorer class for scoring."""
22

3-
from __future__ import annotations
4-
5-
import re
63
from collections import Counter
7-
from typing import Any, Dict, List, Optional, Union
4+
import re
5+
from typing import Any
86

97
import numpy as np
108
import numpy.typing as npt
9+
from torch import nn
1110
import torch
12-
from torch import nn, Tensor
13-
from torch.utils.data import DataLoader, TensorDataset
11+
from torch.utils.data import TensorDataset, DataLoader
1412

1513
from autointent import Context
1614
from autointent._callbacks import REPORTERS_NAMES
@@ -33,8 +31,8 @@ def __init__(
3331
batch_size: int = 8,
3432
learning_rate: float = 5e-5,
3533
seed: int = 0,
36-
report_to: REPORTERS_NAMES | None = None, # type: ignore[no-any-return]
37-
**cnn_kwargs: Dict[str, Any],
34+
report_to: REPORTERS_NAMES | None = None, # type: ignore[no-any-return]
35+
**cnn_kwargs: dict[str, Any],
3836
) -> None:
3937
self.max_seq_length = max_seq_length
4038
self.num_train_epochs = num_train_epochs
@@ -43,10 +41,10 @@ def __init__(
4341
self.seed = seed
4442
self.report_to = report_to
4543
self.cnn_config = cnn_kwargs
46-
44+
4745
# Will be initialized during fit()
48-
self._model: Optional[TextCNN] = None
49-
self._vocab: Optional[Dict[str, int]] = None
46+
self._model: TextCNN | None = None
47+
self._vocab: dict[str, int] | None = None
5048
self._padding_idx = 0
5149
self._unk_token = "<UNK>" # noqa: S105
5250
self._pad_token = "<PAD>" # noqa: S105
@@ -62,8 +60,8 @@ def from_context(
6260
batch_size: int = 8,
6361
learning_rate: float = 5e-5,
6462
seed: int = 0,
65-
**cnn_kwargs: Dict[str, Any],
66-
) -> "CNNScorer":
63+
**cnn_kwargs: dict[str, Any],
64+
) -> CNNScorer:
6765
return cls(
6866
num_train_epochs=num_train_epochs,
6967
batch_size=batch_size,
@@ -73,23 +71,25 @@ def from_context(
7371
**cnn_kwargs,
7472
)
7573

76-
def fit(self, utterances: List[str], labels: ListOfLabels) -> None:
74+
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
7775
self._validate_task(labels)
78-
self._multilabel = isinstance(labels[0], (list, np.ndarray))
76+
self._multilabel = isinstance(labels[0], list | np.ndarray)
7977
self._n_classes = len(labels[0]) if self._multilabel else len(set(labels))
80-
78+
8179
# Build vocabulary and tokenize
8280
self._build_vocab(utterances)
83-
81+
8482
# Convert text to padded indices
8583
x = self._text_to_indices(utterances)
8684
x_tensor = torch.tensor(x, dtype=torch.long)
87-
y_tensor = torch.tensor(labels, dtype=torch.long if not self._multilabel else torch.float)
88-
85+
y_tensor = torch.tensor(
86+
labels, dtype=torch.long if not self._multilabel else torch.float
87+
)
88+
8989
# Initialize model
9090
if self._vocab is None:
91-
raise RuntimeError("Vocabulary not built")
92-
91+
raise ValueError("Vocabulary not built")
92+
9393
self._model = TextCNN(
9494
vocab_size=len(self._vocab),
9595
n_classes=self._n_classes,
@@ -98,70 +98,67 @@ def fit(self, utterances: List[str], labels: ListOfLabels) -> None:
9898
num_filters=self.cnn_config.get("num_filters", 100),
9999
dropout=self.cnn_config.get("dropout", 0.1),
100100
padding_idx=self._padding_idx,
101-
pretrained_embs=self.cnn_config.get("pretrained_embs", None)
101+
pretrained_embs=self.cnn_config.get("pretrained_embs", None),
102102
)
103-
103+
104104
# Training
105105
self._train_model(x_tensor, y_tensor)
106106

107-
def predict(self, utterances: List[str]) -> npt.NDArray[Any]:
107+
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
108108
if self._model is None:
109-
raise RuntimeError("Model not trained. Call fit() first.")
110-
109+
raise ValueError("Model not trained. Call fit() first.")
110+
111111
x = self._text_to_indices(utterances)
112112
x_tensor = torch.tensor(x, dtype=torch.long)
113-
113+
114114
self._model.eval()
115-
all_probs: List[npt.NDArray[Any]] = []
116-
115+
all_probs: list[npt.NDArray[Any]] = []
116+
117117
with torch.no_grad():
118118
for i in range(0, len(x_tensor), self.batch_size):
119-
batch_x = x_tensor[i:i+self.batch_size]
119+
batch_x = x_tensor[i : i + self.batch_size]
120120
outputs = self._model(batch_x)
121121
if self._multilabel:
122122
probs = torch.sigmoid(outputs).cpu().numpy()
123123
else:
124124
probs = torch.softmax(outputs, dim=1).cpu().numpy()
125125
all_probs.append(probs)
126-
126+
127127
return np.concatenate(all_probs, axis=0) if all_probs else np.array([])
128128

129-
def _build_vocab(self, utterances: List[str]) -> None:
129+
def _build_vocab(self, utterances: list[str]) -> None:
130130
"""Build vocabulary from training utterances."""
131-
word_counts: Dict[str, int] = Counter()
131+
word_counts: dict[str, int] = Counter()
132132
for utterance in utterances:
133133
words = re.findall(r"\w+", utterance.lower())
134134
word_counts.update(words)
135-
135+
136136
# Create vocabulary with special tokens
137-
self._vocab = {
138-
self._pad_token: 0,
139-
self._unk_token: 1
140-
}
141-
137+
self._vocab = {self._pad_token: 0, self._unk_token: 1}
138+
142139
# Add words to vocabulary
143140
if self._vocab is None:
144-
raise RuntimeError("Vocabulary not initialized")
145-
141+
raise ValueError("Vocabulary not initialized")
142+
146143
for word, _ in word_counts.most_common():
147144
if word not in self._vocab:
148145
self._vocab[word] = len(self._vocab)
149-
146+
150147
self._unk_idx = 1
151148
self._padding_idx = 0
152149

153-
def _text_to_indices(self, utterances: List[str]) -> List[List[int]]:
150+
def _text_to_indices(self, utterances: list[str]) -> list[list[int]]:
154151
"""Convert utterances to padded sequences of word indices."""
155152
if self._vocab is None:
156-
raise RuntimeError("Vocabulary not built")
157-
158-
sequences: List[List[int]] = []
153+
raise ValueError("Vocabulary not built")
154+
155+
sequences: list[list[int]] = []
159156
for utterance in utterances:
160157
words = re.findall(r"\w+", utterance.lower())
161158
# Convert words to indices, using UNK for unknown words
162-
seq = [self._vocab.get(word, self._unk_idx) for word in words] # type: ignore
159+
seq = [self._vocab.get(word, self._unk_idx) for word in words] # type: ignore[union-attr]
163160
# Truncate if too long
164-
seq = seq[:self.max_seq_length]
161+
seq = seq[: self.max_seq_length]
165162
# Pad if too short
166163
seq = seq + [self._padding_idx] * (self.max_seq_length - len(seq))
167164
sequences.append(seq)
@@ -171,20 +168,18 @@ def clear_cache(self) -> None:
171168
self._model = None
172169
torch.cuda.empty_cache()
173170

174-
def _train_model(self, x: Tensor, y: Tensor) -> None:
171+
def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
175172
if self._model is None:
176-
raise RuntimeError("Model not initialized")
177-
173+
raise ValueError("Model not initialized")
174+
178175
dataset = TensorDataset(x, y)
179-
dataloader = DataLoader(
180-
dataset,
181-
batch_size=self.batch_size,
182-
shuffle=True
176+
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
177+
178+
criterion = (
179+
nn.CrossEntropyLoss() if not self._multilabel else nn.BCEWithLogitsLoss()
183180
)
184-
185-
criterion = nn.CrossEntropyLoss() if not self._multilabel else nn.BCEWithLogitsLoss()
186181
optimizer = torch.optim.Adam(self._model.parameters(), lr=self.learning_rate)
187-
182+
188183
self._model.train()
189184
for _ in range(self.num_train_epochs):
190185
for batch_x, batch_y in dataloader:
@@ -193,5 +188,6 @@ def _train_model(self, x: Tensor, y: Tensor) -> None:
193188
loss = criterion(outputs, batch_y)
194189
loss.backward()
195190
optimizer.step()
196-
191+
197192
self._model.eval()
193+

autointent/modules/scoring/_cnn/textcnn.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
"""TextCNN model for text classification."""
22

3-
from typing import Optional, Tuple
4-
53
import torch
6-
from torch import nn
74
import torch.nn.functional as F
5+
from torch import nn
86

97

108
class TextCNN(nn.Module):
@@ -15,15 +13,15 @@ def __init__(
1513
vocab_size: int,
1614
n_classes: int,
1715
embed_dim: int = 128,
18-
kernel_sizes: Tuple[int, ...] = (3, 4, 5),
16+
kernel_sizes: tuple[int, ...] = (3, 4, 5),
1917
num_filters: int = 100,
2018
dropout: float = 0.1,
2119
padding_idx: int = 0,
22-
pretrained_embs: Optional[torch.Tensor] = None,
20+
pretrained_embs: torch.Tensor | None = None,
2321
) -> None:
2422
"""Initialize TextCNN model."""
2523
super().__init__()
26-
24+
2725
if pretrained_embs is not None:
2826
_, embed_dim = pretrained_embs.shape
2927
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True)
@@ -33,7 +31,7 @@ def __init__(
3331
embedding_dim=embed_dim,
3432
padding_idx=padding_idx
3533
)
36-
34+
3735
self.convs = nn.ModuleList([
3836
nn.Conv1d(
3937
in_channels=embed_dim,
@@ -51,4 +49,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5149
x = [F.relu(conv(x)).max(dim=2)[0] for conv in self.convs]
5250
x = torch.cat(x, dim=1)
5351
x = self.dropout(x)
54-
return self.fc(x)
52+
return self.fc(x)

0 commit comments

Comments
 (0)