|
14 | 14 | # limitations under the License.
|
15 | 15 | import unittest.mock
|
16 | 16 |
|
17 |
| -#import asynctest |
18 |
| -#from asynctest import mock as async_mock |
| 17 | +import asynctest |
| 18 | +from asynctest import mock as async_mock |
19 | 19 |
|
20 | 20 | import google.ai.generativelanguage as glm
|
21 | 21 |
|
|
28 | 28 | # TODO: replace returns with 'assert' statements
|
29 | 29 |
|
30 | 30 |
|
31 |
| -class UnitTests(parameterized.TestCase): |
32 |
| - def setUp(self): |
33 |
| - self.client = unittest.mock.MagicMock() |
34 |
| - |
35 |
| - client.default_discuss_client = self.client |
36 |
| - |
37 |
| - self.observed_request = None |
38 |
| - |
39 |
| - def fake_generate_message( |
40 |
| - request: glm.GenerateMessageRequest, |
41 |
| - ) -> glm.GenerateMessageResponse: |
42 |
| - self.observed_request = request |
43 |
| - return glm.GenerateMessageResponse( |
44 |
| - messages=request.prompt.messages, |
45 |
| - candidates=[ |
46 |
| - glm.Message(content="a", author="1"), |
47 |
| - glm.Message(content="b", author="1"), |
48 |
| - glm.Message(content="c", author="1"), |
49 |
| - ], |
50 |
| - ) |
51 |
| - |
52 |
| - self.client.generate_message = fake_generate_message |
53 |
| - |
54 |
| - @parameterized.named_parameters( |
55 |
| - ["string", "Hello", ""], |
56 |
| - ["dict", {"content": "Hello"}, ""], |
57 |
| - ["dict_author", {"content": "Hello", "author": "me"}, "me"], |
58 |
| - ["proto", glm.Message(content="Hello"), ""], |
59 |
| - ["proto_author", glm.Message(content="Hello", author="me"), "me"], |
60 |
| - ) |
61 |
| - def test_make_message(self, message, author): |
62 |
| - x = discuss._make_message(message) |
63 |
| - self.assertIsInstance(x, glm.Message) |
64 |
| - self.assertEqual("Hello", x.content) |
65 |
| - self.assertEqual(author, x.author) |
66 |
| - |
67 |
| - @parameterized.named_parameters( |
68 |
| - ["string", "Hello", ["Hello"]], |
69 |
| - ["dict", {"content": "Hello"}, ["Hello"]], |
70 |
| - ["proto", glm.Message(content="Hello"), ["Hello"]], |
71 |
| - [ |
72 |
| - "list", |
73 |
| - ["hello0", {"content": "hello1"}, glm.Message(content="hello2")], |
74 |
| - ["hello0", "hello1", "hello2"], |
75 |
| - ], |
76 |
| - ) |
77 |
| - def test_make_messages(self, messages, expected_contents): |
78 |
| - messages = discuss._make_messages(messages) |
79 |
| - for expected, message in zip(expected_contents, messages): |
80 |
| - self.assertEqual(expected, message.content) |
81 |
| - |
82 |
| - @parameterized.named_parameters( |
83 |
| - ["tuple", ("hello", {"content": "goodbye"})], |
84 |
| - ["iterable", iter(["hello", "goodbye"])], |
85 |
| - ["dict", {"input": "hello", "output": "goodbye"}], |
86 |
| - [ |
87 |
| - "proto", |
88 |
| - glm.Example( |
89 |
| - input=glm.Message(content="hello"), |
90 |
| - output=glm.Message(content="goodbye"), |
91 |
| - ), |
92 |
| - ], |
93 |
| - ) |
94 |
| - def test_make_example(self, example): |
95 |
| - x = discuss._make_example(example) |
96 |
| - self.assertIsInstance(x, glm.Example) |
97 |
| - self.assertEqual("hello", x.input.content) |
98 |
| - self.assertEqual("goodbye", x.output.content) |
99 |
| - return |
100 |
| - |
101 |
| - @parameterized.named_parameters( |
102 |
| - [ |
103 |
| - "messages", |
104 |
| - [ |
105 |
| - "Hi", |
106 |
| - {"content": "Hello!"}, |
107 |
| - "what's your name?", |
108 |
| - glm.Message(content="Dave, what's yours"), |
109 |
| - ], |
110 |
| - ], |
111 |
| - [ |
112 |
| - "examples", |
113 |
| - [ |
114 |
| - ("Hi", "Hello!"), |
115 |
| - { |
116 |
| - "input": "what's your name?", |
117 |
| - "output": {"content": "Dave, what's yours"}, |
118 |
| - }, |
119 |
| - ], |
120 |
| - ], |
121 |
| - ) |
122 |
| - def test_make_examples(self, examples): |
123 |
| - examples = discuss._make_examples(examples) |
124 |
| - self.assertLen(examples, 2) |
125 |
| - self.assertEqual(examples[0].input.content, "Hi") |
126 |
| - self.assertEqual(examples[0].output.content, "Hello!") |
127 |
| - self.assertEqual(examples[1].input.content, "what's your name?") |
128 |
| - self.assertEqual(examples[1].output.content, "Dave, what's yours") |
129 |
| - |
130 |
| - return |
131 |
| - |
132 |
| - def test_make_examples_from_example(self): |
133 |
| - ex_dict = {"input": "hello", "output": "meow!"} |
134 |
| - example = discuss._make_example(ex_dict) |
135 |
| - examples1 = discuss._make_examples(ex_dict) |
136 |
| - examples2 = discuss._make_examples(discuss._make_example(ex_dict)) |
137 |
| - |
138 |
| - self.assertEqual(example, examples1[0]) |
139 |
| - self.assertEqual(example, examples2[0]) |
140 |
| - |
141 |
| - @parameterized.named_parameters( |
142 |
| - ["str", "hello"], |
143 |
| - ["message", glm.Message(content="hello")], |
144 |
| - ["messages", ["hello"]], |
145 |
| - ["dict", {"messages": "hello"}], |
146 |
| - ["dict2", {"messages": ["hello"]}], |
147 |
| - ["proto", glm.MessagePrompt(messages=[glm.Message(content="hello")])], |
148 |
| - ) |
149 |
| - def test_make_message_prompt_from_messages(self, prompt): |
150 |
| - x = discuss._make_message_prompt(prompt) |
151 |
| - self.assertIsInstance(x, glm.MessagePrompt) |
152 |
| - self.assertEqual(x.messages[0].content, "hello") |
153 |
| - return |
154 |
| - |
155 |
| - @parameterized.named_parameters( |
156 |
| - [ |
157 |
| - "dict", |
158 |
| - [ |
159 |
| - { |
160 |
| - "context": "you are a cat", |
161 |
| - "examples": ["are you hungry?", "meow!"], |
162 |
| - "messages": "hello", |
163 |
| - } |
164 |
| - ], |
165 |
| - {}, |
166 |
| - ], |
167 |
| - [ |
168 |
| - "kwargs", |
169 |
| - [], |
170 |
| - { |
171 |
| - "context": "you are a cat", |
172 |
| - "examples": ["are you hungry?", "meow!"], |
173 |
| - "messages": "hello", |
174 |
| - }, |
175 |
| - ], |
176 |
| - [ |
177 |
| - "proto", |
178 |
| - [ |
179 |
| - glm.MessagePrompt( |
180 |
| - context="you are a cat", |
181 |
| - examples=[ |
182 |
| - glm.Example( |
183 |
| - input=glm.Message(content="are you hungry?"), |
184 |
| - output=glm.Message(content="meow!"), |
185 |
| - ) |
186 |
| - ], |
187 |
| - messages=[glm.Message(content="hello")], |
188 |
| - ) |
189 |
| - ], |
190 |
| - {}, |
191 |
| - ], |
192 |
| - ) |
193 |
| - def test_make_message_prompt_from_prompt(self, args, kwargs): |
194 |
| - x = discuss._make_message_prompt(*args, **kwargs) |
195 |
| - self.assertIsInstance(x, glm.MessagePrompt) |
196 |
| - self.assertEqual(x.context, "you are a cat") |
197 |
| - self.assertEqual(x.examples[0].input.content, "are you hungry?") |
198 |
| - self.assertEqual(x.examples[0].output.content, "meow!") |
199 |
| - self.assertEqual(x.messages[0].content, "hello") |
200 |
| - |
201 |
| - def test_make_generate_message_request_nested( |
202 |
| - self, |
203 |
| - ): |
204 |
| - request0 = discuss._make_generate_message_request( |
205 |
| - **{ |
206 |
| - "model": "Dave", |
207 |
| - "context": "you are a cat", |
208 |
| - "examples": ["hello", "meow", "are you hungry?", "meow!"], |
209 |
| - "messages": "Please catch that mouse.", |
210 |
| - "temperature": 0.2, |
211 |
| - "candidate_count": 7, |
212 |
| - } |
213 |
| - ) |
214 |
| - request1 = discuss._make_generate_message_request( |
215 |
| - **{ |
216 |
| - "model": "Dave", |
217 |
| - "prompt": { |
218 |
| - "context": "you are a cat", |
219 |
| - "examples": ["hello", "meow", "are you hungry?", "meow!"], |
220 |
| - "messages": "Please catch that mouse.", |
221 |
| - }, |
222 |
| - "temperature": 0.2, |
223 |
| - "candidate_count": 7, |
224 |
| - } |
225 |
| - ) |
226 |
| - |
227 |
| - self.assertIsInstance(request0, glm.GenerateMessageRequest) |
228 |
| - self.assertIsInstance(request1, glm.GenerateMessageRequest) |
229 |
| - self.assertEqual(request0, request1) |
230 |
| - |
231 |
| - @parameterized.parameters( |
232 |
| - {"prompt": {}, "context": "You are a cat."}, |
233 |
| - {"prompt": {"context": "You are a cat."}, "examples": ["hello", "meow"]}, |
234 |
| - {"prompt": {"examples": ["hello", "meow"]}, "messages": "hello"}, |
235 |
| - ) |
236 |
| - def test_make_generate_message_request_flat_prompt_conflict( |
237 |
| - self, |
238 |
| - context=None, |
239 |
| - examples=None, |
240 |
| - messages=None, |
241 |
| - prompt=None, |
242 |
| - ): |
243 |
| - with self.assertRaises(ValueError): |
244 |
| - x = discuss._make_generate_message_request( |
245 |
| - model="test", |
246 |
| - context=context, |
247 |
| - examples=examples, |
248 |
| - messages=messages, |
249 |
| - prompt=prompt, |
250 |
| - ) |
251 |
| - |
252 |
| - @parameterized.parameters( |
253 |
| - {"kwargs": {"context": "You are a cat."}}, |
254 |
| - {"kwargs": {"messages": "hello"}}, |
255 |
| - {"kwargs": {"examples": [["a", "b"], ["c", "d"]]}}, |
256 |
| - {"kwargs": {"messages": ["hello"], "examples": [["a", "b"], ["c", "d"]]}}, |
257 |
| - ) |
258 |
| - def test_reply(self, kwargs): |
259 |
| - response = genai.chat(**kwargs) |
260 |
| - first_messages = response.messages |
261 |
| - |
262 |
| - self.assertEqual("a", response.last) |
263 |
| - self.assertEqual( |
264 |
| - [ |
265 |
| - {"author": "1", "content": "a"}, |
266 |
| - {"author": "1", "content": "b"}, |
267 |
| - {"author": "1", "content": "c"}, |
268 |
| - ], |
269 |
| - response.candidates, |
270 |
| - ) |
271 |
| - |
272 |
| - response = response.reply("again") |
273 |
| - |
274 |
| -''' |
275 | 31 | class AsyncTests(parameterized.TestCase, asynctest.TestCase):
|
276 | 32 | async def test_chat_async(self):
|
277 | 33 | client = async_mock.MagicMock()
|
@@ -325,7 +81,7 @@ async def fake_generate_message(
|
325 | 81 | observed_response.candidates,
|
326 | 82 | [{"author": "1", "content": "Why did the chicken cross the road?"}],
|
327 | 83 | )
|
328 |
| -''' |
| 84 | + |
329 | 85 |
|
330 | 86 | if __name__ == "__main__":
|
331 | 87 | absltest.main()
|
0 commit comments