Skip to content

Commit 8abba9f

Browse files
committed
Update test_lora.py
1 parent e0761f5 commit 8abba9f

File tree

1 file changed

+69
-22
lines changed

1 file changed

+69
-22
lines changed

tests/modules/scoring/test_lora.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,65 @@
1+
import shutil
2+
import tempfile
3+
from pathlib import Path
4+
15
import numpy as np
26
import pytest
37

4-
from autointent._dump_tools import Dumper
58
from autointent.context.data_handler import DataHandler
69
from autointent.modules import BERTLoRAScorer
710

811

12+
def test_lora_scorer_dump_load(dataset):
13+
"""Test that BERTLoRAScorer can be saved and loaded while preserving predictions."""
14+
data_handler = DataHandler(dataset)
15+
16+
# Create and train scorer
17+
scorer_original = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
18+
scorer_original.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
19+
20+
# Test data
21+
test_data = [
22+
"why is there a hold on my account",
23+
"why is my bank account frozen",
24+
]
25+
26+
# Get predictions before saving
27+
predictions_before = scorer_original.predict(test_data)
28+
29+
# Create temp directory and save model
30+
temp_dir_path = Path(tempfile.mkdtemp(prefix="lora_scorer_test_"))
31+
try:
32+
# Save the model
33+
scorer_original.dump(str(temp_dir_path))
34+
35+
# Create a new scorer and load saved model
36+
scorer_loaded = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
37+
scorer_loaded.load(str(temp_dir_path))
38+
39+
# Verify model and tokenizer are loaded
40+
assert hasattr(scorer_loaded, "_model")
41+
assert scorer_loaded._model is not None
42+
assert hasattr(scorer_loaded, "_tokenizer")
43+
assert scorer_loaded._tokenizer is not None
44+
45+
# Get predictions after loading
46+
predictions_after = scorer_loaded.predict(test_data)
47+
48+
# Verify predictions match
49+
assert predictions_before.shape == predictions_after.shape
50+
np.testing.assert_allclose(predictions_before, predictions_after, atol=1e-6)
51+
52+
finally:
53+
# Clean up
54+
shutil.rmtree(temp_dir_path, ignore_errors=True) # workaround for windows permission error
55+
56+
957
def test_lora_prediction(dataset):
10-
"""Test that the transformer model can fit and make predictions."""
58+
"""Test that the lora model can fit and make predictions."""
1159
data_handler = DataHandler(dataset)
1260

1361
scorer = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
62+
1463
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
1564

1665
test_data = [
@@ -21,49 +70,47 @@ def test_lora_prediction(dataset):
2170
"can you tell me why is my bank account frozen",
2271
]
2372

24-
initial_predictions = scorer.predict(test_data)
73+
predictions = scorer.predict(test_data)
2574

26-
import tempfile
27-
with tempfile.TemporaryDirectory() as tmpdir:
28-
import pathlib
29-
dump_path = pathlib.Path(tmpdir)
30-
Dumper.dump(scorer, dump_path)
75+
# Verify prediction shape
76+
assert predictions.shape[0] == len(test_data)
77+
assert predictions.shape[1] == len(set(data_handler.train_labels(0)))
3178

32-
new_scorer = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
33-
Dumper.load(new_scorer, dump_path)
34-
35-
loaded_predictions = new_scorer.predict(test_data)
36-
np.testing.assert_array_almost_equal(initial_predictions, loaded_predictions, decimal=5)
37-
38-
assert initial_predictions.shape[0] == len(test_data)
39-
assert initial_predictions.shape[1] == len(set(data_handler.train_labels(0)))
40-
41-
assert 0.0 <= np.min(initial_predictions) <= np.max(initial_predictions) <= 1.0
79+
# Verify predictions are probabilities
80+
assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0
4281

82+
# Verify probabilities sum to 1 for multiclass
4383
if not scorer._multilabel:
44-
for pred_row in initial_predictions:
84+
for pred_row in predictions:
4585
np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5)
4686

87+
# Test metadata function if available
4788
if hasattr(scorer, "predict_with_metadata"):
4889
predictions, metadata = scorer.predict_with_metadata(test_data)
4990
assert len(predictions) == len(test_data)
5091
assert metadata is None
5192

5293

53-
def test_bert_cache_clearing(dataset):
54-
"""Test that the transformer model properly handles cache clearing."""
94+
def test_lora_cache_clearing(dataset):
95+
"""Test that the lora model properly handles cache clearing."""
5596
data_handler = DataHandler(dataset)
5697

5798
scorer = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
99+
58100
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
59101

60102
test_data = ["test text"]
61103

104+
# Should work before clearing cache
62105
scorer.predict(test_data)
106+
107+
# Clear the cache
63108
scorer.clear_cache()
64109

110+
# Verify model and tokenizer are removed
65111
assert not hasattr(scorer, "_model") or scorer._model is None
66112
assert not hasattr(scorer, "_tokenizer") or scorer._tokenizer is None
67113

114+
# Should raise exception after clearing cache
68115
with pytest.raises(RuntimeError):
69-
scorer.predict(test_data)
116+
scorer.predict(test_data)

0 commit comments

Comments
 (0)