Skip to content

Commit 7a5d556

Browse files
committed
added dump check
1 parent 988f7d5 commit 7a5d556

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

tests/modules/scoring/test_lora.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from autointent.context.data_handler import DataHandler
5+
from autointent._dump_tools import Dumper
56
from autointent.modules import BERTLoRAScorer
67

78

@@ -20,15 +21,30 @@ def test_lora_prediction(dataset):
2021
"can you tell me why is my bank account frozen",
2122
]
2223

23-
predictions = scorer.predict(test_data)
24+
# Get initial predictions
25+
initial_predictions = scorer.predict(test_data)
2426

25-
assert predictions.shape[0] == len(test_data)
26-
assert predictions.shape[1] == len(set(data_handler.train_labels(0)))
27+
# Perform dump and load
28+
import tempfile
29+
with tempfile.TemporaryDirectory() as tmpdir:
30+
import pathlib
31+
dump_path = pathlib.Path(tmpdir)
32+
Dumper.dump(scorer, dump_path)
33+
34+
# Create new scorer instance and load state
35+
new_scorer = BERTLoRAScorer(transformer_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
36+
Dumper.load(new_scorer, dump_path)
37+
38+
loaded_predictions = new_scorer.predict(test_data)
39+
np.testing.assert_array_almost_equal(initial_predictions, loaded_predictions, decimal=5)
2740

28-
assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0
41+
assert initial_predictions.shape[0] == len(test_data)
42+
assert initial_predictions.shape[1] == len(set(data_handler.train_labels(0)))
43+
44+
assert 0.0 <= np.min(initial_predictions) <= np.max(initial_predictions) <= 1.0
2945

3046
if not scorer._multilabel:
31-
for pred_row in predictions:
47+
for pred_row in initial_predictions:
3248
np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5)
3349

3450
if hasattr(scorer, "predict_with_metadata"):

0 commit comments

Comments
 (0)