1+ import shutil
2+ import tempfile
3+ from pathlib import Path
4+
15import numpy as np
26import pytest
37
4- from autointent ._dump_tools import Dumper
58from autointent .context .data_handler import DataHandler
69from autointent .modules import BERTLoRAScorer
710
811
12+ def test_lora_scorer_dump_load (dataset ):
13+ """Test that BERTLoRAScorer can be saved and loaded while preserving predictions."""
14+ data_handler = DataHandler (dataset )
15+
16+ # Create and train scorer
17+ scorer_original = BERTLoRAScorer (classification_model_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 8 )
18+ scorer_original .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
19+
20+ # Test data
21+ test_data = [
22+ "why is there a hold on my account" ,
23+ "why is my bank account frozen" ,
24+ ]
25+
26+ # Get predictions before saving
27+ predictions_before = scorer_original .predict (test_data )
28+
29+ # Create temp directory and save model
30+ temp_dir_path = Path (tempfile .mkdtemp (prefix = "lora_scorer_test_" ))
31+ try :
32+ # Save the model
33+ scorer_original .dump (str (temp_dir_path ))
34+
35+ # Create a new scorer and load saved model
36+ scorer_loaded = BERTLoRAScorer (classification_model_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 8 )
37+ scorer_loaded .load (str (temp_dir_path ))
38+
39+ # Verify model and tokenizer are loaded
40+ assert hasattr (scorer_loaded , "_model" )
41+ assert scorer_loaded ._model is not None
42+ assert hasattr (scorer_loaded , "_tokenizer" )
43+ assert scorer_loaded ._tokenizer is not None
44+
45+ # Get predictions after loading
46+ predictions_after = scorer_loaded .predict (test_data )
47+
48+ # Verify predictions match
49+ assert predictions_before .shape == predictions_after .shape
50+ np .testing .assert_allclose (predictions_before , predictions_after , atol = 1e-6 )
51+
52+ finally :
53+ # Clean up
54+ shutil .rmtree (temp_dir_path , ignore_errors = True ) # workaround for windows permission error
55+
56+
957def test_lora_prediction (dataset ):
10- """Test that the transformer model can fit and make predictions."""
58+ """Test that the lora model can fit and make predictions."""
1159 data_handler = DataHandler (dataset )
1260
1361 scorer = BERTLoRAScorer (classification_model_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 8 )
62+
1463 scorer .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
1564
1665 test_data = [
@@ -21,49 +70,47 @@ def test_lora_prediction(dataset):
2170 "can you tell me why is my bank account frozen" ,
2271 ]
2372
24- initial_predictions = scorer .predict (test_data )
73+ predictions = scorer .predict (test_data )
2574
26- import tempfile
27- with tempfile .TemporaryDirectory () as tmpdir :
28- import pathlib
29- dump_path = pathlib .Path (tmpdir )
30- Dumper .dump (scorer , dump_path )
75+ # Verify prediction shape
76+ assert predictions .shape [0 ] == len (test_data )
77+ assert predictions .shape [1 ] == len (set (data_handler .train_labels (0 )))
3178
32- new_scorer = BERTLoRAScorer (classification_model_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 8 )
33- Dumper .load (new_scorer , dump_path )
34-
35- loaded_predictions = new_scorer .predict (test_data )
36- np .testing .assert_array_almost_equal (initial_predictions , loaded_predictions , decimal = 5 )
37-
38- assert initial_predictions .shape [0 ] == len (test_data )
39- assert initial_predictions .shape [1 ] == len (set (data_handler .train_labels (0 )))
40-
41- assert 0.0 <= np .min (initial_predictions ) <= np .max (initial_predictions ) <= 1.0
79+ # Verify predictions are probabilities
80+ assert 0.0 <= np .min (predictions ) <= np .max (predictions ) <= 1.0
4281
82+ # Verify probabilities sum to 1 for multiclass
4383 if not scorer ._multilabel :
44- for pred_row in initial_predictions :
84+ for pred_row in predictions :
4585 np .testing .assert_almost_equal (np .sum (pred_row ), 1.0 , decimal = 5 )
4686
87+ # Test metadata function if available
4788 if hasattr (scorer , "predict_with_metadata" ):
4889 predictions , metadata = scorer .predict_with_metadata (test_data )
4990 assert len (predictions ) == len (test_data )
5091 assert metadata is None
5192
5293
53- def test_bert_cache_clearing (dataset ):
54- """Test that the transformer model properly handles cache clearing."""
94+ def test_lora_cache_clearing (dataset ):
95+ """Test that the lora model properly handles cache clearing."""
5596 data_handler = DataHandler (dataset )
5697
5798 scorer = BERTLoRAScorer (classification_model_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 8 )
99+
58100 scorer .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
59101
60102 test_data = ["test text" ]
61103
104+ # Should work before clearing cache
62105 scorer .predict (test_data )
106+
107+ # Clear the cache
63108 scorer .clear_cache ()
64109
110+ # Verify model and tokenizer are removed
65111 assert not hasattr (scorer , "_model" ) or scorer ._model is None
66112 assert not hasattr (scorer , "_tokenizer" ) or scorer ._tokenizer is None
67113
114+ # Should raise exception after clearing cache
68115 with pytest .raises (RuntimeError ):
69- scorer .predict (test_data )
116+ scorer .predict (test_data )
0 commit comments