|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 |
| -import unittest.mock |
16 | 15 |
|
17 |
| -import asynctest |
18 |
| -from asynctest import mock as async_mock |
| 16 | +import sys |
| 17 | +import unittest |
| 18 | + |
| 19 | +if sys.version_info < (3, 11): |
| 20 | + import asynctest |
| 21 | + from asynctest import mock as async_mock |
19 | 22 |
|
20 | 23 | import google.ai.generativelanguage as glm
|
21 | 24 |
|
22 | 25 | from google.generativeai import discuss
|
23 |
| -from google.generativeai import client |
24 |
| -import google.generativeai as genai |
25 | 26 | from absl.testing import absltest
|
26 | 27 | from absl.testing import parameterized
|
27 | 28 |
|
28 |
| -# TODO: replace returns with 'assert' statements |
| 29 | +bases = (parameterized.TestCase,) |
29 | 30 |
|
| 31 | +if sys.version_info < (3, 11): |
| 32 | + bases = bases + (asynctest.TestCase,) |
30 | 33 |
|
31 |
| -class AsyncTests(parameterized.TestCase, asynctest.TestCase): |
32 |
| - async def test_chat_async(self): |
33 |
| - client = async_mock.MagicMock() |
| 34 | +unittest.skipIf(sys.version_info >= (3,11), "asynctest is not suported on python 3.11+") |
| 35 | +class AsyncTests(*bases): |
| 36 | + if sys.version_info < (3, 11): |
| 37 | + async def test_chat_async(self): |
| 38 | + client = async_mock.MagicMock() |
34 | 39 |
|
35 |
| - observed_request = None |
| 40 | + observed_request = None |
36 | 41 |
|
37 |
| - async def fake_generate_message( |
38 |
| - request: glm.GenerateMessageRequest, |
39 |
| - ) -> glm.GenerateMessageResponse: |
40 |
| - nonlocal observed_request |
41 |
| - observed_request = request |
42 |
| - return glm.GenerateMessageResponse( |
43 |
| - candidates=[ |
44 |
| - glm.Message( |
45 |
| - author="1", content="Why did the chicken cross the road?" |
46 |
| - ) |
47 |
| - ] |
48 |
| - ) |
| 42 | + async def fake_generate_message( |
| 43 | + request: glm.GenerateMessageRequest, |
| 44 | + ) -> glm.GenerateMessageResponse: |
| 45 | + nonlocal observed_request |
| 46 | + observed_request = request |
| 47 | + return glm.GenerateMessageResponse( |
| 48 | + candidates=[ |
| 49 | + glm.Message( |
| 50 | + author="1", content="Why did the chicken cross the road?" |
| 51 | + ) |
| 52 | + ] |
| 53 | + ) |
49 | 54 |
|
50 |
| - client.generate_message = fake_generate_message |
| 55 | + client.generate_message = fake_generate_message |
51 | 56 |
|
52 |
| - observed_response = await discuss.chat_async( |
53 |
| - model="models/bard", |
54 |
| - context="Example Prompt", |
55 |
| - examples=[["Example from human", "Example response from AI"]], |
56 |
| - messages=["Tell me a joke"], |
57 |
| - temperature=0.75, |
58 |
| - candidate_count=1, |
59 |
| - client=client, |
60 |
| - ) |
| 57 | + observed_response = await discuss.chat_async( |
| 58 | + model="models/bard", |
| 59 | + context="Example Prompt", |
| 60 | + examples=[["Example from human", "Example response from AI"]], |
| 61 | + messages=["Tell me a joke"], |
| 62 | + temperature=0.75, |
| 63 | + candidate_count=1, |
| 64 | + client=client, |
| 65 | + ) |
61 | 66 |
|
62 | 67 | self.assertEqual(
|
63 | 68 | observed_request,
|
|
0 commit comments