Skip to content

Commit 1c8575a

Browse files
committed
change CNNScorer similar to RNN one
1 parent 135ca45 commit 1c8575a

File tree

2 files changed

+142
-136
lines changed

2 files changed

+142
-136
lines changed

autointent/configs/_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ class CrossEncoderConfig(HFModelConfig):
125125

126126
class CNNConfig(BaseModel):
127127
model_config = ConfigDict(extra="forbid")
128+
device: str | None = Field(None, description="Torch notation for CPU or CUDA.")
128129
max_seq_length: int = Field(128, description="Maximum sequence length.")
129130
padding_idx: int = Field(0, description="Index used for padding.")
130-
unknown_idx: int = Field(1, description="Index used for unknown.")
131131
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
132132

133133
@classmethod
Lines changed: 141 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1-
"""CNNScorer class for scoring."""
2-
3-
import re
4-
from collections import Counter
51
from typing import Any
62

73
import numpy as np
84
import numpy.typing as npt
95
import torch
106
from torch import nn
11-
from torch.utils.data import DataLoader, TensorDataset
7+
from torch.optim import Adam
128

139
from autointent import Context
1410
from autointent._callbacks import REPORTERS_NAMES
@@ -19,190 +15,200 @@
1915

2016

2117
class CNNScorer(BaseScorer):
22-
"""Convolutional Neural Network (CNN) scorer for intent classification."""
18+
"""Scorer based on CNN model for text classification."""
2319

2420
name = "cnn"
25-
supports_multilabel = True
2621
supports_multiclass = True
22+
supports_multilabel = True
2723

2824
def __init__(
2925
self,
30-
num_train_epochs: int = 3,
31-
learning_rate: float = 5e-5,
32-
seed: int = 0,
33-
report_to: REPORTERS_NAMES | None = None, # type: ignore[valid-type]
3426
embed_dim: int = 128,
35-
kernel_sizes: list[int] = [3, 4, 5], # noqa: B006
27+
kernel_sizes: list[int] = [3, 4, 5],
3628
num_filters: int = 100,
3729
dropout: float = 0.1,
38-
batch_size: int = 8,
3930
cnn_config: CNNConfig | str | dict[str, Any] | None = None,
31+
num_train_epochs: int = 3,
32+
batch_size: int = 8,
33+
learning_rate: float = 5e-5,
34+
seed: int = 0,
35+
report_to: REPORTERS_NAMES | None = None, # type: ignore # noqa: PGH003
4036
) -> None:
41-
self.num_train_epochs = num_train_epochs
42-
self.learning_rate = learning_rate
43-
self.seed = seed
44-
self.report_to = report_to
37+
"""Initialize the CNN scorer."""
4538
self.embed_dim = embed_dim
4639
self.kernel_sizes = kernel_sizes
4740
self.num_filters = num_filters
4841
self.dropout = dropout
4942
self.cnn_config = CNNConfig.from_search_config(cnn_config)
50-
51-
# Will be initialized during fit()
52-
self._model: TextCNN | None = None
53-
self._vocab: dict[str, int] | None = None
54-
self._unk_token = "<UNK>" # noqa: S105
55-
self._pad_token = "<PAD>" # noqa: S105
56-
self._n_classes: int = 0
57-
self._multilabel: bool = False
58-
self._pad_idx = self.cnn_config.padding_idx
59-
self._unk_idx = self.cnn_config.unknown_idx
60-
self.batch_size = batch_size
61-
self.max_seq_length = self.cnn_config.max_seq_length
43+
self.num_train_epochs = num_train_epochs
44+
self.batch_size = batch_size or self.cnn_config.batch_size
45+
self.learning_rate = learning_rate
46+
self.seed = seed
47+
self.report_to = report_to
48+
self._artifact = None
49+
self._device = self.cnn_config.device or ("cuda" if torch.cuda.is_available() else "cpu")
6250

6351
@classmethod
6452
def from_context(
6553
cls,
6654
context: Context,
55+
embed_dim: int = 128,
56+
kernel_sizes: list[int] = [3, 4, 5],
57+
num_filters: int = 100,
58+
dropout: float = 0.1,
59+
cnn_config: CNNConfig | str | dict[str, Any] | None = None,
6760
num_train_epochs: int = 3,
6861
batch_size: int = 8,
6962
learning_rate: float = 5e-5,
7063
seed: int = 0,
71-
embed_dim: int = 128,
72-
kernel_sizes: list[int] = [3, 4, 5], # noqa: B006
73-
num_filters: int = 100,
74-
dropout: float = 0.1,
75-
cnn_config: CNNConfig | str | dict[str, Any] | None = None
7664
) -> "CNNScorer":
65+
"""Create a CNNScorer from context."""
66+
report_to = context.logging_config.report_to
67+
7768
return cls(
78-
num_train_epochs=num_train_epochs,
79-
batch_size=batch_size,
80-
learning_rate=learning_rate,
81-
seed=seed,
82-
report_to=context.logging_config.report_to,
8369
embed_dim=embed_dim,
8470
kernel_sizes=kernel_sizes,
8571
num_filters=num_filters,
8672
dropout=dropout,
87-
cnn_config=cnn_config
88-
)
89-
90-
def get_implicit_initialization_params(self) -> dict[str, Any]:
91-
return {"cnn_config": self.cnn_config.model_dump()}
92-
93-
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
94-
self._validate_task(labels)
95-
self._multilabel = isinstance(labels[0], (list, np.ndarray)) # noqa: UP038
96-
97-
# Build vocabulary and tokenize
98-
self._build_vocab(utterances)
99-
100-
# Convert text to padded indices
101-
x = self._text_to_indices(utterances)
102-
x_tensor = torch.tensor(x, dtype=torch.long)
103-
y_tensor = torch.tensor(
104-
labels, dtype=torch.long if not self._multilabel else torch.float
73+
cnn_config=cnn_config,
74+
num_train_epochs=num_train_epochs,
75+
batch_size=batch_size,
76+
learning_rate=learning_rate,
77+
seed=seed,
78+
report_to=report_to,
10579
)
10680

107-
# Initialize model
108-
if self._vocab is None:
109-
msg = "Vocabulary not built"
110-
raise ValueError(msg)
111-
81+
def get_embedder_config(self) -> dict[str, Any]:
82+
"""Get the configuration of the embedder."""
83+
config = self.cnn_config.model_dump()
84+
config.update({
85+
"embed_dim": self.embed_dim,
86+
"kernel_sizes": self.kernel_sizes,
87+
"num_filters": self.num_filters,
88+
"dropout": self.dropout,
89+
})
90+
return config
91+
92+
def __initialize_model(self, vocab_size: int) -> None:
93+
"""Initialize the CNN model."""
11294
self._model = TextCNN(
113-
vocab_size=len(self._vocab),
95+
vocab_size=vocab_size,
11496
n_classes=self._n_classes,
11597
embed_dim=self.embed_dim,
11698
kernel_sizes=self.kernel_sizes,
11799
num_filters=self.num_filters,
118100
dropout=self.dropout,
119-
padding_idx=self._pad_idx
101+
padding_idx=self.cnn_config.padding_idx,
102+
pretrained_embs=None,
120103
)
104+
self._model.to(self.device)
105+
106+
def fit(
107+
self,
108+
utterances: list[str],
109+
labels: ListOfLabels,
110+
) -> None:
111+
"""Fit the model to the given data."""
112+
if hasattr(self, "_model"):
113+
self.clear_cache()
114+
self._validate_task(labels)
115+
self._create_vocab(utterances)
116+
self.__initialize_model(len(self._vocab))
117+
x = self._texts_to_sequences(utterances)
118+
y = torch.tensor(labels, dtype=torch.float) if self._multilabel else torch.tensor(labels, dtype=torch.long)
119+
self._train_model(x, y)
120+
121+
def _create_vocab(self, utterances: list[str]) -> None:
122+
"""Create vocabulary from utterances."""
123+
unique_words = set()
124+
for text in utterances:
125+
for word in text.lower().split():
126+
unique_words.add(word)
127+
128+
self._vocab = {"<PAD>": 0, "<UNK>": 1}
129+
for i, word in enumerate(unique_words):
130+
self._vocab[word] = i + 2
131+
132+
def _texts_to_sequences(self, texts: list[str]) -> torch.Tensor:
133+
"""Convert texts to sequences using the vocabulary."""
134+
sequences = [[self._vocab.get(word, self._vocab["<UNK>"]) for word in text.lower().split()] for text in texts]
135+
136+
max_len = min(max(len(seq) for seq in sequences), self.cnn_config.max_seq_length)
137+
padded_sequences = [
138+
seq[:max_len] if len(seq) > max_len else seq + [self._vocab["<PAD>"]] * (max_len - len(seq))
139+
for seq in sequences
140+
]
141+
142+
return torch.tensor(padded_sequences, dtype=torch.long)
143+
144+
def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
145+
"""Train the model."""
146+
self._model.train()
147+
optimizer = Adam(self._model.parameters(), lr=self.learning_rate)
148+
149+
criterion = nn.BCEWithLogitsLoss() if self._multilabel else nn.CrossEntropyLoss()
150+
151+
x = x.to(self._device)
152+
y = y.to(self._device)
153+
154+
dataset = torch.utils.data.TensorDataset(x, y)
155+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
156+
157+
torch.manual_seed(self.seed)
158+
159+
for _epoch in range(self.num_train_epochs):
160+
total_loss = 0
161+
for batch_x, batch_y in dataloader:
162+
optimizer.zero_grad()
163+
outputs = self._model(batch_x)
164+
loss = criterion(outputs, batch_y)
165+
loss.backward()
166+
optimizer.step()
167+
total_loss += loss.item()
121168

122-
# Training
123-
self._train_model(x_tensor, y_tensor)
169+
self._model.eval()
124170

125171
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
126-
if self._model is None:
127-
msg = "Model not trained. Call fit() first."
128-
raise ValueError(msg)
172+
"""Predict probabilities for utterances."""
173+
if not hasattr(self, "_model") or not hasattr(self, "_vocab"):
174+
msg = "Model is not trained. Call fit() first."
175+
raise RuntimeError(msg)
129176

130-
x = self._text_to_indices(utterances)
131-
x_tensor = torch.tensor(x, dtype=torch.long)
177+
x = self._texts_to_sequences(utterances)
178+
x = x.to(self.device)
132179

133180
self._model.eval()
134-
all_probs: list[npt.NDArray[Any]] = []
181+
all_predictions = []
135182

136183
with torch.no_grad():
137-
for i in range(0, len(x_tensor), self.batch_size):
138-
batch_x = x_tensor[i : i + self.batch_size]
184+
for i in range(0, len(x), self.batch_size):
185+
batch_x = x[i : i + self.batch_size]
139186
outputs = self._model(batch_x)
187+
140188
if self._multilabel:
141-
probs = torch.sigmoid(outputs).cpu().numpy()
189+
batch_predictions = torch.sigmoid(outputs).cpu().numpy()
142190
else:
143-
probs = torch.softmax(outputs, dim=1).cpu().numpy()
144-
all_probs.append(probs)
145-
146-
return np.concatenate(all_probs, axis=0) if all_probs else np.array([])
147-
148-
def _build_vocab(self, utterances: list[str]) -> None:
149-
"""Build vocabulary from training utterances."""
150-
word_counts: Counter[str] = Counter()
151-
for utterance in utterances:
152-
words = re.findall(r"\w+", utterance.lower())
153-
word_counts.update(words)
154-
155-
# Create vocabulary with special tokens
156-
self._vocab = {self._pad_token: self._pad_idx, self._unk_token: self._unk_idx}
157-
158-
# Convert Counter to list of (word, count) tuples sorted by frequency
159-
sorted_words = word_counts.most_common()
160-
for word, _ in sorted_words:
161-
if word not in self._vocab:
162-
self._vocab[word] = len(self._vocab)
163-
164-
def _text_to_indices(self, utterances: list[str]) -> list[list[int]]:
165-
"""Convert utterances to padded sequences of word indices."""
166-
if self._vocab is None:
167-
msg = "Vocabulary not built"
168-
raise ValueError(msg)
169-
170-
sequences: list[list[int]] = []
171-
for utterance in utterances:
172-
words = re.findall(r"\w+", utterance.lower())
173-
# Convert words to indices, using UNK for unknown words
174-
seq = [self._vocab.get(word, self._unk_idx) for word in words]
175-
# Truncate if too long
176-
seq = seq[: self.max_seq_length]
177-
# Pad if too short
178-
seq = seq + [self._pad_idx] * (self.max_seq_length - len(seq))
179-
sequences.append(seq)
180-
return sequences
191+
batch_predictions = torch.softmax(outputs, dim=1).cpu().numpy()
181192

182-
def clear_cache(self) -> None:
183-
self._model = None
184-
torch.cuda.empty_cache()
193+
all_predictions.append(batch_predictions)
185194

186-
def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
187-
if self._model is None:
188-
msg = "Model not initialized"
189-
raise ValueError(msg)
195+
return np.vstack(all_predictions) if all_predictions else np.array([])
190196

191-
dataset = TensorDataset(x, y)
192-
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
197+
def clear_cache(self) -> None:
198+
"""Clear model cache."""
199+
if hasattr(self, "_model"):
200+
del self._model
193201

194-
criterion = (
195-
nn.CrossEntropyLoss() if not self._multilabel else nn.BCEWithLogitsLoss()
196-
)
197-
optimizer = torch.optim.Adam(self._model.parameters(), lr=self.learning_rate)
202+
@property
203+
def device(self) -> str:
204+
"""Get device used for model computations."""
205+
return self._device
198206

199-
self._model.train()
200-
for _ in range(self.num_train_epochs):
201-
for batch_x, batch_y in dataloader:
202-
optimizer.zero_grad()
203-
outputs = self._model(batch_x)
204-
loss = criterion(outputs, batch_y)
205-
loss.backward()
206-
optimizer.step()
207+
@device.setter
208+
def device(self, value: str) -> None:
209+
"""Set device for model computations."""
210+
self._device = value
207211

208-
self._model.eval()
212+
def get_implicit_initialization_params(self) -> dict[str, Any]:
213+
"""Return default params used in ``__init__`` method."""
214+
return {"cnn_config": self.cnn_config.model_dump()}

0 commit comments

Comments
 (0)