Skip to content

Commit 5890d45

Browse files
committed
work on review
1 parent 75d9b31 commit 5890d45

File tree

4 files changed

+85
-20
lines changed

4 files changed

+85
-20
lines changed

autointent/modules/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding
1414
from .regex import SimpleRegex
1515
from .scoring import (
16+
BertScorer,
1617
DescriptionScorer,
1718
DNNCScorer,
1819
KNNScorer,
1920
LinearScorer,
2021
MLKnnScorer,
2122
RerankScorer,
2223
SklearnScorer,
23-
TransformerScorer,
2424
)
2525

2626
T = TypeVar("T", bound=BaseModule)
@@ -45,7 +45,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4545
RerankScorer,
4646
SklearnScorer,
4747
MLKnnScorer,
48-
TransformerScorer,
48+
BertScorer,
4949
]
5050
)
5151

autointent/modules/scoring/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
from ._linear import LinearScorer
55
from ._mlknn import MLKnnScorer
66
from ._sklearn import SklearnScorer
7-
from ._transformer import TransformerScorer
7+
from ._transformer import BertScorer
88

99
__all__ = [
10+
"BertScorer",
1011
"DNNCScorer",
1112
"DescriptionScorer",
1213
"KNNScorer",
1314
"LinearScorer",
1415
"MLKnnScorer",
1516
"RerankScorer",
1617
"SklearnScorer",
17-
"TransformerScorer"
1818
]

autointent/modules/scoring/_transformer.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""TransformerScorer class for transformer-based classification."""
1+
"""BertScorer class for transformer-based classification."""
22

33
import tempfile
44
from typing import Any
@@ -21,7 +21,21 @@
2121
from autointent.modules.base import BaseScorer
2222

2323

24-
class TransformerScorer(BaseScorer):
24+
class TokenizerConfig:
25+
"""Configuration for tokenizer parameters."""
26+
27+
def __init__(
28+
self,
29+
max_length: int = 128,
30+
padding: str = "max_length",
31+
truncation: bool = True,
32+
) -> None:
33+
self.max_length = max_length
34+
self.padding = padding
35+
self.truncation = truncation
36+
37+
38+
class BertScorer(BaseScorer):
2539
name = "transformer"
2640
supports_multiclass = True
2741
supports_multilabel = True
@@ -36,26 +50,46 @@ def __init__(
3650
batch_size: int = 8,
3751
learning_rate: float = 5e-5,
3852
seed: int = 0,
53+
tokenizer_config: TokenizerConfig | None = None,
3954
) -> None:
4055
self.model_config = EmbedderConfig.from_search_config(model_config)
4156
self.num_train_epochs = num_train_epochs
4257
self.batch_size = batch_size
4358
self.learning_rate = learning_rate
4459
self.seed = seed
60+
self.tokenizer_config = tokenizer_config or TokenizerConfig()
61+
self._multilabel = False
4562

4663
@classmethod
4764
def from_context(
4865
cls,
4966
context: Context,
5067
model_config: EmbedderConfig | str | None = None,
51-
) -> "TransformerScorer":
68+
num_train_epochs: int = 3,
69+
batch_size: int = 8,
70+
learning_rate: float = 5e-5,
71+
seed: int = 0,
72+
tokenizer_config: TokenizerConfig | None = None,
73+
) -> "BertScorer":
5274
if model_config is None:
5375
model_config = context.resolve_embedder()
54-
return cls(model_config=model_config)
76+
return cls(
77+
model_config=model_config,
78+
num_train_epochs=num_train_epochs,
79+
batch_size=batch_size,
80+
learning_rate=learning_rate,
81+
seed=seed,
82+
tokenizer_config=tokenizer_config,
83+
)
5584

5685
def get_embedder_config(self) -> dict[str, Any]:
5786
return self.model_config.model_dump()
5887

88+
def _validate_task(self, labels: ListOfLabels) -> None:
89+
"""Validate the task and set _multilabel flag."""
90+
super()._validate_task(labels)
91+
self._multilabel = isinstance(labels[0], list)
92+
5993
def fit(
6094
self,
6195
utterances: list[str],
@@ -67,7 +101,7 @@ def fit(
67101
self._validate_task(labels)
68102

69103
if self._multilabel:
70-
labels_array = np.array(labels) if not isinstance(labels, np.ndarray) else labels
104+
labels_array = np.array(labels)
71105
num_labels = labels_array.shape[1]
72106
else:
73107
num_labels = len(set(labels))
@@ -76,8 +110,15 @@ def fit(
76110
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
77111
self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
78112

79-
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
80-
return self._tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
113+
use_cpu = hasattr(self.model_config, "device") and self.model_config.device == "cpu"
114+
115+
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]: # type: ignore[no-any-return]
116+
return self._tokenizer(
117+
examples["text"],
118+
padding=self.tokenizer_config.padding,
119+
truncation=self.tokenizer_config.truncation,
120+
max_length=self.tokenizer_config.max_length,
121+
)
81122

82123
dataset = Dataset.from_dict({"text": utterances, "labels": labels})
83124
tokenized_dataset = dataset.map(tokenize_function, batched=True)
@@ -90,8 +131,10 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
90131
learning_rate=self.learning_rate,
91132
seed=self.seed,
92133
save_strategy="no",
93-
logging_strategy="no",
94-
report_to="none",
134+
logging_strategy="steps",
135+
logging_steps=10,
136+
report_to="wandb",
137+
use_cpu=use_cpu,
95138
)
96139

97140
trainer = Trainer(
@@ -111,7 +154,9 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
111154
msg = "Model is not trained. Call fit() first."
112155
raise RuntimeError(msg)
113156

114-
inputs = self._tokenizer(utterances, padding=True, truncation=True, max_length=128, return_tensors="pt")
157+
inputs = self._tokenizer(
158+
utterances, padding=True, truncation=True, max_length=self.tokenizer_config.max_length, return_tensors="pt"
159+
)
115160

116161
with torch.no_grad():
117162
outputs = self._model(**inputs)
@@ -121,7 +166,6 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
121166
return torch.sigmoid(logits).numpy()
122167
return torch.softmax(logits, dim=1).numpy()
123168

124-
125169
def clear_cache(self) -> None:
126170
if hasattr(self, "_model"):
127171
del self._model

tests/modules/scoring/test_transformer.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import pytest
33

44
from autointent.context.data_handler import DataHandler
5-
from autointent.modules import TransformerScorer
5+
from autointent.modules import BertScorer
66

77

8-
def test_base_transformer(dataset):
8+
def test_bert_prediction(dataset):
9+
"""Test that the transformer model can fit and make predictions."""
910
data_handler = DataHandler(dataset)
1011

11-
scorer = TransformerScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
12+
scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
1213

1314
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
1415

@@ -22,25 +23,45 @@ def test_base_transformer(dataset):
2223

2324
predictions = scorer.predict(test_data)
2425

26+
# Verify prediction shape
2527
assert predictions.shape[0] == len(test_data)
2628
assert predictions.shape[1] == len(set(data_handler.train_labels(0)))
2729

30+
# Verify predictions are probabilities
2831
assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0
2932

33+
# Verify probabilities sum to 1 for multiclass
3034
if not scorer._multilabel:
3135
for pred_row in predictions:
3236
np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5)
3337

38+
# Test metadata function if available
3439
if hasattr(scorer, "predict_with_metadata"):
3540
predictions, metadata = scorer.predict_with_metadata(test_data)
3641
assert len(predictions) == len(test_data)
3742
assert metadata is None
38-
else:
39-
pytest.skip("predict_with_metadata not implemented in TransformerScorer")
4043

44+
45+
def test_bert_cache_clearing(dataset):
46+
"""Test that the transformer model properly handles cache clearing."""
47+
data_handler = DataHandler(dataset)
48+
49+
scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
50+
51+
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
52+
53+
test_data = ["test text"]
54+
55+
# Should work before clearing cache
56+
scorer.predict(test_data)
57+
58+
# Clear the cache
4159
scorer.clear_cache()
60+
61+
# Verify model and tokenizer are removed
4262
assert not hasattr(scorer, "_model") or scorer._model is None
4363
assert not hasattr(scorer, "_tokenizer") or scorer._tokenizer is None
4464

65+
# Should raise exception after clearing cache
4566
with pytest.raises(RuntimeError):
4667
scorer.predict(test_data)

0 commit comments

Comments
 (0)