Skip to content

Commit 75d9b31

Browse files
committed
Added code for full tuning
1 parent a35046b commit 75d9b31

File tree

4 files changed

+188
-1
lines changed

4 files changed

+188
-1
lines changed

autointent/modules/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,16 @@
1212
)
1313
from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding
1414
from .regex import SimpleRegex
15-
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, SklearnScorer
15+
from .scoring import (
16+
DescriptionScorer,
17+
DNNCScorer,
18+
KNNScorer,
19+
LinearScorer,
20+
MLKnnScorer,
21+
RerankScorer,
22+
SklearnScorer,
23+
TransformerScorer,
24+
)
1625

1726
T = TypeVar("T", bound=BaseModule)
1827

@@ -36,6 +45,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
3645
RerankScorer,
3746
SklearnScorer,
3847
MLKnnScorer,
48+
TransformerScorer,
3949
]
4050
)
4151

autointent/modules/scoring/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ._linear import LinearScorer
55
from ._mlknn import MLKnnScorer
66
from ._sklearn import SklearnScorer
7+
from ._transformer import TransformerScorer
78

89
__all__ = [
910
"DNNCScorer",
@@ -13,4 +14,5 @@
1314
"MLKnnScorer",
1415
"RerankScorer",
1516
"SklearnScorer",
17+
"TransformerScorer"
1618
]
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""TransformerScorer class for transformer-based classification."""
2+
3+
import tempfile
4+
from typing import Any
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
import torch
9+
from datasets import Dataset
10+
from transformers import (
11+
AutoModelForSequenceClassification,
12+
AutoTokenizer,
13+
DataCollatorWithPadding,
14+
Trainer,
15+
TrainingArguments,
16+
)
17+
18+
from autointent import Context
19+
from autointent.configs import EmbedderConfig
20+
from autointent.custom_types import ListOfLabels
21+
from autointent.modules.base import BaseScorer
22+
23+
24+
class TransformerScorer(BaseScorer):
25+
name = "transformer"
26+
supports_multiclass = True
27+
supports_multilabel = True
28+
_multilabel: bool
29+
_model: Any
30+
_tokenizer: Any
31+
32+
def __init__(
33+
self,
34+
model_config: EmbedderConfig | str | dict[str, Any] | None = None,
35+
num_train_epochs: int = 3,
36+
batch_size: int = 8,
37+
learning_rate: float = 5e-5,
38+
seed: int = 0,
39+
) -> None:
40+
self.model_config = EmbedderConfig.from_search_config(model_config)
41+
self.num_train_epochs = num_train_epochs
42+
self.batch_size = batch_size
43+
self.learning_rate = learning_rate
44+
self.seed = seed
45+
46+
@classmethod
47+
def from_context(
48+
cls,
49+
context: Context,
50+
model_config: EmbedderConfig | str | None = None,
51+
) -> "TransformerScorer":
52+
if model_config is None:
53+
model_config = context.resolve_embedder()
54+
return cls(model_config=model_config)
55+
56+
def get_embedder_config(self) -> dict[str, Any]:
57+
return self.model_config.model_dump()
58+
59+
def fit(
60+
self,
61+
utterances: list[str],
62+
labels: ListOfLabels,
63+
) -> None:
64+
if hasattr(self, "_model"):
65+
self.clear_cache()
66+
67+
self._validate_task(labels)
68+
69+
if self._multilabel:
70+
labels_array = np.array(labels) if not isinstance(labels, np.ndarray) else labels
71+
num_labels = labels_array.shape[1]
72+
else:
73+
num_labels = len(set(labels))
74+
75+
model_name = self.model_config.model_name
76+
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
77+
self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
78+
79+
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
80+
return self._tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
81+
82+
dataset = Dataset.from_dict({"text": utterances, "labels": labels})
83+
tokenized_dataset = dataset.map(tokenize_function, batched=True)
84+
85+
with tempfile.TemporaryDirectory() as tmp_dir:
86+
training_args = TrainingArguments(
87+
output_dir=tmp_dir,
88+
num_train_epochs=self.num_train_epochs,
89+
per_device_train_batch_size=self.batch_size,
90+
learning_rate=self.learning_rate,
91+
seed=self.seed,
92+
save_strategy="no",
93+
logging_strategy="no",
94+
report_to="none",
95+
)
96+
97+
trainer = Trainer(
98+
model=self._model,
99+
args=training_args,
100+
train_dataset=tokenized_dataset,
101+
tokenizer=self._tokenizer,
102+
data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer),
103+
)
104+
105+
trainer.train()
106+
107+
self._model.eval()
108+
109+
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
110+
if not hasattr(self, "_model") or not hasattr(self, "_tokenizer"):
111+
msg = "Model is not trained. Call fit() first."
112+
raise RuntimeError(msg)
113+
114+
inputs = self._tokenizer(utterances, padding=True, truncation=True, max_length=128, return_tensors="pt")
115+
116+
with torch.no_grad():
117+
outputs = self._model(**inputs)
118+
logits = outputs.logits
119+
120+
if self._multilabel:
121+
return torch.sigmoid(logits).numpy()
122+
return torch.softmax(logits, dim=1).numpy()
123+
124+
125+
def clear_cache(self) -> None:
126+
if hasattr(self, "_model"):
127+
del self._model
128+
if hasattr(self, "_tokenizer"):
129+
del self._tokenizer
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import pytest
3+
4+
from autointent.context.data_handler import DataHandler
5+
from autointent.modules import TransformerScorer
6+
7+
8+
def test_base_transformer(dataset):
9+
data_handler = DataHandler(dataset)
10+
11+
scorer = TransformerScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
12+
13+
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
14+
15+
test_data = [
16+
"why is there a hold on my american saving bank account",
17+
"i am nost sure why my account is blocked",
18+
"why is there a hold on my capital one checking account",
19+
"i think my account is blocked but i do not know the reason",
20+
"can you tell me why is my bank account frozen",
21+
]
22+
23+
predictions = scorer.predict(test_data)
24+
25+
assert predictions.shape[0] == len(test_data)
26+
assert predictions.shape[1] == len(set(data_handler.train_labels(0)))
27+
28+
assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0
29+
30+
if not scorer._multilabel:
31+
for pred_row in predictions:
32+
np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5)
33+
34+
if hasattr(scorer, "predict_with_metadata"):
35+
predictions, metadata = scorer.predict_with_metadata(test_data)
36+
assert len(predictions) == len(test_data)
37+
assert metadata is None
38+
else:
39+
pytest.skip("predict_with_metadata not implemented in TransformerScorer")
40+
41+
scorer.clear_cache()
42+
assert not hasattr(scorer, "_model") or scorer._model is None
43+
assert not hasattr(scorer, "_tokenizer") or scorer._tokenizer is None
44+
45+
with pytest.raises(RuntimeError):
46+
scorer.predict(test_data)

0 commit comments

Comments
 (0)