Skip to content

Commit 7dfb8f8

Browse files
committed
added test
1 parent 9235df9 commit 7dfb8f8

File tree

5 files changed

+74
-2
lines changed

5 files changed

+74
-2
lines changed

autointent/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding
1414
from .regex import SimpleRegex
1515
from .scoring import (
16+
BERTLoRAScorer,
1617
BertScorer,
1718
DescriptionScorer,
1819
DNNCScorer,
@@ -46,6 +47,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4647
SklearnScorer,
4748
MLKnnScorer,
4849
BertScorer,
50+
BERTLoRAScorer
4951
]
5052
)
5153

autointent/modules/scoring/_lora/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def fit(
9393
)
9494
self._model = get_peft_model(self._model, self._lora_config)
9595

96-
device = torch.device(self.model_config.device)
96+
device = torch.device(self.model_config.device if self.model_config.device else 'cpu')
9797
self._model = self._model.to(device)
9898

9999
use_cpu = self.model_config.device == "cpu"
@@ -137,7 +137,7 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
137137
msg = "Model is not trained. Call fit() first."
138138
raise RuntimeError(msg)
139139

140-
device = torch.device(self.model_config.device)
140+
device = torch.device(self.model_config.device if self.model_config.device else 'cpu')
141141
self._model = self._model.to(device)
142142

143143
all_predictions = []

tests/assets/configs/multiclass.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@
3535
batch_size: [8, 16]
3636
learning_rate: [5.0e-5]
3737
seed: [0]
38+
- module_name: lora
39+
model_config:
40+
- model_name: avsolatorio/GIST-small-Embedding-v0
41+
num_train_epochs: [1]
42+
batch_size: [8, 16]
43+
learning_rate: [5.0e-5]
44+
seed: [0]
3845
- node_type: decision
3946
target_metric: decision_accuracy
4047
search_space:

tests/assets/configs/multilabel.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
batch_size: [8]
3232
learning_rate: [5.0e-5]
3333
seed: [0]
34+
- module_name: lora
35+
model_config:
36+
- model_name: avsolatorio/GIST-small-Embedding-v0
37+
num_train_epochs: [1]
38+
batch_size: [8]
39+
learning_rate: [5.0e-5]
40+
seed: [0]
3441
- node_type: decision
3542
target_metric: decision_accuracy
3643
search_space:

tests/modules/scoring/test_lora.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import numpy as np
2+
import pytest
3+
4+
from autointent.context.data_handler import DataHandler
5+
from autointent.modules import BERTLoRAScorer
6+
7+
8+
def test_lora_prediction(dataset):
9+
"""Test that the transformer model can fit and make predictions."""
10+
data_handler = DataHandler(dataset)
11+
12+
scorer = BERTLoRAScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
13+
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
14+
15+
test_data = [
16+
"why is there a hold on my american saving bank account",
17+
"i am not sure why my account is blocked",
18+
"why is there a hold on my capital one checking account",
19+
"i think my account is blocked but i do not know the reason",
20+
"can you tell me why is my bank account frozen",
21+
]
22+
23+
predictions = scorer.predict(test_data)
24+
25+
assert predictions.shape[0] == len(test_data)
26+
assert predictions.shape[1] == len(set(data_handler.train_labels(0)))
27+
28+
assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0
29+
30+
if not scorer._multilabel:
31+
for pred_row in predictions:
32+
np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5)
33+
34+
if hasattr(scorer, "predict_with_metadata"):
35+
predictions, metadata = scorer.predict_with_metadata(test_data)
36+
assert len(predictions) == len(test_data)
37+
assert metadata is None
38+
39+
40+
def test_bert_cache_clearing(dataset):
41+
"""Test that the transformer model properly handles cache clearing."""
42+
data_handler = DataHandler(dataset)
43+
44+
scorer = BERTLoRAScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
45+
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
46+
47+
test_data = ["test text"]
48+
49+
scorer.predict(test_data)
50+
scorer.clear_cache()
51+
52+
assert not hasattr(scorer, "_model") or scorer._model is None
53+
assert not hasattr(scorer, "_tokenizer") or scorer._tokenizer is None
54+
55+
with pytest.raises(RuntimeError):
56+
scorer.predict(test_data)

0 commit comments

Comments
 (0)