Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion autointent/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
)
from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding
from .regex import SimpleRegex
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, SklearnScorer
from .scoring import (
BertScorer,
DescriptionScorer,
DNNCScorer,
KNNScorer,
LinearScorer,
MLKnnScorer,
RerankScorer,
SklearnScorer,
)

T = TypeVar("T", bound=BaseModule)

Expand All @@ -36,6 +45,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
RerankScorer,
SklearnScorer,
MLKnnScorer,
BertScorer,
]
)

Expand Down
2 changes: 2 additions & 0 deletions autointent/modules/scoring/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._bert import BertScorer
from ._description import DescriptionScorer
from ._dnnc import DNNCScorer
from ._knn import KNNScorer, RerankScorer
Expand All @@ -6,6 +7,7 @@
from ._sklearn import SklearnScorer

__all__ = [
"BertScorer",
"DNNCScorer",
"DescriptionScorer",
"KNNScorer",
Expand Down
173 changes: 173 additions & 0 deletions autointent/modules/scoring/_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""BertScorer class for transformer-based classification."""

import tempfile
from typing import Any

import numpy as np
import numpy.typing as npt
import torch
from datasets import Dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
)

from autointent import Context
from autointent.configs import EmbedderConfig
from autointent.custom_types import ListOfLabels
from autointent.modules.base import BaseScorer


class TokenizerConfig:
"""Configuration for tokenizer parameters."""

def __init__(
self,
max_length: int = 128,
padding: str = "max_length",
truncation: bool = True,
) -> None:
self.max_length = max_length
self.padding = padding
self.truncation = truncation


class BertScorer(BaseScorer):
name = "transformer"
supports_multiclass = True
supports_multilabel = True
_multilabel: bool
_model: Any
_tokenizer: Any

def __init__(
self,
model_config: EmbedderConfig | str | dict[str, Any] | None = None,
num_train_epochs: int = 3,
batch_size: int = 8,
learning_rate: float = 5e-5,
seed: int = 0,
tokenizer_config: TokenizerConfig | None = None,
) -> None:
self.model_config = EmbedderConfig.from_search_config(model_config)
self.num_train_epochs = num_train_epochs
self.batch_size = batch_size
self.learning_rate = learning_rate
self.seed = seed
self.tokenizer_config = tokenizer_config or TokenizerConfig()
self._multilabel = False

@classmethod
def from_context(
cls,
context: Context,
model_config: EmbedderConfig | str | None = None,
num_train_epochs: int = 3,
batch_size: int = 8,
learning_rate: float = 5e-5,
seed: int = 0,
tokenizer_config: TokenizerConfig | None = None,
) -> "BertScorer":
if model_config is None:
model_config = context.resolve_embedder()
return cls(
model_config=model_config,
num_train_epochs=num_train_epochs,
batch_size=batch_size,
learning_rate=learning_rate,
seed=seed,
tokenizer_config=tokenizer_config,
)

def get_embedder_config(self) -> dict[str, Any]:
return self.model_config.model_dump()

def _validate_task(self, labels: ListOfLabels) -> None:
"""Validate the task and set _multilabel flag."""
super()._validate_task(labels)
self._multilabel = isinstance(labels[0], list)

def fit(
self,
utterances: list[str],
labels: ListOfLabels,
) -> None:
if hasattr(self, "_model"):
self.clear_cache()

self._validate_task(labels)

if self._multilabel:
labels_array = np.array(labels)
num_labels = labels_array.shape[1]
else:
num_labels = len(set(labels))

model_name = self.model_config.model_name
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

use_cpu = hasattr(self.model_config, "device") and self.model_config.device == "cpu"

def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
return self._tokenizer( # type: ignore[no-any-return]
examples["text"],
padding=self.tokenizer_config.padding,
truncation=self.tokenizer_config.truncation,
max_length=self.tokenizer_config.max_length,
)

dataset = Dataset.from_dict({"text": utterances, "labels": labels})
tokenized_dataset = dataset.map(tokenize_function, batched=True)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
num_train_epochs=self.num_train_epochs,
per_device_train_batch_size=self.batch_size,
learning_rate=self.learning_rate,
seed=self.seed,
save_strategy="no",
logging_strategy="steps",
logging_steps=10,
report_to="wandb",
use_cpu=use_cpu,
)

trainer = Trainer(
model=self._model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=self._tokenizer,
data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer),
)

trainer.train()

self._model.eval()

def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
if not hasattr(self, "_model") or not hasattr(self, "_tokenizer"):
msg = "Model is not trained. Call fit() first."
raise RuntimeError(msg)

inputs = self._tokenizer(
utterances, padding=True, truncation=True, max_length=self.tokenizer_config.max_length, return_tensors="pt"
)

with torch.no_grad():
outputs = self._model(**inputs)
logits = outputs.logits

if self._multilabel:
return torch.sigmoid(logits).numpy()
return torch.softmax(logits, dim=1).numpy()

def clear_cache(self) -> None:
if hasattr(self, "_model"):
del self._model
if hasattr(self, "_tokenizer"):
del self._tokenizer
67 changes: 67 additions & 0 deletions tests/modules/scoring/test_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import pytest

from autointent.context.data_handler import DataHandler
from autointent.modules import BertScorer


def test_bert_prediction(dataset):
"""Test that the transformer model can fit and make predictions."""
data_handler = DataHandler(dataset)

scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)

scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))

test_data = [
"why is there a hold on my american saving bank account",
"i am nost sure why my account is blocked",
"why is there a hold on my capital one checking account",
"i think my account is blocked but i do not know the reason",
"can you tell me why is my bank account frozen",
]

predictions = scorer.predict(test_data)

# Verify prediction shape
assert predictions.shape[0] == len(test_data)
assert predictions.shape[1] == len(set(data_handler.train_labels(0)))

# Verify predictions are probabilities
assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0

# Verify probabilities sum to 1 for multiclass
if not scorer._multilabel:
for pred_row in predictions:
np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5)

# Test metadata function if available
if hasattr(scorer, "predict_with_metadata"):
predictions, metadata = scorer.predict_with_metadata(test_data)
assert len(predictions) == len(test_data)
assert metadata is None


def test_bert_cache_clearing(dataset):
"""Test that the transformer model properly handles cache clearing."""
data_handler = DataHandler(dataset)

scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)

scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))

test_data = ["test text"]

# Should work before clearing cache
scorer.predict(test_data)

# Clear the cache
scorer.clear_cache()

# Verify model and tokenizer are removed
assert not hasattr(scorer, "_model") or scorer._model is None
assert not hasattr(scorer, "_tokenizer") or scorer._tokenizer is None

# Should raise exception after clearing cache
with pytest.raises(RuntimeError):
scorer.predict(test_data)
Loading