22import pytest
33
44from autointent .context .data_handler import DataHandler
5+ from autointent ._dump_tools import Dumper
56from autointent .modules import BERTLoRAScorer
67
78
@@ -20,15 +21,30 @@ def test_lora_prediction(dataset):
2021 "can you tell me why is my bank account frozen" ,
2122 ]
2223
23- predictions = scorer .predict (test_data )
24+ # Get initial predictions
25+ initial_predictions = scorer .predict (test_data )
2426
25- assert predictions .shape [0 ] == len (test_data )
26- assert predictions .shape [1 ] == len (set (data_handler .train_labels (0 )))
27+ # Perform dump and load
28+ import tempfile
29+ with tempfile .TemporaryDirectory () as tmpdir :
30+ import pathlib
31+ dump_path = pathlib .Path (tmpdir )
32+ Dumper .dump (scorer , dump_path )
33+
34+ # Create new scorer instance and load state
35+ new_scorer = BERTLoRAScorer (transformer_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 8 )
36+ Dumper .load (new_scorer , dump_path )
37+
38+ loaded_predictions = new_scorer .predict (test_data )
39+ np .testing .assert_array_almost_equal (initial_predictions , loaded_predictions , decimal = 5 )
2740
28- assert 0.0 <= np .min (predictions ) <= np .max (predictions ) <= 1.0
41+ assert initial_predictions .shape [0 ] == len (test_data )
42+ assert initial_predictions .shape [1 ] == len (set (data_handler .train_labels (0 )))
43+
44+ assert 0.0 <= np .min (initial_predictions ) <= np .max (initial_predictions ) <= 1.0
2945
3046 if not scorer ._multilabel :
31- for pred_row in predictions :
47+ for pred_row in initial_predictions :
3248 np .testing .assert_almost_equal (np .sum (pred_row ), 1.0 , decimal = 5 )
3349
3450 if hasattr (scorer , "predict_with_metadata" ):
0 commit comments