Skip to content

Commit abb4e34

Browse files
committed
fix ruff and mypy
1 parent 6c98caa commit abb4e34

File tree

2 files changed

+84
-55
lines changed

2 files changed

+84
-55
lines changed

autointent/modules/scoring/_cnn/cnn.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""CNNScorer class for scoring."""
22

3-
from collections import Counter
3+
from __future__ import annotations
4+
45
import re
5-
from typing import Any
6+
from collections import Counter
7+
from typing import Any, Dict, List, Optional, Union
68

79
import numpy as np
810
import numpy.typing as npt
9-
from torch import nn
1011
import torch
11-
from torch.utils.data import TensorDataset, DataLoader
12+
from torch import nn, Tensor
13+
from torch.utils.data import DataLoader, TensorDataset
1214

1315
from autointent import Context
1416
from autointent._callbacks import REPORTERS_NAMES
@@ -21,8 +23,6 @@ class CNNScorer(BaseScorer):
2123
"""Convolutional Neural Network (CNN) scorer for intent classification."""
2224

2325
name = "cnn"
24-
_n_classes: int
25-
_multilabel: bool
2626
supports_multilabel = True
2727
supports_multiclass = True
2828

@@ -33,8 +33,8 @@ def __init__(
3333
batch_size: int = 8,
3434
learning_rate: float = 5e-5,
3535
seed: int = 0,
36-
report_to: REPORTERS_NAMES | None = None,
37-
**cnn_kwargs: dict[str, Any],
36+
report_to: REPORTERS_NAMES | None = None, # type: ignore[no-any-return]
37+
**cnn_kwargs: Dict[str, Any],
3838
) -> None:
3939
self.max_seq_length = max_seq_length
4040
self.num_train_epochs = num_train_epochs
@@ -45,11 +45,14 @@ def __init__(
4545
self.cnn_config = cnn_kwargs
4646

4747
# Will be initialized during fit()
48-
self._model = None
49-
self._vocab = None
48+
self._model: Optional[TextCNN] = None
49+
self._vocab: Optional[Dict[str, int]] = None
5050
self._padding_idx = 0
5151
self._unk_token = "<UNK>" # noqa: S105
5252
self._pad_token = "<PAD>" # noqa: S105
53+
self._unk_idx = 1
54+
self._n_classes: int = 0
55+
self._multilabel: bool = False
5356

5457
@classmethod
5558
def from_context(
@@ -59,7 +62,7 @@ def from_context(
5962
batch_size: int = 8,
6063
learning_rate: float = 5e-5,
6164
seed: int = 0,
62-
**cnn_kwargs: dict[str, Any],
65+
**cnn_kwargs: Dict[str, Any],
6366
) -> "CNNScorer":
6467
return cls(
6568
num_train_epochs=num_train_epochs,
@@ -70,22 +73,23 @@ def from_context(
7073
**cnn_kwargs,
7174
)
7275

73-
def fit(self, utterances: list[str], labels: ListOfLabels, clear_cache: bool = False) -> None:
74-
if clear_cache:
75-
self.clear_cache()
76-
76+
def fit(self, utterances: List[str], labels: ListOfLabels) -> None:
7777
self._validate_task(labels)
78-
self._multilabel = isinstance(labels[0], list | np.ndarray)
78+
self._multilabel = isinstance(labels[0], (list, np.ndarray))
79+
self._n_classes = len(labels[0]) if self._multilabel else len(set(labels))
7980

8081
# Build vocabulary and tokenize
8182
self._build_vocab(utterances)
8283

8384
# Convert text to padded indices
8485
x = self._text_to_indices(utterances)
85-
x = torch.tensor(x, dtype=torch.long)
86-
y = torch.tensor(labels, dtype=torch.long)
86+
x_tensor = torch.tensor(x, dtype=torch.long)
87+
y_tensor = torch.tensor(labels, dtype=torch.long if not self._multilabel else torch.float)
8788

8889
# Initialize model
90+
if self._vocab is None:
91+
raise RuntimeError("Vocabulary not built")
92+
8993
self._model = TextCNN(
9094
vocab_size=len(self._vocab),
9195
n_classes=self._n_classes,
@@ -98,22 +102,21 @@ def fit(self, utterances: list[str], labels: ListOfLabels, clear_cache: bool = F
98102
)
99103

100104
# Training
101-
self._train_model(x, y)
105+
self._train_model(x_tensor, y_tensor)
102106

103-
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
107+
def predict(self, utterances: List[str]) -> npt.NDArray[Any]:
104108
if self._model is None:
105-
error_msg = "Model not trained. Call fit() first."
106-
raise RuntimeError(error_msg)
109+
raise RuntimeError("Model not trained. Call fit() first.")
107110

108111
x = self._text_to_indices(utterances)
109-
x = torch.tensor(x, dtype=torch.long)
112+
x_tensor = torch.tensor(x, dtype=torch.long)
110113

111114
self._model.eval()
112-
all_probs = []
115+
all_probs: List[npt.NDArray[Any]] = []
113116

114117
with torch.no_grad():
115-
for i in range(0, len(x), self.batch_size):
116-
batch_x = x[i:i+self.batch_size]
118+
for i in range(0, len(x_tensor), self.batch_size):
119+
batch_x = x_tensor[i:i+self.batch_size]
117120
outputs = self._model(batch_x)
118121
if self._multilabel:
119122
probs = torch.sigmoid(outputs).cpu().numpy()
@@ -123,9 +126,9 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
123126

124127
return np.concatenate(all_probs, axis=0) if all_probs else np.array([])
125128

126-
def _build_vocab(self, utterances: list[str]) -> None:
129+
def _build_vocab(self, utterances: List[str]) -> None:
127130
"""Build vocabulary from training utterances."""
128-
word_counts = Counter()
131+
word_counts: Dict[str, int] = Counter()
129132
for utterance in utterances:
130133
words = re.findall(r"\w+", utterance.lower())
131134
word_counts.update(words)
@@ -137,20 +140,26 @@ def _build_vocab(self, utterances: list[str]) -> None:
137140
}
138141

139142
# Add words to vocabulary
143+
if self._vocab is None:
144+
raise RuntimeError("Vocabulary not initialized")
145+
140146
for word, _ in word_counts.most_common():
141147
if word not in self._vocab:
142148
self._vocab[word] = len(self._vocab)
143149

144150
self._unk_idx = 1
145151
self._padding_idx = 0
146152

147-
def _text_to_indices(self, utterances: list[str]) -> list[list[int]]:
153+
def _text_to_indices(self, utterances: List[str]) -> List[List[int]]:
148154
"""Convert utterances to padded sequences of word indices."""
149-
sequences = []
155+
if self._vocab is None:
156+
raise RuntimeError("Vocabulary not built")
157+
158+
sequences: List[List[int]] = []
150159
for utterance in utterances:
151160
words = re.findall(r"\w+", utterance.lower())
152161
# Convert words to indices, using UNK for unknown words
153-
seq = [self._vocab.get(word, self._unk_idx) for word in words]
162+
seq = [self._vocab.get(word, self._unk_idx) for word in words] # type: ignore
154163
# Truncate if too long
155164
seq = seq[:self.max_seq_length]
156165
# Pad if too short
@@ -162,7 +171,10 @@ def clear_cache(self) -> None:
162171
self._model = None
163172
torch.cuda.empty_cache()
164173

165-
def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
174+
def _train_model(self, x: Tensor, y: Tensor) -> None:
175+
if self._model is None:
176+
raise RuntimeError("Model not initialized")
177+
166178
dataset = TensorDataset(x, y)
167179
dataloader = DataLoader(
168180
dataset,
@@ -182,4 +194,4 @@ def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
182194
loss.backward()
183195
optimizer.step()
184196

185-
self._model.eval()
197+
self._model.eval()
Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,54 @@
1+
"""TextCNN model for text classification."""
2+
3+
from typing import Optional, Tuple
4+
15
import torch
2-
import torch.nn as nn
6+
from torch import nn
37
import torch.nn.functional as F
48

9+
510
class TextCNN(nn.Module):
6-
def __init__(self,
7-
vocab_size,
8-
n_classes,
9-
embed_dim=128,
10-
kernel_sizes=(3, 4, 5),
11-
num_filters=100,
12-
dropout=0.1,
13-
padding_idx=0,
14-
pretrained_embs=None
15-
):
16-
super(TextCNN, self).__init__()
11+
"""TextCNN model implementation."""
12+
13+
def __init__(
14+
self,
15+
vocab_size: int,
16+
n_classes: int,
17+
embed_dim: int = 128,
18+
kernel_sizes: Tuple[int, ...] = (3, 4, 5),
19+
num_filters: int = 100,
20+
dropout: float = 0.1,
21+
padding_idx: int = 0,
22+
pretrained_embs: Optional[torch.Tensor] = None,
23+
) -> None:
24+
"""Initialize TextCNN model."""
25+
super().__init__()
1726

18-
if pretrained_embs != None:
27+
if pretrained_embs is not None:
1928
_, embed_dim = pretrained_embs.shape
2029
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True)
2130
else:
22-
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
31+
self.embedding = nn.Embedding(
32+
num_embeddings=vocab_size,
33+
embedding_dim=embed_dim,
34+
padding_idx=padding_idx
35+
)
36+
2337
self.convs = nn.ModuleList([
24-
nn.Conv1d(in_channels=embed_dim, out_channels=num_filters, kernel_size=k) for k in kernel_sizes
38+
nn.Conv1d(
39+
in_channels=embed_dim,
40+
out_channels=num_filters,
41+
kernel_size=k
42+
) for k in kernel_sizes
2543
])
2644
self.dropout = nn.Dropout(dropout)
2745
self.fc = nn.Linear(num_filters * len(kernel_sizes), n_classes)
2846

29-
def forward(self, x):
30-
x = self.embedding(x)
47+
def forward(self, x: torch.Tensor) -> torch.Tensor:
48+
"""Forward pass of the model."""
49+
x = self.embedding(x)
3150
x = x.permute(0, 2, 1)
32-
x = [F.relu(conv(x)).max(dim=2)[0] for conv in self.convs]
33-
x = torch.cat(x, dim=1)
51+
x = [F.relu(conv(x)).max(dim=2)[0] for conv in self.convs]
52+
x = torch.cat(x, dim=1)
3453
x = self.dropout(x)
35-
x = self.fc(x)
36-
37-
return x
54+
return self.fc(x)

0 commit comments

Comments
 (0)