Skip to content

Commit c707ad3

Browse files
committed
Added test for PTuningScorer
1 parent 2919e2f commit c707ad3

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import numpy as np
2+
import pytest
3+
4+
from autointent.context.data_handler import DataHandler
5+
from autointent.modules import PTuningScorer
6+
7+
8+
def test_ptuning_prediction(dataset):
9+
"""Test that the transformer model can fit and make predictions."""
10+
data_handler = DataHandler(dataset)
11+
12+
scorer = PTuningScorer(
13+
base_model_config="prajjwal1/bert-tiny",
14+
num_train_epochs=1,
15+
batch_size=8,
16+
task_type="SEQ_CLS",
17+
num_virtual_tokens=10,
18+
)
19+
20+
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
21+
22+
test_data = [
23+
"why is there a hold on my american saving bank account",
24+
"i am nost sure why my account is blocked",
25+
"why is there a hold on my capital one checking account",
26+
"i think my account is blocked but i do not know the reason",
27+
"can you tell me why is my bank account frozen",
28+
]
29+
30+
predictions = scorer.predict(test_data)
31+
32+
assert predictions.shape[0] == len(test_data)
33+
assert predictions.shape[1] == len(set(data_handler.train_labels(0)))
34+
35+
assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0
36+
37+
if not scorer._multilabel:
38+
for pred_row in predictions:
39+
np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5)
40+
41+
if hasattr(scorer, "predict_with_metadata"):
42+
predictions, metadata = scorer.predict_with_metadata(test_data)
43+
assert len(predictions) == len(test_data)
44+
assert metadata is None
45+
46+
47+
def test_ptuning_cache_clearing(dataset):
48+
"""Test that the transformer model properly handles cache clearing."""
49+
data_handler = DataHandler(dataset)
50+
51+
scorer = PTuningScorer(
52+
base_model_config="prajjwal1/bert-tiny",
53+
num_train_epochs=1,
54+
batch_size=8,
55+
task_type="SEQ_CLS",
56+
num_virtual_tokens=20,
57+
)
58+
59+
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
60+
61+
test_data = ["test text"]
62+
scorer.predict(test_data)
63+
scorer.clear_cache()
64+
65+
assert not hasattr(scorer, "_model") or scorer._model is None
66+
assert not hasattr(scorer, "_tokenizer") or scorer._tokenizer is None
67+
68+
with pytest.raises(RuntimeError):
69+
scorer.predict(test_data)

0 commit comments

Comments
 (0)