Skip to content

Commit 375bd6a

Browse files
committed
Cleanup async tests
1 parent 7d75dfa commit 375bd6a

File tree

2 files changed

+37
-35
lines changed

2 files changed

+37
-35
lines changed

tests/test_discuss_async.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,52 +12,57 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import unittest.mock
1615

17-
import asynctest
18-
from asynctest import mock as async_mock
16+
import sys
17+
import unittest
18+
19+
if sys.version_info < (3, 11):
20+
import asynctest
21+
from asynctest import mock as async_mock
1922

2023
import google.ai.generativelanguage as glm
2124

2225
from google.generativeai import discuss
23-
from google.generativeai import client
24-
import google.generativeai as genai
2526
from absl.testing import absltest
2627
from absl.testing import parameterized
2728

28-
# TODO: replace returns with 'assert' statements
29+
bases = (parameterized.TestCase,)
2930

31+
if sys.version_info < (3, 11):
32+
bases = bases + (asynctest.TestCase,)
3033

31-
class AsyncTests(parameterized.TestCase, asynctest.TestCase):
32-
async def test_chat_async(self):
33-
client = async_mock.MagicMock()
34+
unittest.skipIf(sys.version_info >= (3,11), "asynctest is not suported on python 3.11+")
35+
class AsyncTests(*bases):
36+
if sys.version_info < (3, 11):
37+
async def test_chat_async(self):
38+
client = async_mock.MagicMock()
3439

35-
observed_request = None
40+
observed_request = None
3641

37-
async def fake_generate_message(
38-
request: glm.GenerateMessageRequest,
39-
) -> glm.GenerateMessageResponse:
40-
nonlocal observed_request
41-
observed_request = request
42-
return glm.GenerateMessageResponse(
43-
candidates=[
44-
glm.Message(
45-
author="1", content="Why did the chicken cross the road?"
46-
)
47-
]
48-
)
42+
async def fake_generate_message(
43+
request: glm.GenerateMessageRequest,
44+
) -> glm.GenerateMessageResponse:
45+
nonlocal observed_request
46+
observed_request = request
47+
return glm.GenerateMessageResponse(
48+
candidates=[
49+
glm.Message(
50+
author="1", content="Why did the chicken cross the road?"
51+
)
52+
]
53+
)
4954

50-
client.generate_message = fake_generate_message
55+
client.generate_message = fake_generate_message
5156

52-
observed_response = await discuss.chat_async(
53-
model="models/bard",
54-
context="Example Prompt",
55-
examples=[["Example from human", "Example response from AI"]],
56-
messages=["Tell me a joke"],
57-
temperature=0.75,
58-
candidate_count=1,
59-
client=client,
60-
)
57+
observed_response = await discuss.chat_async(
58+
model="models/bard",
59+
context="Example Prompt",
60+
examples=[["Example from human", "Example response from AI"]],
61+
messages=["Tell me a joke"],
62+
temperature=0.75,
63+
candidate_count=1,
64+
client=client,
65+
)
6166

6267
self.assertEqual(
6368
observed_request,

tests/test_text.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
import unittest
1818
import unittest.mock as mock
1919

20-
import asynctest
21-
from asynctest import mock as async_mock
22-
2320
import google.ai.generativelanguage as glm
2421

2522
from google.generativeai import text as text_service

0 commit comments

Comments
 (0)