Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions autointent/_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __hash__(self) -> int:
hasher = Hasher()
for parameter in self.embedding_model.parameters():
hasher.update(parameter.detach().cpu().numpy())
hasher.update(self.config.max_length)
hasher.update(self.config.tokenizer_config.max_length)
return hasher.intdigest()

def clear_ram(self) -> None:
Expand All @@ -114,7 +114,7 @@ def dump(self, path: Path) -> None:
model_name=str(self.config.model_name),
device=self.config.device,
batch_size=self.config.batch_size,
max_length=self.config.max_length,
max_length=self.config.tokenizer_config.max_length,
use_cache=self.config.use_cache,
)
path.mkdir(parents=True, exist_ok=True)
Expand All @@ -137,6 +137,10 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
else:
kwargs = metadata # type: ignore[assignment]

max_length = kwargs.pop("max_length", None)
if max_length is not None:
kwargs["tokenizer_config"] = {"max_length": max_length}

return cls(EmbedderConfig(**kwargs))

def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]:
Expand All @@ -162,12 +166,12 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
self.config.model_name,
self.config.batch_size,
str(self.config.max_length),
str(self.config.tokenizer_config.max_length),
self.config.device,
)

if self.config.max_length is not None:
self.embedding_model.max_seq_length = self.config.max_length
if self.config.tokenizer_config.max_length is not None:
self.embedding_model.max_seq_length = self.config.tokenizer_config.max_length

embeddings = self.embedding_model.encode(
utterances,
Expand Down
8 changes: 6 additions & 2 deletions autointent/_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
self.config.model_name,
trust_remote_code=True,
device=self.config.device,
max_length=self.config.max_length, # type: ignore[arg-type]
max_length=self.config.tokenizer_config.max_length, # type: ignore[arg-type]
)
self._train_head = False
self._clf = classifier_head
Expand Down Expand Up @@ -252,7 +252,7 @@ def save(self, path: str) -> None:
model_name=self.config.model_name,
train_head=self._train_head,
device=self.config.device,
max_length=self.config.max_length,
max_length=self.config.tokenizer_config.max_length,
batch_size=self.config.batch_size,
)

Expand Down Expand Up @@ -282,6 +282,10 @@ def load(cls, path: Path, override_config: CrossEncoderConfig | None = None) ->
else:
kwargs = metadata # type: ignore[assignment]

max_length = kwargs.pop("max_length", None)
if max_length is not None:
kwargs["tokenizer_config"] = {"max_length": max_length}

return cls(
CrossEncoderConfig(**kwargs),
classifier_head=clf,
Expand Down
6 changes: 3 additions & 3 deletions autointent/_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy.typing as npt

from autointent import Embedder
from autointent.configs import EmbedderConfig, TaskTypeEnum
from autointent.configs import EmbedderConfig, TaskTypeEnum, TokenizerConfig
from autointent.custom_types import ListOfLabels


Expand Down Expand Up @@ -195,7 +195,7 @@ def dump(self, dir_path: Path) -> None:
json.dump(data, file, indent=4, ensure_ascii=False)

metadata = VectorIndexMetadata(
embedder_max_length=self.embedder.config.max_length,
embedder_max_length=self.embedder.config.tokenizer_config.max_length,
embedder_model_name=str(self.embedder.config.model_name),
embedder_device=self.embedder.config.device,
embedder_batch_size=self.embedder.config.batch_size,
Expand Down Expand Up @@ -229,7 +229,7 @@ def load(
model_name=metadata["embedder_model_name"],
device=embedder_device or metadata["embedder_device"],
batch_size=embedder_batch_size or metadata["embedder_batch_size"],
max_length=metadata["embedder_max_length"],
tokenizer_config=TokenizerConfig(max_length=metadata["embedder_max_length"]),
use_cache=embedder_use_cache or metadata["embedder_use_cache"],
)
)
Expand Down
4 changes: 3 additions & 1 deletion autointent/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

from ._inference_node import InferenceNodeConfig
from ._optimization import DataConfig, LoggingConfig
from ._transformers import CrossEncoderConfig, EmbedderConfig, TaskTypeEnum
from ._transformers import CrossEncoderConfig, EmbedderConfig, HFModelConfig, TaskTypeEnum, TokenizerConfig

__all__ = [
"CrossEncoderConfig",
"DataConfig",
"EmbedderConfig",
"HFModelConfig",
"InferenceNodeConfig",
"InferenceNodeConfig",
"LoggingConfig",
"TaskTypeEnum",
"TokenizerConfig",
]
23 changes: 14 additions & 9 deletions autointent/configs/_transformers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from enum import Enum
from typing import Any
from typing import Any, Literal

from pydantic import BaseModel, ConfigDict, Field, PositiveInt
from typing_extensions import Self, assert_never


class ModelConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
class TokenizerConfig(BaseModel):
padding: bool | Literal["longest", "max_length", "do_not_pad"] = True
truncation: bool = True
max_length: PositiveInt | None = Field(None, description="Maximum length of input sequences.")


class STModelConfig(ModelConfig):
model_name: str
class HFModelConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
model_name: str = Field(
"prajjwal1/bert-tiny", description="Name of the hugging face repository with transformer model."
)
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
device: str | None = Field(None, description="Torch notation for CPU or CUDA.")
tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig)

@classmethod
def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self:
Expand All @@ -26,7 +31,7 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) ->
Model configuration.
"""
if values is None:
return cls() # type: ignore[call-arg]
return cls()
if isinstance(values, BaseModel):
return values # type: ignore[return-value]
if isinstance(values, str):
Expand All @@ -45,7 +50,7 @@ class TaskTypeEnum(Enum):
sts = "sts"


class EmbedderConfig(STModelConfig):
class EmbedderConfig(HFModelConfig):
model_name: str = Field("sentence-transformers/all-MiniLM-L6-v2", description="Name of the hugging face model.")
default_prompt: str | None = Field(
None, description="Default prompt for the model. This is used when no task specific prompt is not provided."
Expand Down Expand Up @@ -105,7 +110,7 @@ def get_prompt_type(self, prompt_type: TaskTypeEnum | None) -> str | None: # no
use_cache: bool = Field(False, description="Whether to use embeddings caching.")


class CrossEncoderConfig(STModelConfig):
class CrossEncoderConfig(HFModelConfig):
model_name: str = Field("cross-encoder/ms-marco-MiniLM-L-6-v2", description="Name of the hugging face model.")
train_head: bool = Field(
False, description="Whether to train the head of the model. If False, LogReg will be trained."
Expand Down
35 changes: 6 additions & 29 deletions autointent/modules/scoring/_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,11 @@
)

from autointent import Context
from autointent.configs import EmbedderConfig
from autointent.configs import HFModelConfig
from autointent.custom_types import ListOfLabels
from autointent.modules.base import BaseScorer


class TokenizerConfig:
"""Configuration for tokenizer parameters."""

def __init__(
self,
max_length: int = 128,
padding: str = "max_length",
truncation: bool = True,
) -> None:
self.max_length = max_length
self.padding = padding
self.truncation = truncation


class BertScorer(BaseScorer):
name = "transformer"
supports_multiclass = True
Expand All @@ -45,31 +31,28 @@ class BertScorer(BaseScorer):

def __init__(
self,
model_config: EmbedderConfig | str | dict[str, Any] | None = None,
model_config: HFModelConfig | str | dict[str, Any] | None = None,
num_train_epochs: int = 3,
batch_size: int = 8,
learning_rate: float = 5e-5,
seed: int = 0,
tokenizer_config: TokenizerConfig | None = None,
) -> None:
self.model_config = EmbedderConfig.from_search_config(model_config)
self.model_config = HFModelConfig.from_search_config(model_config)
self.num_train_epochs = num_train_epochs
self.batch_size = batch_size
self.learning_rate = learning_rate
self.seed = seed
self.tokenizer_config = tokenizer_config or TokenizerConfig()
self._multilabel = False

@classmethod
def from_context(
cls,
context: Context,
model_config: EmbedderConfig | str | None = None,
model_config: HFModelConfig | str | dict[str, Any] | None = None,
num_train_epochs: int = 3,
batch_size: int = 8,
learning_rate: float = 5e-5,
seed: int = 0,
tokenizer_config: TokenizerConfig | None = None,
) -> "BertScorer":
if model_config is None:
model_config = context.resolve_embedder()
Expand All @@ -79,7 +62,6 @@ def from_context(
batch_size=batch_size,
learning_rate=learning_rate,
seed=seed,
tokenizer_config=tokenizer_config,
)

def get_embedder_config(self) -> dict[str, Any]:
Expand Down Expand Up @@ -114,10 +96,7 @@ def fit(

def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
return self._tokenizer( # type: ignore[no-any-return]
examples["text"],
padding=self.tokenizer_config.padding,
truncation=self.tokenizer_config.truncation,
max_length=self.tokenizer_config.max_length,
examples["text"], return_tensors="pt", **self.model_config.tokenizer_config.model_dump()
)

dataset = Dataset.from_dict({"text": utterances, "labels": labels})
Expand Down Expand Up @@ -154,9 +133,7 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
msg = "Model is not trained. Call fit() first."
raise RuntimeError(msg)

inputs = self._tokenizer(
utterances, padding=True, truncation=True, max_length=self.tokenizer_config.max_length, return_tensors="pt"
)
inputs = self._tokenizer(utterances, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())

with torch.no_grad():
outputs = self._model(**inputs)
Expand Down
2 changes: 1 addition & 1 deletion autointent/modules/scoring/_mlknn/mlknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
model_name=self.embedder_config.model_name,
device=self.embedder_config.device,
batch_size=self.embedder_config.batch_size,
max_length=self.embedder_config.max_length,
tokenizer_config=self.embedder_config.tokenizer_config,
use_cache=self.embedder_config.use_cache,
),
)
Expand Down
2 changes: 1 addition & 1 deletion autointent/modules/scoring/_sklearn/sklearn_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def fit(
model_name=self.embedder_config.model_name,
device=self.embedder_config.device,
batch_size=self.embedder_config.batch_size,
max_length=self.embedder_config.max_length,
tokenizer_config=self.embedder_config.tokenizer_config,
use_cache=self.embedder_config.use_cache,
)
)
Expand Down
Loading
Loading