Skip to content

Commit b72d107

Browse files
authored
Fixup (#128)
* Add smkoe tests for count_tokens. * fix docstring
1 parent 2bdec6a commit b72d107

File tree

4 files changed

+53
-3
lines changed

4 files changed

+53
-3
lines changed

google/generativeai/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
3131
genai.configure(api_key=os.environ['API_KEY'])
3232
33-
model = genai.Model(name='gemini-pro')
33+
model = genai.GenerativeModel(name='gemini-pro')
3434
response = model.generate_content('Please summarise this document: ...')
3535
3636
print(response.text)

google/generativeai/generative_models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,18 @@ async def generate_content_async(
274274
def count_tokens(
275275
self, contents: content_types.ContentsType
276276
) -> glm.CountTokensResponse:
277+
if self._client is None:
278+
self._client = client.get_default_generative_client()
277279
contents = content_types.to_contents(contents)
278-
return self._client.count_tokens(model=self.model_name, contents=contents)
280+
return self._client.count_tokens(glm.CountTokensRequest(model=self.model_name, contents=contents))
279281

280282
async def count_tokens_async(
281283
self, contents: content_types.ContentsType
282284
) -> glm.CountTokensResponse:
285+
if self._async_client is None:
286+
self._async_client = client.get_default_generative_async_client()
283287
contents = content_types.to_contents(contents)
284-
return await self._client.count_tokens(model=self.model_name, contents=contents)
288+
return await self._async_client.count_tokens(glm.CountTokensRequest(model=self.model_name, contents=contents))
285289
# fmt: on
286290

287291
def start_chat(

tests/test_generative_models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ def stream_generate_content(
5656
response = self.responses["stream_generate_content"].pop(0)
5757
return response
5858

59+
@add_client_method
60+
def count_tokens(
61+
request: glm.CountTokensRequest,
62+
) -> Iterable[glm.GenerateContentResponse]:
63+
self.observed_requests.append(request)
64+
response = self.responses["count_tokens"].pop(0)
65+
return response
66+
5967
def test_hello(self):
6068
# Generate text from text prompt
6169
model = generative_models.GenerativeModel(model_name="gemini-m")
@@ -564,6 +572,21 @@ def test_chat_streaming_unexpected_stop(self):
564572
chat.rewind()
565573
self.assertLen(chat.history, 0)
566574

575+
@parameterized.named_parameters(
576+
["basic", "Hello"],
577+
["list", ["Hello"]],
578+
[
579+
"list2",
580+
[{"text": "Hello"}, {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}],
581+
],
582+
["contents", [{"role": "user", "parts": ["hello"]}]],
583+
)
584+
def test_count_tokens_smoke(self, contents):
585+
self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
586+
model = generative_models.GenerativeModel("gemini-mm-m")
587+
response = model.count_tokens(contents)
588+
self.assertEqual(type(response).to_dict(response), {"total_tokens": 7})
589+
567590
@parameterized.named_parameters(
568591
[
569592
"GenerateContentResponse",

tests/test_generative_models_async.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ async def stream_generate_content(
6262
response = self.responses["stream_generate_content"].pop(0)
6363
return response
6464

65+
@add_client_method
66+
async def count_tokens(
67+
request: glm.CountTokensRequest,
68+
) -> Iterable[glm.GenerateContentResponse]:
69+
self.observed_requests.append(request)
70+
response = self.responses["count_tokens"].pop(0)
71+
return response
72+
6573
async def test_basic(self):
6674
# Generate text from text prompt
6775
model = generative_models.GenerativeModel(model_name="gemini-m")
@@ -98,6 +106,21 @@ async def responses():
98106

99107
self.assertEqual(response.text, "world!")
100108

109+
@parameterized.named_parameters(
110+
["basic", "Hello"],
111+
["list", ["Hello"]],
112+
[
113+
"list2",
114+
[{"text": "Hello"}, {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}],
115+
],
116+
["contents", [{"role": "user", "parts": ["hello"]}]],
117+
)
118+
async def test_count_tokens_smoke(self, contents):
119+
self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
120+
model = generative_models.GenerativeModel("gemini-mm-m")
121+
response = await model.count_tokens_async(contents)
122+
self.assertEqual(type(response).to_dict(response), {"total_tokens": 7})
123+
101124

102125
if __name__ == "__main__":
103126
absltest.main()

0 commit comments

Comments
 (0)