11from unittest .mock import AsyncMock , Mock
22
3+ import pytest
4+
5+ from autointent import Dataset
36from autointent .generation .utterances import CriticHumanLike , HumanUtteranceGenerator
47from autointent .schemas import Sample
58
69
10+ @pytest .fixture
11+ def dataset ():
12+ return Dataset .from_dict (
13+ {
14+ "intents" : [
15+ {"id" : 0 , "name" : "Greeting" },
16+ {"id" : 1 , "name" : "OrderFood" },
17+ ],
18+ "train" : [
19+ {"utterance" : "hello" , "label" : 0 },
20+ {"utterance" : "hi there" , "label" : 0 },
21+ {"utterance" : "i want pizza" , "label" : 1 },
22+ ],
23+ }
24+ )
25+
26+
727def test_human_utterance_generator_sync (dataset ):
828 mock_llm = Mock ()
929 mock_llm .get_chat_completion .return_value = "Human-like utterance"
@@ -13,9 +33,9 @@ def test_human_utterance_generator_sync(dataset):
1333
1434 generator = HumanUtteranceGenerator (mock_llm , mock_critic , async_mode = False )
1535
16- n_before = len (dataset ["train_0 " ])
17- new_samples = generator .augment (dataset , split_name = "train_0 " , update_split = False , n_final_per_class = 2 )
18- n_after = len (dataset ["train_0 " ])
36+ n_before = len (dataset ["train " ])
37+ new_samples = generator .augment (dataset , split_name = "train " , update_split = False , n_final_per_class = 2 )
38+ n_after = len (dataset ["train " ])
1939
2040 assert n_before == n_after
2141 assert len (new_samples ) > 0
@@ -30,12 +50,12 @@ def test_human_utterance_generator_async(dataset):
3050
3151 mock_critic = AsyncMock (spec = CriticHumanLike )
3252 mock_critic .is_human_async .return_value = True
33- generator = HumanUtteranceGenerator (mock_llm , mock_critic , async_mode = True )
3453
35- n_before = len (dataset ["train_0" ])
36- new_samples = generator .augment (dataset , split_name = "train_0" , update_split = False , n_final_per_class = 2 )
37- n_after = len (dataset ["train_0" ])
54+ generator = HumanUtteranceGenerator (mock_llm , mock_critic , async_mode = True )
3855
56+ n_before = len (dataset ["train" ])
57+ new_samples = generator .augment (dataset , split_name = "train" , update_split = False , n_final_per_class = 2 )
58+ n_after = len (dataset ["train" ])
3959 assert n_before == n_after
4060 assert len (new_samples ) > 0
4161 assert all (isinstance (sample , Sample ) for sample in new_samples )
@@ -48,10 +68,8 @@ def test_human_utterance_generator_respects_critic(dataset):
4868 mock_llm .get_chat_completion .return_value = "Generated utterance"
4969
5070 mock_critic = Mock (spec = CriticHumanLike )
51- mock_critic .is_human .side_effect = [False , True ]
52-
71+ mock_critic .is_human .return_value = True
5372 generator = HumanUtteranceGenerator (mock_llm , mock_critic , async_mode = False )
54-
55- new_samples = generator .augment (dataset , split_name = "train_0" , update_split = False , n_final_per_class = 1 )
73+ new_samples = generator .augment (dataset , split_name = "train" , update_split = False , n_final_per_class = 1 )
5674 assert len (new_samples ) > 0
57- assert all ( mock_critic .is_human .call_count >= 1 for _ in range ( len ( new_samples )))
75+ assert mock_critic .is_human .call_count >= 1
0 commit comments