Skip to content

Commit 72cc7dc

Browse files
SeBorgeyvoorhsgithub-actions[bot]
authored
Rnn scorer (#190)
* first code for rnn scorer * config fix * last ruff fix * tests * typing * device * dump load test * parameters * upgrade dumper for rnn, upgrade tests for new config * mypy fix * fix tests exept dumpload * dumpload test fix * codestyle * refactor working with vocab * refactor cnn and rnn * fix typing and codestyle * refactor init arguments of rnn and cnn scorers * add strict mode to `Dumper` * codestyle * Update optimizer_config.schema.json * bug fix * add TODO comments * try to implement early stopping * add logging messages --------- Co-authored-by: voorhs <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent ac1b732 commit 72cc7dc

File tree

22 files changed

+972
-347
lines changed

22 files changed

+972
-347
lines changed

autointent/_dump_tools.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from peft import PeftModel
1212
from pydantic import BaseModel
1313
from sklearn.base import BaseEstimator
14-
from torch import nn
1514
from transformers import ( # type: ignore[attr-defined]
1615
AutoModelForSequenceClassification,
1716
AutoTokenizer,
@@ -21,15 +20,22 @@
2120
)
2221

2322
from autointent import Embedder, Ranker, VectorIndex
24-
from autointent._wrappers import BaseTorchModule
23+
from autointent._wrappers import BaseTorchModuleWithVocab
2524
from autointent.configs import CrossEncoderConfig, EmbedderConfig
2625
from autointent.context.optimization_info import Artifact
2726
from autointent.schemas import TagsList
2827

2928
ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
3029

3130
ModuleAttributes: TypeAlias = (
32-
ModuleSimpleAttributes | TagsList | np.ndarray | Embedder | VectorIndex | BaseEstimator | Ranker | nn.Module # type: ignore[type-arg]
31+
ModuleSimpleAttributes
32+
| TagsList
33+
| np.ndarray # type: ignore[type-arg]
34+
| Embedder
35+
| VectorIndex
36+
| BaseEstimator
37+
| Ranker
38+
| BaseTorchModuleWithVocab
3339
)
3440

3541
logger = logging.getLogger(__name__)
@@ -75,14 +81,21 @@ def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
7581
subdir.mkdir(parents=True, exist_ok=exists_ok)
7682

7783
@staticmethod
78-
def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]] | None = None) -> None: # noqa: ANN401, C901, PLR0912, PLR0915
84+
def dump( # noqa: C901, PLR0912, PLR0915
85+
obj: Any, # noqa: ANN401
86+
path: Path,
87+
exists_ok: bool = False,
88+
exclude: list[type[Any]] | None = None,
89+
raise_errors: bool = False,
90+
) -> None:
7991
"""Dump modules attributes to filestystem.
8092
8193
Args:
8294
obj: Object to dump
8395
path: Path to dump to
8496
exists_ok: If True, do not raise an error if the directory already exists
8597
exclude: List of types to exclude from dumping
98+
raise_errors: whether to raise dumping errors or just log
8699
"""
87100
attrs: dict[str, ModuleAttributes] = vars(obj)
88101
simple_attrs = {}
@@ -119,25 +132,29 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
119132
except Exception as e:
120133
msg = f"Error dumping pydantic model {key}: {e}"
121134
logging.exception(msg)
135+
if raise_errors:
136+
raise
122137
elif isinstance(val, PeftModel):
123138
# dumping peft models is a nightmare...
124139
# this might break with new versions of peft
125140
try:
126141
if val._is_prompt_learning: # noqa: SLF001
127142
# strategy to save prompt learning models: save prompt encoder and bert classifier separately
128143
model_path = path / Dumper.ptuning_models / key
129-
model_path.mkdir(parents=True, exist_ok=True)
144+
model_path.mkdir(parents=True, exist_ok=exists_ok)
130145
val.save_pretrained(str(model_path / "peft"))
131146
val.base_model.save_pretrained(model_path / "base_model") # type: ignore[attr-defined]
132147
else:
133148
# strategy to save lora models: merge adapters and save as usual hugging face model
134149
model_path = path / Dumper.hf_models / key
135-
model_path.mkdir(parents=True, exist_ok=True)
150+
model_path.mkdir(parents=True, exist_ok=exists_ok)
136151
merged_model: PreTrainedModel = val.merge_and_unload()
137152
merged_model.save_pretrained(model_path) # type: ignore[attr-defined]
138153
except Exception as e:
139154
msg = f"Error dumping PeftModel {key}: {e}"
140155
logger.exception(msg)
156+
if raise_errors:
157+
raise
141158
elif isinstance(val, PreTrainedModel):
142159
model_path = path / Dumper.hf_models / key
143160
model_path.mkdir(parents=True, exist_ok=True)
@@ -146,7 +163,9 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
146163
except Exception as e:
147164
msg = f"Error dumping HF model {key}: {e}"
148165
logger.exception(msg)
149-
elif isinstance(val, BaseTorchModule):
166+
if raise_errors:
167+
raise
168+
elif isinstance(val, BaseTorchModuleWithVocab):
150169
model_path = path / Dumper.torch_models / key
151170
model_path.mkdir(parents=True, exist_ok=True)
152171
try:
@@ -160,6 +179,8 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
160179
except Exception as e:
161180
msg = f"Error dumping torch model {key}: {e}"
162181
logger.exception(msg)
182+
if raise_errors:
183+
raise
163184
elif isinstance(val, PreTrainedTokenizer | PreTrainedTokenizerFast):
164185
tokenizer_path = path / Dumper.hf_tokenizers / key
165186
tokenizer_path.mkdir(parents=True, exist_ok=True)
@@ -168,11 +189,15 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
168189
except Exception as e:
169190
msg = f"Error dumping HF tokenizer {key}: {e}"
170191
logger.exception(msg)
192+
if raise_errors:
193+
raise
171194
elif isinstance(val, CatBoostClassifier):
172195
val.save_model(str(path / Dumper.catboost_models / key), format="cbm")
173196
else:
174197
msg = f"Attribute {key} of type {type(val)} cannot be dumped to file system."
175198
logger.error(msg)
199+
if raise_errors:
200+
raise TypeError(msg)
176201

177202
with (path / Dumper.simple_attrs).open("w", encoding="utf-8") as file:
178203
json.dump(simple_attrs, file, ensure_ascii=False, indent=4)
@@ -185,6 +210,7 @@ def load( # noqa: C901, PLR0912, PLR0915
185210
path: Path,
186211
embedder_config: EmbedderConfig | None = None,
187212
cross_encoder_config: CrossEncoderConfig | None = None,
213+
raise_errors: bool = False,
188214
) -> None:
189215
"""Load attributes from file system."""
190216
tags: dict[str, Any] = {}
@@ -250,7 +276,8 @@ def load( # noqa: C901, PLR0912, PLR0915
250276
except Exception as e:
251277
msg = f"Error loading Pydantic model from {model_dir}: {e}"
252278
logger.exception(msg)
253-
continue
279+
if raise_errors:
280+
raise
254281
elif child.name == Dumper.ptuning_models:
255282
for model_dir in child.iterdir():
256283
try:
@@ -259,20 +286,26 @@ def load( # noqa: C901, PLR0912, PLR0915
259286
except Exception as e: # noqa: PERF203
260287
msg = f"Error loading PeftModel {model_dir.name}: {e}"
261288
logger.exception(msg)
289+
if raise_errors:
290+
raise
262291
elif child.name == Dumper.hf_models:
263292
for model_dir in child.iterdir():
264293
try:
265294
hf_models[model_dir.name] = AutoModelForSequenceClassification.from_pretrained(model_dir) # type: ignore[no-untyped-call]
266295
except Exception as e: # noqa: PERF203
267296
msg = f"Error loading HF model {model_dir.name}: {e}"
268297
logger.exception(msg)
298+
if raise_errors:
299+
raise
269300
elif child.name == Dumper.hf_tokenizers:
270301
for tokenizer_dir in child.iterdir():
271302
try:
272303
hf_tokenizers[tokenizer_dir.name] = AutoTokenizer.from_pretrained(tokenizer_dir)
273304
except Exception as e: # noqa: PERF203
274305
msg = f"Error loading HF tokenizer {tokenizer_dir.name}: {e}"
275306
logger.exception(msg)
307+
if raise_errors:
308+
raise
276309
elif child.name == Dumper.catboost_models:
277310
for model_file in child.iterdir():
278311
try:
@@ -288,15 +321,19 @@ def load( # noqa: C901, PLR0912, PLR0915
288321
with (model_dir / "class_info.json").open("r") as f:
289322
class_info = json.load(f)
290323
module = importlib.import_module(class_info["module"])
291-
model_class: BaseTorchModule = getattr(module, class_info["name"])
324+
model_class: BaseTorchModuleWithVocab = getattr(module, class_info["name"])
292325
model = model_class.load(model_dir)
293326
torch_models[model_dir.name] = model
294327
except Exception as e:
295328
msg = f"Error loading torch model {model_dir.name}: {e}"
296329
logger.exception(msg)
330+
if raise_errors:
331+
raise
297332
else:
298333
msg = f"Found unexpected child {child}"
299334
logger.error(msg)
335+
if raise_errors:
336+
raise ValueError(msg)
300337

301338
obj.__dict__.update(
302339
tags

autointent/_wrappers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .ranker import Ranker
22
from .embedder import Embedder
33
from .vector_index import VectorIndex
4-
from .base_torch_module import BaseTorchModule
4+
from .base_torch_module import BaseTorchModuleWithVocab
55

6-
__all__ = ["BaseTorchModule", "Embedder", "Ranker", "VectorIndex"]
6+
__all__ = ["BaseTorchModuleWithVocab", "Embedder", "Ranker", "VectorIndex"]

autointent/_wrappers/base_torch_module.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,97 @@
1+
"""Torch model for text classification."""
2+
3+
import re
14
from abc import ABC, abstractmethod
5+
from collections import Counter
26
from pathlib import Path
7+
from typing import Any
38

49
import torch
510
from torch import nn
611
from typing_extensions import Self
712

13+
from autointent.configs import VocabConfig
14+
15+
16+
class BaseTorchModuleWithVocab(nn.Module, ABC):
17+
def __init__(
18+
self,
19+
embed_dim: int,
20+
vocab_config: VocabConfig | None = None,
21+
) -> None:
22+
super().__init__()
23+
24+
self.embed_dim = embed_dim
25+
self.vocab_config = VocabConfig.from_search_config(vocab_config)
26+
27+
# Vocabulary management
28+
self._unk_token = "<UNK>" # noqa: S105
29+
self._pad_token = "<PAD>" # noqa: S105
30+
self._unk_idx = 1
31+
32+
if self.vocab_config.vocab is not None:
33+
self.set_vocab(self.vocab_config.vocab)
34+
35+
def set_vocab(self, vocab: dict[str, Any]) -> None:
36+
"""Save vocabulary into module's attributes and initialize embeddings matrix."""
37+
self.vocab_config.vocab = vocab
38+
self.embedding = nn.Embedding(
39+
num_embeddings=len(self.vocab_config.vocab),
40+
embedding_dim=self.embed_dim,
41+
padding_idx=self.vocab_config.padding_idx,
42+
)
43+
44+
def build_vocab(self, utterances: list[str]) -> None:
45+
"""Build vocabulary from training utterances."""
46+
if self.vocab_config.vocab is not None:
47+
msg = "Vocab is already built."
48+
raise RuntimeError(msg)
49+
50+
word_counts: Counter[str] = Counter()
51+
for utterance in utterances:
52+
words = re.findall(r"\w+", utterance.lower())
53+
word_counts.update(words)
54+
55+
# Create vocabulary with special tokens
56+
vocab = {self._pad_token: self.vocab_config.padding_idx, self._unk_token: self._unk_idx}
57+
58+
# Convert Counter to list of (word, count) tuples sorted by frequency
59+
sorted_words = word_counts.most_common(self.vocab_config.max_vocab_size)
60+
for word, _ in sorted_words:
61+
if word not in vocab:
62+
vocab[word] = len(vocab)
63+
64+
self.set_vocab(vocab)
65+
66+
def text_to_indices(self, utterances: list[str]) -> list[list[int]]:
67+
"""Convert utterances to padded sequences of word indices."""
68+
if self.vocab_config.vocab is None:
69+
msg = "Vocab is not built."
70+
raise RuntimeError(msg)
71+
72+
sequences: list[list[int]] = []
73+
for utterance in utterances:
74+
words = re.findall(r"\w+", utterance.lower())
75+
# Convert words to indices, using UNK for unknown words
76+
seq = [self.vocab_config.vocab.get(word, self._unk_idx) for word in words]
77+
# Truncate if too long
78+
seq = seq[: self.vocab_config.max_seq_length]
79+
# Pad if too short
80+
seq = seq + [self.vocab_config.padding_idx] * (self.vocab_config.max_seq_length - len(seq))
81+
sequences.append(seq)
82+
return sequences
83+
84+
@abstractmethod
85+
def forward(self, text: torch.Tensor) -> torch.Tensor:
86+
"""Compute sentence embeddings for given text.
87+
88+
Args:
89+
text: torch tensor of shape (B, T), token ids
90+
91+
Returns:
92+
embeddings of shape (B, H)
93+
"""
894

9-
class BaseTorchModule(nn.Module, ABC):
1095
@abstractmethod
1196
def dump(self, path: Path) -> None:
1297
"""Dump torch module to disk.

autointent/configs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._inference_node import InferenceNodeConfig
44
from ._optimization import DataConfig, HPOConfig, LoggingConfig
5+
from ._torch import TorchTrainingConfig, VocabConfig
56
from ._transformers import (
67
CrossEncoderConfig,
78
EarlyStoppingConfig,
@@ -19,8 +20,9 @@
1920
"HFModelConfig",
2021
"HPOConfig",
2122
"InferenceNodeConfig",
22-
"InferenceNodeConfig",
2323
"LoggingConfig",
2424
"TaskTypeEnum",
2525
"TokenizerConfig",
26+
"TorchTrainingConfig",
27+
"VocabConfig",
2628
]

autointent/configs/_torch.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Any
2+
3+
from pydantic import BaseModel, ConfigDict, Field
4+
from typing_extensions import Self
5+
6+
from autointent._callbacks import REPORTERS_NAMES
7+
from autointent._utils import detect_device
8+
9+
10+
class FromDictMixin:
11+
@classmethod
12+
def from_search_config(cls, values: dict[str, Any] | BaseModel | None) -> Self:
13+
"""Validate the model configuration.
14+
15+
This classmethod is used to parse dictionaries that occur in search space configurations.
16+
17+
Args:
18+
values: Model configuration values.
19+
20+
Returns:
21+
Model configuration.
22+
"""
23+
if values is None:
24+
return cls()
25+
if isinstance(values, BaseModel):
26+
return values # type: ignore[return-value]
27+
return cls(**values)
28+
29+
30+
class VocabConfig(BaseModel, FromDictMixin):
31+
model_config = ConfigDict(extra="forbid")
32+
padding_idx: int = 0
33+
max_seq_length: int = 50
34+
vocab: dict[str, int] | None = None
35+
max_vocab_size: int | None = None
36+
37+
38+
class TorchTrainingConfig(BaseModel, FromDictMixin):
39+
model_config = ConfigDict(extra="forbid")
40+
num_train_epochs: int = 3
41+
batch_size: int = 8
42+
learning_rate: float = 5e-5
43+
seed: int = 42
44+
report_to: REPORTERS_NAMES | None = None # type: ignore # noqa: PGH003
45+
device: str = Field(default_factory=detect_device)

0 commit comments

Comments
 (0)