Skip to content

Commit c165b20

Browse files
authored
Count tokens (#315)
* Upgrade count tokens, add usage metadata. Change-Id: Ib07580b4f22db7e7c58109d6d0c8a27076f204da * Fix usage_metadata for streaming, fix indentation in __str__. Change-Id: I08ceb067355c933c20d50beb8ee51f9e7ba83ee7 * format Change-Id: I11ed2499b974d7cacde88c740386a5f9a71186ea * fix typing Change-Id: Ieeae7eefebc330a4b856b0016912f4fce509d780
1 parent 1efbcef commit c165b20

File tree

6 files changed

+256
-25
lines changed

6 files changed

+256
-25
lines changed

google/generativeai/generative_models.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -322,35 +322,57 @@ async def generate_content_async(
322322
# fmt: off
323323
def count_tokens(
324324
self,
325-
contents: content_types.ContentsType,
325+
contents: content_types.ContentsType = None,
326+
*,
327+
generation_config: generation_types.GenerationConfigType | None = None,
328+
safety_settings: safety_types.SafetySettingOptions | None = None,
329+
tools: content_types.FunctionLibraryType | None = None,
330+
tool_config: content_types.ToolConfigType | None = None,
326331
request_options: dict[str, Any] | None = None,
327332
) -> glm.CountTokensResponse:
328333
if request_options is None:
329334
request_options = {}
330335

331336
if self._client is None:
332337
self._client = client.get_default_generative_client()
333-
contents = content_types.to_contents(contents)
334-
return self._client.count_tokens(
335-
glm.CountTokensRequest(model=self.model_name, contents=contents),
336-
**request_options,
337-
)
338+
339+
request = glm.CountTokensRequest(
340+
model=self.model_name,
341+
generate_content_request=self._prepare_request(
342+
contents=contents,
343+
generation_config=generation_config,
344+
safety_settings=safety_settings,
345+
tools=tools,
346+
tool_config=tool_config,
347+
))
348+
return self._client.count_tokens(request, **request_options)
338349

339350
async def count_tokens_async(
340351
self,
341-
contents: content_types.ContentsType,
352+
contents: content_types.ContentsType = None,
353+
*,
354+
generation_config: generation_types.GenerationConfigType | None = None,
355+
safety_settings: safety_types.SafetySettingOptions | None = None,
356+
tools: content_types.FunctionLibraryType | None = None,
357+
tool_config: content_types.ToolConfigType | None = None,
342358
request_options: dict[str, Any] | None = None,
343359
) -> glm.CountTokensResponse:
344360
if request_options is None:
345361
request_options = {}
346362

347363
if self._async_client is None:
348364
self._async_client = client.get_default_generative_async_client()
349-
contents = content_types.to_contents(contents)
350-
return await self._async_client.count_tokens(
351-
glm.CountTokensRequest(model=self.model_name, contents=contents),
352-
**request_options,
353-
)
365+
366+
request = glm.CountTokensRequest(
367+
model=self.model_name,
368+
generate_content_request=self._prepare_request(
369+
contents=contents,
370+
generation_config=generation_config,
371+
safety_settings=safety_settings,
372+
tools=tools,
373+
tool_config=tool_config,
374+
))
375+
return await self._async_client.count_tokens(request, **request_options)
354376

355377
# fmt: on
356378

google/generativeai/types/generation_types.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
import collections
1818
import contextlib
19-
import sys
2019
from collections.abc import Iterable, AsyncIterable
2120
import dataclasses
2221
import itertools
22+
import json
23+
import sys
2324
import textwrap
2425
from typing import Union
2526
from typing_extensions import TypedDict
@@ -250,6 +251,7 @@ def _join_candidates(candidates: Iterable[glm.Candidate]):
250251
finish_reason=candidates[-1].finish_reason,
251252
safety_ratings=_join_safety_ratings_lists([c.safety_ratings for c in candidates]),
252253
citation_metadata=_join_citation_metadatas([c.citation_metadata for c in candidates]),
254+
token_count=candidates[-1].token_count,
253255
)
254256

255257

@@ -276,9 +278,11 @@ def _join_prompt_feedbacks(
276278

277279

278280
def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]):
281+
chunks = tuple(chunks)
279282
return glm.GenerateContentResponse(
280283
candidates=_join_candidate_lists(c.candidates for c in chunks),
281284
prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
285+
usage_metadata=chunks[-1].usage_metadata,
282286
)
283287

284288

@@ -373,13 +377,21 @@ def text(self):
373377
def prompt_feedback(self):
374378
return self._result.prompt_feedback
375379

380+
@property
381+
def usage_metadata(self):
382+
return self._result.usage_metadata
383+
376384
def __str__(self) -> str:
377385
if self._done:
378386
_iterator = "None"
379387
else:
380388
_iterator = f"<{self._iterator.__class__.__name__}>"
381389

382-
_result = f"glm.GenerateContentResponse({type(self._result).to_dict(self._result)})"
390+
as_dict = type(self._result).to_dict(self._result)
391+
json_str = json.dumps(as_dict, indent=2)
392+
393+
_result = f"glm.GenerateContentResponse({json_str})"
394+
_result = _result.replace("\n", "\n ")
383395

384396
if self._error:
385397
_error = f",\nerror=<{self._error.__class__.__name__}> {self._error}"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_version():
4242
release_status = "Development Status :: 5 - Production/Stable"
4343

4444
dependencies = [
45-
"google-ai-generativelanguage==0.6.2",
45+
"google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py-2.tar.gz",
4646
"google-api-core",
4747
"google-api-python-client",
4848
"google-auth>=2.15.0", # 2.15 adds API key auth support

tests/test_generation.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,24 @@ def test_repr_for_generate_content_response_from_response(self):
503503
GenerateContentResponse(
504504
done=True,
505505
iterator=None,
506-
result=glm.GenerateContentResponse({'candidates': [{'content': {'parts': [{'text': 'Hello world!'}], 'role': ''}, 'finish_reason': 0, 'safety_ratings': [], 'token_count': 0, 'grounding_attributions': []}]}),
506+
result=glm.GenerateContentResponse({
507+
"candidates": [
508+
{
509+
"content": {
510+
"parts": [
511+
{
512+
"text": "Hello world!"
513+
}
514+
],
515+
"role": ""
516+
},
517+
"finish_reason": 0,
518+
"safety_ratings": [],
519+
"token_count": 0,
520+
"grounding_attributions": []
521+
}
522+
]
523+
}),
507524
)"""
508525
)
509526
self.assertEqual(expected, result)
@@ -522,7 +539,24 @@ def test_repr_for_generate_content_response_from_iterator(self):
522539
GenerateContentResponse(
523540
done=False,
524541
iterator=<list_iterator>,
525-
result=glm.GenerateContentResponse({'candidates': [{'content': {'parts': [{'text': 'a'}], 'role': ''}, 'finish_reason': 0, 'safety_ratings': [], 'token_count': 0, 'grounding_attributions': []}]}),
542+
result=glm.GenerateContentResponse({
543+
"candidates": [
544+
{
545+
"content": {
546+
"parts": [
547+
{
548+
"text": "a"
549+
}
550+
],
551+
"role": ""
552+
},
553+
"finish_reason": 0,
554+
"safety_ratings": [],
555+
"token_count": 0,
556+
"grounding_attributions": []
557+
}
558+
]
559+
}),
526560
)"""
527561
)
528562
self.assertEqual(expected, result)

0 commit comments

Comments
 (0)