Skip to content

Commit 5e08b44

Browse files
Feat/bert scorer config refactoring (#168)
* refactor configs * add proper configs to BERTScorer * fix typing * fix tokenizer's parameters * fix transformers and accelerate issue * Update optimizer_config.schema.json * bug fix * update callback test * fix tests --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent cb2e610 commit 5e08b44

File tree

12 files changed

+128
-103
lines changed

12 files changed

+128
-103
lines changed

autointent/_embedder.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __hash__(self) -> int:
8787
hasher = Hasher()
8888
for parameter in self.embedding_model.parameters():
8989
hasher.update(parameter.detach().cpu().numpy())
90-
hasher.update(self.config.max_length)
90+
hasher.update(self.config.tokenizer_config.max_length)
9191
return hasher.intdigest()
9292

9393
def clear_ram(self) -> None:
@@ -114,7 +114,7 @@ def dump(self, path: Path) -> None:
114114
model_name=str(self.config.model_name),
115115
device=self.config.device,
116116
batch_size=self.config.batch_size,
117-
max_length=self.config.max_length,
117+
max_length=self.config.tokenizer_config.max_length,
118118
use_cache=self.config.use_cache,
119119
)
120120
path.mkdir(parents=True, exist_ok=True)
@@ -137,6 +137,10 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
137137
else:
138138
kwargs = metadata # type: ignore[assignment]
139139

140+
max_length = kwargs.pop("max_length", None)
141+
if max_length is not None:
142+
kwargs["tokenizer_config"] = {"max_length": max_length}
143+
140144
return cls(EmbedderConfig(**kwargs))
141145

142146
def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]:
@@ -162,12 +166,12 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
162166
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
163167
self.config.model_name,
164168
self.config.batch_size,
165-
str(self.config.max_length),
169+
str(self.config.tokenizer_config.max_length),
166170
self.config.device,
167171
)
168172

169-
if self.config.max_length is not None:
170-
self.embedding_model.max_seq_length = self.config.max_length
173+
if self.config.tokenizer_config.max_length is not None:
174+
self.embedding_model.max_seq_length = self.config.tokenizer_config.max_length
171175

172176
embeddings = self.embedding_model.encode(
173177
utterances,

autointent/_ranker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
self.config.model_name,
114114
trust_remote_code=True,
115115
device=self.config.device,
116-
max_length=self.config.max_length, # type: ignore[arg-type]
116+
max_length=self.config.tokenizer_config.max_length, # type: ignore[arg-type]
117117
)
118118
self._train_head = False
119119
self._clf = classifier_head
@@ -252,7 +252,7 @@ def save(self, path: str) -> None:
252252
model_name=self.config.model_name,
253253
train_head=self._train_head,
254254
device=self.config.device,
255-
max_length=self.config.max_length,
255+
max_length=self.config.tokenizer_config.max_length,
256256
batch_size=self.config.batch_size,
257257
)
258258

@@ -282,6 +282,10 @@ def load(cls, path: Path, override_config: CrossEncoderConfig | None = None) ->
282282
else:
283283
kwargs = metadata # type: ignore[assignment]
284284

285+
max_length = kwargs.pop("max_length", None)
286+
if max_length is not None:
287+
kwargs["tokenizer_config"] = {"max_length": max_length}
288+
285289
return cls(
286290
CrossEncoderConfig(**kwargs),
287291
classifier_head=clf,

autointent/_vector_index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import numpy.typing as npt
1616

1717
from autointent import Embedder
18-
from autointent.configs import EmbedderConfig, TaskTypeEnum
18+
from autointent.configs import EmbedderConfig, TaskTypeEnum, TokenizerConfig
1919
from autointent.custom_types import ListOfLabels
2020

2121

@@ -195,7 +195,7 @@ def dump(self, dir_path: Path) -> None:
195195
json.dump(data, file, indent=4, ensure_ascii=False)
196196

197197
metadata = VectorIndexMetadata(
198-
embedder_max_length=self.embedder.config.max_length,
198+
embedder_max_length=self.embedder.config.tokenizer_config.max_length,
199199
embedder_model_name=str(self.embedder.config.model_name),
200200
embedder_device=self.embedder.config.device,
201201
embedder_batch_size=self.embedder.config.batch_size,
@@ -229,7 +229,7 @@ def load(
229229
model_name=metadata["embedder_model_name"],
230230
device=embedder_device or metadata["embedder_device"],
231231
batch_size=embedder_batch_size or metadata["embedder_batch_size"],
232-
max_length=metadata["embedder_max_length"],
232+
tokenizer_config=TokenizerConfig(max_length=metadata["embedder_max_length"]),
233233
use_cache=embedder_use_cache or metadata["embedder_use_cache"],
234234
)
235235
)

autointent/configs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22

33
from ._inference_node import InferenceNodeConfig
44
from ._optimization import DataConfig, LoggingConfig
5-
from ._transformers import CrossEncoderConfig, EmbedderConfig, TaskTypeEnum
5+
from ._transformers import CrossEncoderConfig, EmbedderConfig, HFModelConfig, TaskTypeEnum, TokenizerConfig
66

77
__all__ = [
88
"CrossEncoderConfig",
99
"DataConfig",
1010
"EmbedderConfig",
11+
"HFModelConfig",
1112
"InferenceNodeConfig",
1213
"InferenceNodeConfig",
1314
"LoggingConfig",
1415
"TaskTypeEnum",
16+
"TokenizerConfig",
1517
]

autointent/configs/_transformers.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
from enum import Enum
2-
from typing import Any
2+
from typing import Any, Literal
33

44
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
55
from typing_extensions import Self, assert_never
66

77

8-
class ModelConfig(BaseModel):
9-
model_config = ConfigDict(extra="forbid")
10-
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
8+
class TokenizerConfig(BaseModel):
9+
padding: bool | Literal["longest", "max_length", "do_not_pad"] = True
10+
truncation: bool = True
1111
max_length: PositiveInt | None = Field(None, description="Maximum length of input sequences.")
1212

1313

14-
class STModelConfig(ModelConfig):
15-
model_name: str
14+
class HFModelConfig(BaseModel):
15+
model_config = ConfigDict(extra="forbid")
16+
model_name: str = Field(
17+
"prajjwal1/bert-tiny", description="Name of the hugging face repository with transformer model."
18+
)
19+
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
1620
device: str | None = Field(None, description="Torch notation for CPU or CUDA.")
21+
tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig)
1722

1823
@classmethod
1924
def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self:
@@ -26,7 +31,7 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) ->
2631
Model configuration.
2732
"""
2833
if values is None:
29-
return cls() # type: ignore[call-arg]
34+
return cls()
3035
if isinstance(values, BaseModel):
3136
return values # type: ignore[return-value]
3237
if isinstance(values, str):
@@ -45,7 +50,7 @@ class TaskTypeEnum(Enum):
4550
sts = "sts"
4651

4752

48-
class EmbedderConfig(STModelConfig):
53+
class EmbedderConfig(HFModelConfig):
4954
model_name: str = Field("sentence-transformers/all-MiniLM-L6-v2", description="Name of the hugging face model.")
5055
default_prompt: str | None = Field(
5156
None, description="Default prompt for the model. This is used when no task specific prompt is not provided."
@@ -105,7 +110,7 @@ def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # no
105110
use_cache: bool = Field(False, description="Whether to use embeddings caching.")
106111

107112

108-
class CrossEncoderConfig(STModelConfig):
113+
class CrossEncoderConfig(HFModelConfig):
109114
model_name: str = Field("cross-encoder/ms-marco-MiniLM-L-6-v2", description="Name of the hugging face model.")
110115
train_head: bool = Field(
111116
False, description="Whether to train the head of the model. If False, LogReg will be trained."

autointent/modules/scoring/_bert.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,11 @@
1616
)
1717

1818
from autointent import Context
19-
from autointent.configs import EmbedderConfig
19+
from autointent.configs import HFModelConfig
2020
from autointent.custom_types import ListOfLabels
2121
from autointent.modules.base import BaseScorer
2222

2323

24-
class TokenizerConfig:
25-
"""Configuration for tokenizer parameters."""
26-
27-
def __init__(
28-
self,
29-
max_length: int = 128,
30-
padding: str = "max_length",
31-
truncation: bool = True,
32-
) -> None:
33-
self.max_length = max_length
34-
self.padding = padding
35-
self.truncation = truncation
36-
37-
3824
class BertScorer(BaseScorer):
3925
name = "transformer"
4026
supports_multiclass = True
@@ -45,31 +31,28 @@ class BertScorer(BaseScorer):
4531

4632
def __init__(
4733
self,
48-
model_config: EmbedderConfig | str | dict[str, Any] | None = None,
34+
model_config: HFModelConfig | str | dict[str, Any] | None = None,
4935
num_train_epochs: int = 3,
5036
batch_size: int = 8,
5137
learning_rate: float = 5e-5,
5238
seed: int = 0,
53-
tokenizer_config: TokenizerConfig | None = None,
5439
) -> None:
55-
self.model_config = EmbedderConfig.from_search_config(model_config)
40+
self.model_config = HFModelConfig.from_search_config(model_config)
5641
self.num_train_epochs = num_train_epochs
5742
self.batch_size = batch_size
5843
self.learning_rate = learning_rate
5944
self.seed = seed
60-
self.tokenizer_config = tokenizer_config or TokenizerConfig()
6145
self._multilabel = False
6246

6347
@classmethod
6448
def from_context(
6549
cls,
6650
context: Context,
67-
model_config: EmbedderConfig | str | None = None,
51+
model_config: HFModelConfig | str | dict[str, Any] | None = None,
6852
num_train_epochs: int = 3,
6953
batch_size: int = 8,
7054
learning_rate: float = 5e-5,
7155
seed: int = 0,
72-
tokenizer_config: TokenizerConfig | None = None,
7356
) -> "BertScorer":
7457
if model_config is None:
7558
model_config = context.resolve_embedder()
@@ -79,7 +62,6 @@ def from_context(
7962
batch_size=batch_size,
8063
learning_rate=learning_rate,
8164
seed=seed,
82-
tokenizer_config=tokenizer_config,
8365
)
8466

8567
def get_embedder_config(self) -> dict[str, Any]:
@@ -114,10 +96,7 @@ def fit(
11496

11597
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
11698
return self._tokenizer( # type: ignore[no-any-return]
117-
examples["text"],
118-
padding=self.tokenizer_config.padding,
119-
truncation=self.tokenizer_config.truncation,
120-
max_length=self.tokenizer_config.max_length,
99+
examples["text"], return_tensors="pt", **self.model_config.tokenizer_config.model_dump()
121100
)
122101

123102
dataset = Dataset.from_dict({"text": utterances, "labels": labels})
@@ -154,9 +133,7 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
154133
msg = "Model is not trained. Call fit() first."
155134
raise RuntimeError(msg)
156135

157-
inputs = self._tokenizer(
158-
utterances, padding=True, truncation=True, max_length=self.tokenizer_config.max_length, return_tensors="pt"
159-
)
136+
inputs = self._tokenizer(utterances, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
160137

161138
with torch.no_grad():
162139
outputs = self._model(**inputs)

autointent/modules/scoring/_mlknn/mlknn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
140140
model_name=self.embedder_config.model_name,
141141
device=self.embedder_config.device,
142142
batch_size=self.embedder_config.batch_size,
143-
max_length=self.embedder_config.max_length,
143+
tokenizer_config=self.embedder_config.tokenizer_config,
144144
use_cache=self.embedder_config.use_cache,
145145
),
146146
)

autointent/modules/scoring/_sklearn/sklearn_scorer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def fit(
128128
model_name=self.embedder_config.model_name,
129129
device=self.embedder_config.device,
130130
batch_size=self.embedder_config.batch_size,
131-
max_length=self.embedder_config.max_length,
131+
tokenizer_config=self.embedder_config.tokenizer_config,
132132
use_cache=self.embedder_config.use_cache,
133133
)
134134
)

0 commit comments

Comments
 (0)