Skip to content

Commit fd607ac

Browse files
committed
added code for test dumper
1 parent 26f2c45 commit fd607ac

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

tests/modules/scoring/test_bert.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,59 @@
1+
import shutil
2+
import tempfile
3+
from pathlib import Path
4+
15
import numpy as np
26
import pytest
37

48
from autointent.context.data_handler import DataHandler
59
from autointent.modules import BertScorer
610

711

12+
def test_bert_scorer_dump_load(dataset):
13+
"""Test that BertScorer can be saved and loaded while preserving predictions."""
14+
data_handler = DataHandler(dataset)
15+
16+
# Create and train scorer
17+
scorer_original = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
18+
scorer_original.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
19+
20+
# Test data
21+
test_data = [
22+
"why is there a hold on my account",
23+
"why is my bank account frozen",
24+
]
25+
26+
# Get predictions before saving
27+
predictions_before = scorer_original.predict(test_data)
28+
29+
# Create temp directory and save model
30+
temp_dir_path = Path(tempfile.mkdtemp(prefix="bert_scorer_test_"))
31+
try:
32+
# Save the model
33+
scorer_original.dump(str(temp_dir_path))
34+
35+
# Create a new scorer and load saved model
36+
scorer_loaded = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
37+
scorer_loaded.load(str(temp_dir_path))
38+
39+
# Verify model and tokenizer are loaded
40+
assert hasattr(scorer_loaded, "_model")
41+
assert scorer_loaded._model is not None
42+
assert hasattr(scorer_loaded, "_tokenizer")
43+
assert scorer_loaded._tokenizer is not None
44+
45+
# Get predictions after loading
46+
predictions_after = scorer_loaded.predict(test_data)
47+
48+
# Verify predictions match
49+
assert predictions_before.shape == predictions_after.shape
50+
np.testing.assert_allclose(predictions_before, predictions_after, atol=1e-6)
51+
52+
finally:
53+
# Clean up
54+
shutil.rmtree(temp_dir_path)
55+
56+
857
def test_bert_prediction(dataset):
958
"""Test that the transformer model can fit and make predictions."""
1059
data_handler = DataHandler(dataset)

0 commit comments

Comments
 (0)