Skip to content

Commit 30a1ebf

Browse files
committed
Split async tests for discuss.
1 parent 8c51134 commit 30a1ebf

File tree

2 files changed

+3
-306
lines changed

2 files changed

+3
-306
lines changed

tests/test_discuss.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
# limitations under the License.
1515
import unittest.mock
1616

17-
#import asynctest
18-
#from asynctest import mock as async_mock
19-
2017
import google.ai.generativelanguage as glm
2118

2219
from google.generativeai import discuss
@@ -271,61 +268,5 @@ def test_reply(self, kwargs):
271268

272269
response = response.reply("again")
273270

274-
'''
275-
class AsyncTests(parameterized.TestCase, asynctest.TestCase):
276-
async def test_chat_async(self):
277-
client = async_mock.MagicMock()
278-
279-
observed_request = None
280-
281-
async def fake_generate_message(
282-
request: glm.GenerateMessageRequest,
283-
) -> glm.GenerateMessageResponse:
284-
nonlocal observed_request
285-
observed_request = request
286-
return glm.GenerateMessageResponse(
287-
candidates=[
288-
glm.Message(
289-
author="1", content="Why did the chicken cross the road?"
290-
)
291-
]
292-
)
293-
294-
client.generate_message = fake_generate_message
295-
296-
observed_response = await discuss.chat_async(
297-
model="models/bard",
298-
context="Example Prompt",
299-
examples=[["Example from human", "Example response from AI"]],
300-
messages=["Tell me a joke"],
301-
temperature=0.75,
302-
candidate_count=1,
303-
client=client,
304-
)
305-
306-
self.assertEqual(
307-
observed_request,
308-
glm.GenerateMessageRequest(
309-
model="models/bard",
310-
prompt=glm.MessagePrompt(
311-
context="Example Prompt",
312-
examples=[
313-
glm.Example(
314-
input=glm.Message(content="Example from human"),
315-
output=glm.Message(content="Example response from AI"),
316-
)
317-
],
318-
messages=[glm.Message(author="0", content="Tell me a joke")],
319-
),
320-
temperature=0.75,
321-
candidate_count=1,
322-
),
323-
)
324-
self.assertEqual(
325-
observed_response.candidates,
326-
[{"author": "1", "content": "Why did the chicken cross the road?"}],
327-
)
328-
'''
329-
330271
if __name__ == "__main__":
331272
absltest.main()

tests/test_discuss_async.py

Lines changed: 3 additions & 247 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# limitations under the License.
1515
import unittest.mock
1616

17-
#import asynctest
18-
#from asynctest import mock as async_mock
17+
import asynctest
18+
from asynctest import mock as async_mock
1919

2020
import google.ai.generativelanguage as glm
2121

@@ -28,250 +28,6 @@
2828
# TODO: replace returns with 'assert' statements
2929

3030

31-
class UnitTests(parameterized.TestCase):
32-
def setUp(self):
33-
self.client = unittest.mock.MagicMock()
34-
35-
client.default_discuss_client = self.client
36-
37-
self.observed_request = None
38-
39-
def fake_generate_message(
40-
request: glm.GenerateMessageRequest,
41-
) -> glm.GenerateMessageResponse:
42-
self.observed_request = request
43-
return glm.GenerateMessageResponse(
44-
messages=request.prompt.messages,
45-
candidates=[
46-
glm.Message(content="a", author="1"),
47-
glm.Message(content="b", author="1"),
48-
glm.Message(content="c", author="1"),
49-
],
50-
)
51-
52-
self.client.generate_message = fake_generate_message
53-
54-
@parameterized.named_parameters(
55-
["string", "Hello", ""],
56-
["dict", {"content": "Hello"}, ""],
57-
["dict_author", {"content": "Hello", "author": "me"}, "me"],
58-
["proto", glm.Message(content="Hello"), ""],
59-
["proto_author", glm.Message(content="Hello", author="me"), "me"],
60-
)
61-
def test_make_message(self, message, author):
62-
x = discuss._make_message(message)
63-
self.assertIsInstance(x, glm.Message)
64-
self.assertEqual("Hello", x.content)
65-
self.assertEqual(author, x.author)
66-
67-
@parameterized.named_parameters(
68-
["string", "Hello", ["Hello"]],
69-
["dict", {"content": "Hello"}, ["Hello"]],
70-
["proto", glm.Message(content="Hello"), ["Hello"]],
71-
[
72-
"list",
73-
["hello0", {"content": "hello1"}, glm.Message(content="hello2")],
74-
["hello0", "hello1", "hello2"],
75-
],
76-
)
77-
def test_make_messages(self, messages, expected_contents):
78-
messages = discuss._make_messages(messages)
79-
for expected, message in zip(expected_contents, messages):
80-
self.assertEqual(expected, message.content)
81-
82-
@parameterized.named_parameters(
83-
["tuple", ("hello", {"content": "goodbye"})],
84-
["iterable", iter(["hello", "goodbye"])],
85-
["dict", {"input": "hello", "output": "goodbye"}],
86-
[
87-
"proto",
88-
glm.Example(
89-
input=glm.Message(content="hello"),
90-
output=glm.Message(content="goodbye"),
91-
),
92-
],
93-
)
94-
def test_make_example(self, example):
95-
x = discuss._make_example(example)
96-
self.assertIsInstance(x, glm.Example)
97-
self.assertEqual("hello", x.input.content)
98-
self.assertEqual("goodbye", x.output.content)
99-
return
100-
101-
@parameterized.named_parameters(
102-
[
103-
"messages",
104-
[
105-
"Hi",
106-
{"content": "Hello!"},
107-
"what's your name?",
108-
glm.Message(content="Dave, what's yours"),
109-
],
110-
],
111-
[
112-
"examples",
113-
[
114-
("Hi", "Hello!"),
115-
{
116-
"input": "what's your name?",
117-
"output": {"content": "Dave, what's yours"},
118-
},
119-
],
120-
],
121-
)
122-
def test_make_examples(self, examples):
123-
examples = discuss._make_examples(examples)
124-
self.assertLen(examples, 2)
125-
self.assertEqual(examples[0].input.content, "Hi")
126-
self.assertEqual(examples[0].output.content, "Hello!")
127-
self.assertEqual(examples[1].input.content, "what's your name?")
128-
self.assertEqual(examples[1].output.content, "Dave, what's yours")
129-
130-
return
131-
132-
def test_make_examples_from_example(self):
133-
ex_dict = {"input": "hello", "output": "meow!"}
134-
example = discuss._make_example(ex_dict)
135-
examples1 = discuss._make_examples(ex_dict)
136-
examples2 = discuss._make_examples(discuss._make_example(ex_dict))
137-
138-
self.assertEqual(example, examples1[0])
139-
self.assertEqual(example, examples2[0])
140-
141-
@parameterized.named_parameters(
142-
["str", "hello"],
143-
["message", glm.Message(content="hello")],
144-
["messages", ["hello"]],
145-
["dict", {"messages": "hello"}],
146-
["dict2", {"messages": ["hello"]}],
147-
["proto", glm.MessagePrompt(messages=[glm.Message(content="hello")])],
148-
)
149-
def test_make_message_prompt_from_messages(self, prompt):
150-
x = discuss._make_message_prompt(prompt)
151-
self.assertIsInstance(x, glm.MessagePrompt)
152-
self.assertEqual(x.messages[0].content, "hello")
153-
return
154-
155-
@parameterized.named_parameters(
156-
[
157-
"dict",
158-
[
159-
{
160-
"context": "you are a cat",
161-
"examples": ["are you hungry?", "meow!"],
162-
"messages": "hello",
163-
}
164-
],
165-
{},
166-
],
167-
[
168-
"kwargs",
169-
[],
170-
{
171-
"context": "you are a cat",
172-
"examples": ["are you hungry?", "meow!"],
173-
"messages": "hello",
174-
},
175-
],
176-
[
177-
"proto",
178-
[
179-
glm.MessagePrompt(
180-
context="you are a cat",
181-
examples=[
182-
glm.Example(
183-
input=glm.Message(content="are you hungry?"),
184-
output=glm.Message(content="meow!"),
185-
)
186-
],
187-
messages=[glm.Message(content="hello")],
188-
)
189-
],
190-
{},
191-
],
192-
)
193-
def test_make_message_prompt_from_prompt(self, args, kwargs):
194-
x = discuss._make_message_prompt(*args, **kwargs)
195-
self.assertIsInstance(x, glm.MessagePrompt)
196-
self.assertEqual(x.context, "you are a cat")
197-
self.assertEqual(x.examples[0].input.content, "are you hungry?")
198-
self.assertEqual(x.examples[0].output.content, "meow!")
199-
self.assertEqual(x.messages[0].content, "hello")
200-
201-
def test_make_generate_message_request_nested(
202-
self,
203-
):
204-
request0 = discuss._make_generate_message_request(
205-
**{
206-
"model": "Dave",
207-
"context": "you are a cat",
208-
"examples": ["hello", "meow", "are you hungry?", "meow!"],
209-
"messages": "Please catch that mouse.",
210-
"temperature": 0.2,
211-
"candidate_count": 7,
212-
}
213-
)
214-
request1 = discuss._make_generate_message_request(
215-
**{
216-
"model": "Dave",
217-
"prompt": {
218-
"context": "you are a cat",
219-
"examples": ["hello", "meow", "are you hungry?", "meow!"],
220-
"messages": "Please catch that mouse.",
221-
},
222-
"temperature": 0.2,
223-
"candidate_count": 7,
224-
}
225-
)
226-
227-
self.assertIsInstance(request0, glm.GenerateMessageRequest)
228-
self.assertIsInstance(request1, glm.GenerateMessageRequest)
229-
self.assertEqual(request0, request1)
230-
231-
@parameterized.parameters(
232-
{"prompt": {}, "context": "You are a cat."},
233-
{"prompt": {"context": "You are a cat."}, "examples": ["hello", "meow"]},
234-
{"prompt": {"examples": ["hello", "meow"]}, "messages": "hello"},
235-
)
236-
def test_make_generate_message_request_flat_prompt_conflict(
237-
self,
238-
context=None,
239-
examples=None,
240-
messages=None,
241-
prompt=None,
242-
):
243-
with self.assertRaises(ValueError):
244-
x = discuss._make_generate_message_request(
245-
model="test",
246-
context=context,
247-
examples=examples,
248-
messages=messages,
249-
prompt=prompt,
250-
)
251-
252-
@parameterized.parameters(
253-
{"kwargs": {"context": "You are a cat."}},
254-
{"kwargs": {"messages": "hello"}},
255-
{"kwargs": {"examples": [["a", "b"], ["c", "d"]]}},
256-
{"kwargs": {"messages": ["hello"], "examples": [["a", "b"], ["c", "d"]]}},
257-
)
258-
def test_reply(self, kwargs):
259-
response = genai.chat(**kwargs)
260-
first_messages = response.messages
261-
262-
self.assertEqual("a", response.last)
263-
self.assertEqual(
264-
[
265-
{"author": "1", "content": "a"},
266-
{"author": "1", "content": "b"},
267-
{"author": "1", "content": "c"},
268-
],
269-
response.candidates,
270-
)
271-
272-
response = response.reply("again")
273-
274-
'''
27531
class AsyncTests(parameterized.TestCase, asynctest.TestCase):
27632
async def test_chat_async(self):
27733
client = async_mock.MagicMock()
@@ -325,7 +81,7 @@ async def fake_generate_message(
32581
observed_response.candidates,
32682
[{"author": "1", "content": "Why did the chicken cross the road?"}],
32783
)
328-
'''
84+
32985

33086
if __name__ == "__main__":
33187
absltest.main()

0 commit comments

Comments
 (0)