1+ import numpy as np
2+ import pytest
3+
4+ from autointent .context .data_handler import DataHandler
5+ from autointent .modules import BERTLoRAScorer
6+
7+
8+ def test_lora_prediction (dataset ):
9+ """Test that the transformer model can fit and make predictions."""
10+ data_handler = DataHandler (dataset )
11+
12+ scorer = BERTLoRAScorer (model_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 8 )
13+ scorer .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
14+
15+ test_data = [
16+ "why is there a hold on my american saving bank account" ,
17+ "i am not sure why my account is blocked" ,
18+ "why is there a hold on my capital one checking account" ,
19+ "i think my account is blocked but i do not know the reason" ,
20+ "can you tell me why is my bank account frozen" ,
21+ ]
22+
23+ predictions = scorer .predict (test_data )
24+
25+ assert predictions .shape [0 ] == len (test_data )
26+ assert predictions .shape [1 ] == len (set (data_handler .train_labels (0 )))
27+
28+ assert 0.0 <= np .min (predictions ) <= np .max (predictions ) <= 1.0
29+
30+ if not scorer ._multilabel :
31+ for pred_row in predictions :
32+ np .testing .assert_almost_equal (np .sum (pred_row ), 1.0 , decimal = 5 )
33+
34+ if hasattr (scorer , "predict_with_metadata" ):
35+ predictions , metadata = scorer .predict_with_metadata (test_data )
36+ assert len (predictions ) == len (test_data )
37+ assert metadata is None
38+
39+
40+ def test_bert_cache_clearing (dataset ):
41+ """Test that the transformer model properly handles cache clearing."""
42+ data_handler = DataHandler (dataset )
43+
44+ scorer = BERTLoRAScorer (model_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 8 )
45+ scorer .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
46+
47+ test_data = ["test text" ]
48+
49+ scorer .predict (test_data )
50+ scorer .clear_cache ()
51+
52+ assert not hasattr (scorer , "_model" ) or scorer ._model is None
53+ assert not hasattr (scorer , "_tokenizer" ) or scorer ._tokenizer is None
54+
55+ with pytest .raises (RuntimeError ):
56+ scorer .predict (test_data )
0 commit comments