1+ from unittest .mock import Mock , AsyncMock
2+ import pytest
3+ from autointent .generation import Generator
4+
5+ from autointent .generation .utterances import HumanUtteranceGenerator , CriticHumanLike
6+ from autointent import Dataset , Sample
7+
8+
9+ def test_human_utterance_generator_sync (dataset ):
10+ mock_llm = Mock ()
11+ mock_llm .get_chat_completion .return_value = "Human-like utterance"
12+
13+ mock_critic = Mock (spec = CriticHumanLike )
14+ mock_critic .is_human .return_value = True
15+
16+ generator = HumanUtteranceGenerator (mock_llm , mock_critic , async_mode = False )
17+
18+ n_before = len (dataset ["train_0" ])
19+ new_samples = generator .augment (dataset , split_name = "train_0" , update_split = False , n_final_per_class = 2 )
20+ n_after = len (dataset ["train_0" ])
21+
22+ assert n_before == n_after
23+ assert len (new_samples ) > 0
24+ assert all (isinstance (sample , Sample ) for sample in new_samples )
25+ assert all ("utterance" in sample .dict () for sample in new_samples )
26+ assert all ("label" in sample .dict () for sample in new_samples )
27+
28+
29+ def test_human_utterance_generator_async (dataset ):
30+ mock_llm = AsyncMock ()
31+ mock_llm .get_chat_completion_async .return_value = "Human-like utterance"
32+
33+ mock_critic = AsyncMock (spec = CriticHumanLike )
34+ mock_critic .is_human_async .return_value = True
35+ generator = HumanUtteranceGenerator (mock_llm , mock_critic , async_mode = True )
36+
37+ n_before = len (dataset ["train_0" ])
38+ new_samples = generator .augment (dataset , split_name = "train_0" , update_split = False , n_final_per_class = 2 )
39+ n_after = len (dataset ["train_0" ])
40+
41+ assert n_before == n_after
42+ assert len (new_samples ) > 0
43+ assert all (isinstance (sample , Sample ) for sample in new_samples )
44+ assert all ("utterance" in sample .dict () for sample in new_samples )
45+ assert all ("label" in sample .dict () for sample in new_samples )
46+
47+
48+ def test_human_utterance_generator_respects_critic (dataset ):
49+ mock_llm = Mock ()
50+ mock_llm .get_chat_completion .return_value = "Generated utterance"
51+
52+ mock_critic = Mock (spec = CriticHumanLike )
53+ mock_critic .is_human .side_effect = [False , True ]
54+
55+ generator = HumanUtteranceGenerator (mock_llm , mock_critic , async_mode = False )
56+
57+ new_samples = generator .augment (dataset , split_name = "train_0" , update_split = False , n_final_per_class = 1 )
58+ assert len (new_samples ) > 0
59+ assert all (mock_critic .is_human .call_count >= 1 for _ in range (len (new_samples )))
0 commit comments