Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion autointent/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
)
from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding
from .regex import SimpleRegex
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, SklearnScorer
from .scoring import (
DescriptionScorer,
DNNCScorer,
KNNScorer,
LinearScorer,
MLKnnScorer,
RerankScorer,
SklearnScorer,
TransformerScorer,
)

T = TypeVar("T", bound=BaseModule)

Expand All @@ -36,6 +45,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
RerankScorer,
SklearnScorer,
MLKnnScorer,
TransformerScorer,
]
)

Expand Down
2 changes: 2 additions & 0 deletions autointent/modules/scoring/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._linear import LinearScorer
from ._mlknn import MLKnnScorer
from ._sklearn import SklearnScorer
from ._transformer import TransformerScorer

__all__ = [
"DNNCScorer",
Expand All @@ -13,4 +14,5 @@
"MLKnnScorer",
"RerankScorer",
"SklearnScorer",
"TransformerScorer"
]
129 changes: 129 additions & 0 deletions autointent/modules/scoring/_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""TransformerScorer class for transformer-based classification."""

import tempfile
from typing import Any

import numpy as np
import numpy.typing as npt
import torch
from datasets import Dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
)

from autointent import Context
from autointent.configs import EmbedderConfig
from autointent.custom_types import ListOfLabels
from autointent.modules.base import BaseScorer


class TransformerScorer(BaseScorer):
name = "transformer"
supports_multiclass = True
supports_multilabel = True
_multilabel: bool
_model: Any
_tokenizer: Any

def __init__(
self,
model_config: EmbedderConfig | str | dict[str, Any] | None = None,
num_train_epochs: int = 3,
batch_size: int = 8,
learning_rate: float = 5e-5,
seed: int = 0,
) -> None:
self.model_config = EmbedderConfig.from_search_config(model_config)
self.num_train_epochs = num_train_epochs
self.batch_size = batch_size
self.learning_rate = learning_rate
self.seed = seed

@classmethod
def from_context(
cls,
context: Context,
model_config: EmbedderConfig | str | None = None,
) -> "TransformerScorer":
if model_config is None:
model_config = context.resolve_embedder()
return cls(model_config=model_config)

def get_embedder_config(self) -> dict[str, Any]:
return self.model_config.model_dump()

def fit(
self,
utterances: list[str],
labels: ListOfLabels,
) -> None:
if hasattr(self, "_model"):
self.clear_cache()

self._validate_task(labels)

if self._multilabel:
labels_array = np.array(labels) if not isinstance(labels, np.ndarray) else labels
num_labels = labels_array.shape[1]
else:
num_labels = len(set(labels))

model_name = self.model_config.model_name
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Модель на device не перенес (можно вроде в тренер передавать)


def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
return self._tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

dataset = Dataset.from_dict({"text": utterances, "labels": labels})
tokenized_dataset = dataset.map(tokenize_function, batched=True)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
num_train_epochs=self.num_train_epochs,
per_device_train_batch_size=self.batch_size,
learning_rate=self.learning_rate,
seed=self.seed,
save_strategy="no",
logging_strategy="no",
report_to="none",
)

trainer = Trainer(
model=self._model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=self._tokenizer,
data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer),
)

trainer.train()
Comment on lines 97 to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

насколько я понимаю нужно написать Trainer(model=self._model.to(device), ...), чтобы обучение производилось на гпу

возможно не прав, скажи потом если у тебя оно само на гпу перекидывалось

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

желательно просто этот параметр инферить из EmbedderConfig

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


self._model.eval()

def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
if not hasattr(self, "_model") or not hasattr(self, "_tokenizer"):
msg = "Model is not trained. Call fit() first."
raise RuntimeError(msg)

inputs = self._tokenizer(utterances, padding=True, truncation=True, max_length=128, return_tensors="pt")

with torch.no_grad():
outputs = self._model(**inputs)
logits = outputs.logits

if self._multilabel:
return torch.sigmoid(logits).numpy()
return torch.softmax(logits, dim=1).numpy()


def clear_cache(self) -> None:
if hasattr(self, "_model"):
del self._model
if hasattr(self, "_tokenizer"):
del self._tokenizer
46 changes: 46 additions & 0 deletions tests/modules/scoring/test_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
import pytest

from autointent.context.data_handler import DataHandler
from autointent.modules import TransformerScorer


def test_base_transformer(dataset):
data_handler = DataHandler(dataset)

scorer = TransformerScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)

scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))

test_data = [
"why is there a hold on my american saving bank account",
"i am nost sure why my account is blocked",
"why is there a hold on my capital one checking account",
"i think my account is blocked but i do not know the reason",
"can you tell me why is my bank account frozen",
]

predictions = scorer.predict(test_data)

assert predictions.shape[0] == len(test_data)
assert predictions.shape[1] == len(set(data_handler.train_labels(0)))

assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0

if not scorer._multilabel:
for pred_row in predictions:
np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5)

if hasattr(scorer, "predict_with_metadata"):
predictions, metadata = scorer.predict_with_metadata(test_data)
assert len(predictions) == len(test_data)
assert metadata is None
else:
pytest.skip("predict_with_metadata not implemented in TransformerScorer")

scorer.clear_cache()
assert not hasattr(scorer, "_model") or scorer._model is None
assert not hasattr(scorer, "_tokenizer") or scorer._tokenizer is None

with pytest.raises(RuntimeError):
scorer.predict(test_data)
Loading