Skip to content

Commit 8c51134

Browse files
committed
split test_discuss for async
1 parent d71edf4 commit 8c51134

File tree

1 file changed

+331
-0
lines changed

1 file changed

+331
-0
lines changed

tests/test_discuss_async.py

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2023 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import unittest.mock
16+
17+
#import asynctest
18+
#from asynctest import mock as async_mock
19+
20+
import google.ai.generativelanguage as glm
21+
22+
from google.generativeai import discuss
23+
from google.generativeai import client
24+
import google.generativeai as genai
25+
from absl.testing import absltest
26+
from absl.testing import parameterized
27+
28+
# TODO: replace returns with 'assert' statements
29+
30+
31+
class UnitTests(parameterized.TestCase):
32+
def setUp(self):
33+
self.client = unittest.mock.MagicMock()
34+
35+
client.default_discuss_client = self.client
36+
37+
self.observed_request = None
38+
39+
def fake_generate_message(
40+
request: glm.GenerateMessageRequest,
41+
) -> glm.GenerateMessageResponse:
42+
self.observed_request = request
43+
return glm.GenerateMessageResponse(
44+
messages=request.prompt.messages,
45+
candidates=[
46+
glm.Message(content="a", author="1"),
47+
glm.Message(content="b", author="1"),
48+
glm.Message(content="c", author="1"),
49+
],
50+
)
51+
52+
self.client.generate_message = fake_generate_message
53+
54+
@parameterized.named_parameters(
55+
["string", "Hello", ""],
56+
["dict", {"content": "Hello"}, ""],
57+
["dict_author", {"content": "Hello", "author": "me"}, "me"],
58+
["proto", glm.Message(content="Hello"), ""],
59+
["proto_author", glm.Message(content="Hello", author="me"), "me"],
60+
)
61+
def test_make_message(self, message, author):
62+
x = discuss._make_message(message)
63+
self.assertIsInstance(x, glm.Message)
64+
self.assertEqual("Hello", x.content)
65+
self.assertEqual(author, x.author)
66+
67+
@parameterized.named_parameters(
68+
["string", "Hello", ["Hello"]],
69+
["dict", {"content": "Hello"}, ["Hello"]],
70+
["proto", glm.Message(content="Hello"), ["Hello"]],
71+
[
72+
"list",
73+
["hello0", {"content": "hello1"}, glm.Message(content="hello2")],
74+
["hello0", "hello1", "hello2"],
75+
],
76+
)
77+
def test_make_messages(self, messages, expected_contents):
78+
messages = discuss._make_messages(messages)
79+
for expected, message in zip(expected_contents, messages):
80+
self.assertEqual(expected, message.content)
81+
82+
@parameterized.named_parameters(
83+
["tuple", ("hello", {"content": "goodbye"})],
84+
["iterable", iter(["hello", "goodbye"])],
85+
["dict", {"input": "hello", "output": "goodbye"}],
86+
[
87+
"proto",
88+
glm.Example(
89+
input=glm.Message(content="hello"),
90+
output=glm.Message(content="goodbye"),
91+
),
92+
],
93+
)
94+
def test_make_example(self, example):
95+
x = discuss._make_example(example)
96+
self.assertIsInstance(x, glm.Example)
97+
self.assertEqual("hello", x.input.content)
98+
self.assertEqual("goodbye", x.output.content)
99+
return
100+
101+
@parameterized.named_parameters(
102+
[
103+
"messages",
104+
[
105+
"Hi",
106+
{"content": "Hello!"},
107+
"what's your name?",
108+
glm.Message(content="Dave, what's yours"),
109+
],
110+
],
111+
[
112+
"examples",
113+
[
114+
("Hi", "Hello!"),
115+
{
116+
"input": "what's your name?",
117+
"output": {"content": "Dave, what's yours"},
118+
},
119+
],
120+
],
121+
)
122+
def test_make_examples(self, examples):
123+
examples = discuss._make_examples(examples)
124+
self.assertLen(examples, 2)
125+
self.assertEqual(examples[0].input.content, "Hi")
126+
self.assertEqual(examples[0].output.content, "Hello!")
127+
self.assertEqual(examples[1].input.content, "what's your name?")
128+
self.assertEqual(examples[1].output.content, "Dave, what's yours")
129+
130+
return
131+
132+
def test_make_examples_from_example(self):
133+
ex_dict = {"input": "hello", "output": "meow!"}
134+
example = discuss._make_example(ex_dict)
135+
examples1 = discuss._make_examples(ex_dict)
136+
examples2 = discuss._make_examples(discuss._make_example(ex_dict))
137+
138+
self.assertEqual(example, examples1[0])
139+
self.assertEqual(example, examples2[0])
140+
141+
@parameterized.named_parameters(
142+
["str", "hello"],
143+
["message", glm.Message(content="hello")],
144+
["messages", ["hello"]],
145+
["dict", {"messages": "hello"}],
146+
["dict2", {"messages": ["hello"]}],
147+
["proto", glm.MessagePrompt(messages=[glm.Message(content="hello")])],
148+
)
149+
def test_make_message_prompt_from_messages(self, prompt):
150+
x = discuss._make_message_prompt(prompt)
151+
self.assertIsInstance(x, glm.MessagePrompt)
152+
self.assertEqual(x.messages[0].content, "hello")
153+
return
154+
155+
@parameterized.named_parameters(
156+
[
157+
"dict",
158+
[
159+
{
160+
"context": "you are a cat",
161+
"examples": ["are you hungry?", "meow!"],
162+
"messages": "hello",
163+
}
164+
],
165+
{},
166+
],
167+
[
168+
"kwargs",
169+
[],
170+
{
171+
"context": "you are a cat",
172+
"examples": ["are you hungry?", "meow!"],
173+
"messages": "hello",
174+
},
175+
],
176+
[
177+
"proto",
178+
[
179+
glm.MessagePrompt(
180+
context="you are a cat",
181+
examples=[
182+
glm.Example(
183+
input=glm.Message(content="are you hungry?"),
184+
output=glm.Message(content="meow!"),
185+
)
186+
],
187+
messages=[glm.Message(content="hello")],
188+
)
189+
],
190+
{},
191+
],
192+
)
193+
def test_make_message_prompt_from_prompt(self, args, kwargs):
194+
x = discuss._make_message_prompt(*args, **kwargs)
195+
self.assertIsInstance(x, glm.MessagePrompt)
196+
self.assertEqual(x.context, "you are a cat")
197+
self.assertEqual(x.examples[0].input.content, "are you hungry?")
198+
self.assertEqual(x.examples[0].output.content, "meow!")
199+
self.assertEqual(x.messages[0].content, "hello")
200+
201+
def test_make_generate_message_request_nested(
202+
self,
203+
):
204+
request0 = discuss._make_generate_message_request(
205+
**{
206+
"model": "Dave",
207+
"context": "you are a cat",
208+
"examples": ["hello", "meow", "are you hungry?", "meow!"],
209+
"messages": "Please catch that mouse.",
210+
"temperature": 0.2,
211+
"candidate_count": 7,
212+
}
213+
)
214+
request1 = discuss._make_generate_message_request(
215+
**{
216+
"model": "Dave",
217+
"prompt": {
218+
"context": "you are a cat",
219+
"examples": ["hello", "meow", "are you hungry?", "meow!"],
220+
"messages": "Please catch that mouse.",
221+
},
222+
"temperature": 0.2,
223+
"candidate_count": 7,
224+
}
225+
)
226+
227+
self.assertIsInstance(request0, glm.GenerateMessageRequest)
228+
self.assertIsInstance(request1, glm.GenerateMessageRequest)
229+
self.assertEqual(request0, request1)
230+
231+
@parameterized.parameters(
232+
{"prompt": {}, "context": "You are a cat."},
233+
{"prompt": {"context": "You are a cat."}, "examples": ["hello", "meow"]},
234+
{"prompt": {"examples": ["hello", "meow"]}, "messages": "hello"},
235+
)
236+
def test_make_generate_message_request_flat_prompt_conflict(
237+
self,
238+
context=None,
239+
examples=None,
240+
messages=None,
241+
prompt=None,
242+
):
243+
with self.assertRaises(ValueError):
244+
x = discuss._make_generate_message_request(
245+
model="test",
246+
context=context,
247+
examples=examples,
248+
messages=messages,
249+
prompt=prompt,
250+
)
251+
252+
@parameterized.parameters(
253+
{"kwargs": {"context": "You are a cat."}},
254+
{"kwargs": {"messages": "hello"}},
255+
{"kwargs": {"examples": [["a", "b"], ["c", "d"]]}},
256+
{"kwargs": {"messages": ["hello"], "examples": [["a", "b"], ["c", "d"]]}},
257+
)
258+
def test_reply(self, kwargs):
259+
response = genai.chat(**kwargs)
260+
first_messages = response.messages
261+
262+
self.assertEqual("a", response.last)
263+
self.assertEqual(
264+
[
265+
{"author": "1", "content": "a"},
266+
{"author": "1", "content": "b"},
267+
{"author": "1", "content": "c"},
268+
],
269+
response.candidates,
270+
)
271+
272+
response = response.reply("again")
273+
274+
'''
275+
class AsyncTests(parameterized.TestCase, asynctest.TestCase):
276+
async def test_chat_async(self):
277+
client = async_mock.MagicMock()
278+
279+
observed_request = None
280+
281+
async def fake_generate_message(
282+
request: glm.GenerateMessageRequest,
283+
) -> glm.GenerateMessageResponse:
284+
nonlocal observed_request
285+
observed_request = request
286+
return glm.GenerateMessageResponse(
287+
candidates=[
288+
glm.Message(
289+
author="1", content="Why did the chicken cross the road?"
290+
)
291+
]
292+
)
293+
294+
client.generate_message = fake_generate_message
295+
296+
observed_response = await discuss.chat_async(
297+
model="models/bard",
298+
context="Example Prompt",
299+
examples=[["Example from human", "Example response from AI"]],
300+
messages=["Tell me a joke"],
301+
temperature=0.75,
302+
candidate_count=1,
303+
client=client,
304+
)
305+
306+
self.assertEqual(
307+
observed_request,
308+
glm.GenerateMessageRequest(
309+
model="models/bard",
310+
prompt=glm.MessagePrompt(
311+
context="Example Prompt",
312+
examples=[
313+
glm.Example(
314+
input=glm.Message(content="Example from human"),
315+
output=glm.Message(content="Example response from AI"),
316+
)
317+
],
318+
messages=[glm.Message(author="0", content="Tell me a joke")],
319+
),
320+
temperature=0.75,
321+
candidate_count=1,
322+
),
323+
)
324+
self.assertEqual(
325+
observed_response.candidates,
326+
[{"author": "1", "content": "Why did the chicken cross the road?"}],
327+
)
328+
'''
329+
330+
if __name__ == "__main__":
331+
absltest.main()

0 commit comments

Comments
 (0)