1- from unittest .mock import MagicMock
2-
31from autointent .modules .embedding import LogRegEmbedding
42
53
@@ -21,17 +19,15 @@ def test_fit_trains_model():
2119 assert module ._label_encoder .classes_ .tolist () == [0 , 1 ]
2220
2321
24- def test_score_evaluates_model ():
22+ def test_predict_evaluates_model ():
2523 module = LogRegEmbedding (k = 5 , embedder_name = "sergeyzh/rubert-tiny-turbo" )
2624
2725 utterances = ["hello" , "goodbye" , "hi" , "bye" , "bye" , "hello" , "welcome" , "hi123" , "hiii" , "bye-bye" , "bye!" ]
2826 labels = [0 , 1 , 0 , 1 , 1 , 0 , 0 , 0 , 0 , 1 , 1 ]
2927 module .fit (utterances , labels )
3028
31- mock_context = MagicMock ()
32- mock_context .data_handler .test_utterances .return_value = ["hello" , "goodbye" ]
33- mock_context .data_handler .test_labels .return_value = [[1 , 0 ], [0 , 1 ]]
34-
35- scores = module .score (mock_context , split = "test" )
29+ probas = module .predict (["hello" , "bye" ])
3630
37- assert isinstance (scores , dict )
31+ assert len (probas ) == 2
32+ assert probas [0 ][0 ] > probas [0 ][1 ]
33+ assert probas [1 ][1 ] > probas [1 ][0 ]
0 commit comments