Skip to content

Commit 644a849

Browse files
SeBorgeyvoorhsgithub-actions[bot]
authored
full tuning (#165)
* Added code for full tuning * work on review * renaming * fix ruff * mypy test * ignote mypy * Feat/bert scorer config refactoring (#168) * refactor configs * add proper configs to BERTScorer * fix typing * fix tokenizer's parameters * fix transformers and accelerate issue * Update optimizer_config.schema.json * bug fix * update callback test * fix tests --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * delete validate_task * report_to * batches * Fix/docs building for bert scorer (#171) * fix * fix codestyle --------- Co-authored-by: Алексеев Илья <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 1ee5366 commit 644a849

File tree

16 files changed

+354
-77
lines changed

16 files changed

+354
-77
lines changed

autointent/_embedder.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __hash__(self) -> int:
8787
hasher = Hasher()
8888
for parameter in self.embedding_model.parameters():
8989
hasher.update(parameter.detach().cpu().numpy())
90-
hasher.update(self.config.max_length)
90+
hasher.update(self.config.tokenizer_config.max_length)
9191
return hasher.intdigest()
9292

9393
def clear_ram(self) -> None:
@@ -114,7 +114,7 @@ def dump(self, path: Path) -> None:
114114
model_name=str(self.config.model_name),
115115
device=self.config.device,
116116
batch_size=self.config.batch_size,
117-
max_length=self.config.max_length,
117+
max_length=self.config.tokenizer_config.max_length,
118118
use_cache=self.config.use_cache,
119119
)
120120
path.mkdir(parents=True, exist_ok=True)
@@ -137,6 +137,10 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
137137
else:
138138
kwargs = metadata # type: ignore[assignment]
139139

140+
max_length = kwargs.pop("max_length", None)
141+
if max_length is not None:
142+
kwargs["tokenizer_config"] = {"max_length": max_length}
143+
140144
return cls(EmbedderConfig(**kwargs))
141145

142146
def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]:
@@ -162,12 +166,12 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
162166
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
163167
self.config.model_name,
164168
self.config.batch_size,
165-
str(self.config.max_length),
169+
str(self.config.tokenizer_config.max_length),
166170
self.config.device,
167171
)
168172

169-
if self.config.max_length is not None:
170-
self.embedding_model.max_seq_length = self.config.max_length
173+
if self.config.tokenizer_config.max_length is not None:
174+
self.embedding_model.max_seq_length = self.config.tokenizer_config.max_length
171175

172176
embeddings = self.embedding_model.encode(
173177
utterances,

autointent/_ranker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
self.config.model_name,
114114
trust_remote_code=True,
115115
device=self.config.device,
116-
max_length=self.config.max_length, # type: ignore[arg-type]
116+
max_length=self.config.tokenizer_config.max_length, # type: ignore[arg-type]
117117
)
118118
self._train_head = False
119119
self._clf = classifier_head
@@ -252,7 +252,7 @@ def save(self, path: str) -> None:
252252
model_name=self.config.model_name,
253253
train_head=self._train_head,
254254
device=self.config.device,
255-
max_length=self.config.max_length,
255+
max_length=self.config.tokenizer_config.max_length,
256256
batch_size=self.config.batch_size,
257257
)
258258

@@ -282,6 +282,10 @@ def load(cls, path: Path, override_config: CrossEncoderConfig | None = None) ->
282282
else:
283283
kwargs = metadata # type: ignore[assignment]
284284

285+
max_length = kwargs.pop("max_length", None)
286+
if max_length is not None:
287+
kwargs["tokenizer_config"] = {"max_length": max_length}
288+
285289
return cls(
286290
CrossEncoderConfig(**kwargs),
287291
classifier_head=clf,

autointent/_vector_index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import numpy.typing as npt
1616

1717
from autointent import Embedder
18-
from autointent.configs import EmbedderConfig, TaskTypeEnum
18+
from autointent.configs import EmbedderConfig, TaskTypeEnum, TokenizerConfig
1919
from autointent.custom_types import ListOfLabels
2020

2121

@@ -195,7 +195,7 @@ def dump(self, dir_path: Path) -> None:
195195
json.dump(data, file, indent=4, ensure_ascii=False)
196196

197197
metadata = VectorIndexMetadata(
198-
embedder_max_length=self.embedder.config.max_length,
198+
embedder_max_length=self.embedder.config.tokenizer_config.max_length,
199199
embedder_model_name=str(self.embedder.config.model_name),
200200
embedder_device=self.embedder.config.device,
201201
embedder_batch_size=self.embedder.config.batch_size,
@@ -229,7 +229,7 @@ def load(
229229
model_name=metadata["embedder_model_name"],
230230
device=embedder_device or metadata["embedder_device"],
231231
batch_size=embedder_batch_size or metadata["embedder_batch_size"],
232-
max_length=metadata["embedder_max_length"],
232+
tokenizer_config=TokenizerConfig(max_length=metadata["embedder_max_length"]),
233233
use_cache=embedder_use_cache or metadata["embedder_use_cache"],
234234
)
235235
)

autointent/configs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22

33
from ._inference_node import InferenceNodeConfig
44
from ._optimization import DataConfig, LoggingConfig
5-
from ._transformers import CrossEncoderConfig, EmbedderConfig, TaskTypeEnum
5+
from ._transformers import CrossEncoderConfig, EmbedderConfig, HFModelConfig, TaskTypeEnum, TokenizerConfig
66

77
__all__ = [
88
"CrossEncoderConfig",
99
"DataConfig",
1010
"EmbedderConfig",
11+
"HFModelConfig",
1112
"InferenceNodeConfig",
1213
"InferenceNodeConfig",
1314
"LoggingConfig",
1415
"TaskTypeEnum",
16+
"TokenizerConfig",
1517
]

autointent/configs/_transformers.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
from enum import Enum
2-
from typing import Any
2+
from typing import Any, Literal
33

44
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
55
from typing_extensions import Self, assert_never
66

77

8-
class ModelConfig(BaseModel):
9-
model_config = ConfigDict(extra="forbid")
10-
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
8+
class TokenizerConfig(BaseModel):
9+
padding: bool | Literal["longest", "max_length", "do_not_pad"] = True
10+
truncation: bool = True
1111
max_length: PositiveInt | None = Field(None, description="Maximum length of input sequences.")
1212

1313

14-
class STModelConfig(ModelConfig):
15-
model_name: str
14+
class HFModelConfig(BaseModel):
15+
model_config = ConfigDict(extra="forbid")
16+
model_name: str = Field(
17+
"prajjwal1/bert-tiny", description="Name of the hugging face repository with transformer model."
18+
)
19+
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
1620
device: str | None = Field(None, description="Torch notation for CPU or CUDA.")
21+
tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig)
1722

1823
@classmethod
1924
def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self:
@@ -26,7 +31,7 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) ->
2631
Model configuration.
2732
"""
2833
if values is None:
29-
return cls() # type: ignore[call-arg]
34+
return cls()
3035
if isinstance(values, BaseModel):
3136
return values # type: ignore[return-value]
3237
if isinstance(values, str):
@@ -45,7 +50,7 @@ class TaskTypeEnum(Enum):
4550
sts = "sts"
4651

4752

48-
class EmbedderConfig(STModelConfig):
53+
class EmbedderConfig(HFModelConfig):
4954
model_name: str = Field("sentence-transformers/all-MiniLM-L6-v2", description="Name of the hugging face model.")
5055
default_prompt: str | None = Field(
5156
None, description="Default prompt for the model. This is used when no task specific prompt is not provided."
@@ -105,7 +110,7 @@ def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # no
105110
use_cache: bool = Field(False, description="Whether to use embeddings caching.")
106111

107112

108-
class CrossEncoderConfig(STModelConfig):
113+
class CrossEncoderConfig(HFModelConfig):
109114
model_name: str = Field("cross-encoder/ms-marco-MiniLM-L-6-v2", description="Name of the hugging face model.")
110115
train_head: bool = Field(
111116
False, description="Whether to train the head of the model. If False, LogReg will be trained."

autointent/modules/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,16 @@
1212
)
1313
from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding
1414
from .regex import SimpleRegex
15-
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, SklearnScorer
15+
from .scoring import (
16+
BertScorer,
17+
DescriptionScorer,
18+
DNNCScorer,
19+
KNNScorer,
20+
LinearScorer,
21+
MLKnnScorer,
22+
RerankScorer,
23+
SklearnScorer,
24+
)
1625

1726
T = TypeVar("T", bound=BaseModule)
1827

@@ -36,6 +45,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
3645
RerankScorer,
3746
SklearnScorer,
3847
MLKnnScorer,
48+
BertScorer,
3949
]
4050
)
4151

autointent/modules/scoring/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ._bert import BertScorer
12
from ._description import DescriptionScorer
23
from ._dnnc import DNNCScorer
34
from ._knn import KNNScorer, RerankScorer
@@ -6,6 +7,7 @@
67
from ._sklearn import SklearnScorer
78

89
__all__ = [
10+
"BertScorer",
911
"DNNCScorer",
1012
"DescriptionScorer",
1113
"KNNScorer",
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""BertScorer class for transformer-based classification."""
2+
3+
import tempfile
4+
from typing import Any
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
import torch
9+
from datasets import Dataset
10+
from transformers import (
11+
AutoModelForSequenceClassification,
12+
AutoTokenizer,
13+
DataCollatorWithPadding,
14+
Trainer,
15+
TrainingArguments,
16+
)
17+
18+
from autointent import Context
19+
from autointent._callbacks import REPORTERS_NAMES
20+
from autointent.configs import HFModelConfig
21+
from autointent.custom_types import ListOfLabels
22+
from autointent.modules.base import BaseScorer
23+
24+
25+
class BertScorer(BaseScorer):
26+
name = "transformer"
27+
supports_multiclass = True
28+
supports_multilabel = True
29+
_model: Any
30+
_tokenizer: Any
31+
32+
def __init__(
33+
self,
34+
model_config: HFModelConfig | str | dict[str, Any] | None = None,
35+
num_train_epochs: int = 3,
36+
batch_size: int = 8,
37+
learning_rate: float = 5e-5,
38+
seed: int = 0,
39+
report_to: REPORTERS_NAMES | None = None, # type: ignore # noqa: PGH003
40+
) -> None:
41+
self.model_config = HFModelConfig.from_search_config(model_config)
42+
self.num_train_epochs = num_train_epochs
43+
self.batch_size = batch_size
44+
self.learning_rate = learning_rate
45+
self.seed = seed
46+
self.report_to = report_to
47+
48+
@classmethod
49+
def from_context(
50+
cls,
51+
context: Context,
52+
model_config: HFModelConfig | str | dict[str, Any] | None = None,
53+
num_train_epochs: int = 3,
54+
batch_size: int = 8,
55+
learning_rate: float = 5e-5,
56+
seed: int = 0,
57+
) -> "BertScorer":
58+
if model_config is None:
59+
model_config = context.resolve_embedder()
60+
61+
report_to = context.logging_config.report_to
62+
63+
return cls(
64+
model_config=model_config,
65+
num_train_epochs=num_train_epochs,
66+
batch_size=batch_size,
67+
learning_rate=learning_rate,
68+
seed=seed,
69+
report_to=report_to,
70+
)
71+
72+
def get_embedder_config(self) -> dict[str, Any]:
73+
return self.model_config.model_dump()
74+
75+
def fit(
76+
self,
77+
utterances: list[str],
78+
labels: ListOfLabels,
79+
) -> None:
80+
if hasattr(self, "_model"):
81+
self.clear_cache()
82+
83+
self._validate_task(labels)
84+
85+
model_name = self.model_config.model_name
86+
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
87+
self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=self._n_classes)
88+
89+
use_cpu = self.model_config.device == "cpu"
90+
91+
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
92+
return self._tokenizer( # type: ignore[no-any-return]
93+
examples["text"], return_tensors="pt", **self.model_config.tokenizer_config.model_dump()
94+
)
95+
96+
dataset = Dataset.from_dict({"text": utterances, "labels": labels})
97+
tokenized_dataset = dataset.map(tokenize_function, batched=True)
98+
99+
with tempfile.TemporaryDirectory() as tmp_dir:
100+
training_args = TrainingArguments(
101+
output_dir=tmp_dir,
102+
num_train_epochs=self.num_train_epochs,
103+
per_device_train_batch_size=self.batch_size,
104+
learning_rate=self.learning_rate,
105+
seed=self.seed,
106+
save_strategy="no",
107+
logging_strategy="steps",
108+
logging_steps=10,
109+
report_to=self.report_to,
110+
use_cpu=use_cpu,
111+
)
112+
113+
trainer = Trainer(
114+
model=self._model,
115+
args=training_args,
116+
train_dataset=tokenized_dataset,
117+
tokenizer=self._tokenizer,
118+
data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer),
119+
)
120+
121+
trainer.train()
122+
123+
self._model.eval()
124+
125+
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
126+
if not hasattr(self, "_model") or not hasattr(self, "_tokenizer"):
127+
msg = "Model is not trained. Call fit() first."
128+
raise RuntimeError(msg)
129+
130+
all_predictions = []
131+
for i in range(0, len(utterances), self.batch_size):
132+
batch = utterances[i : i + self.batch_size]
133+
inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
134+
with torch.no_grad():
135+
outputs = self._model(**inputs)
136+
logits = outputs.logits
137+
if self._multilabel:
138+
batch_predictions = torch.sigmoid(logits).numpy()
139+
else:
140+
batch_predictions = torch.softmax(logits, dim=1).numpy()
141+
all_predictions.append(batch_predictions)
142+
return np.vstack(all_predictions) if all_predictions else np.array([])
143+
144+
def clear_cache(self) -> None:
145+
if hasattr(self, "_model"):
146+
del self._model
147+
if hasattr(self, "_tokenizer"):
148+
del self._tokenizer

autointent/modules/scoring/_mlknn/mlknn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
140140
model_name=self.embedder_config.model_name,
141141
device=self.embedder_config.device,
142142
batch_size=self.embedder_config.batch_size,
143-
max_length=self.embedder_config.max_length,
143+
tokenizer_config=self.embedder_config.tokenizer_config,
144144
use_cache=self.embedder_config.use_cache,
145145
),
146146
)

autointent/modules/scoring/_sklearn/sklearn_scorer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def fit(
128128
model_name=self.embedder_config.model_name,
129129
device=self.embedder_config.device,
130130
batch_size=self.embedder_config.batch_size,
131-
max_length=self.embedder_config.max_length,
131+
tokenizer_config=self.embedder_config.tokenizer_config,
132132
use_cache=self.embedder_config.use_cache,
133133
)
134134
)

0 commit comments

Comments
 (0)