Skip to content

Commit 125c436

Browse files
authored
Add count_text_tokens, and expose operations. (#76)
* Add count_text_tokens, and expose operations. * Format and fix pytype errors. * use get_base_model_name in create_tuned_model * Resolve comments
1 parent 46db06f commit 125c436

File tree

8 files changed

+154
-22
lines changed

8 files changed

+154
-22
lines changed

google/generativeai/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777

7878
from google.generativeai.text import generate_text
7979
from google.generativeai.text import generate_embeddings
80+
from google.generativeai.text import count_text_tokens
8081

8182
from google.generativeai.models import list_models
8283
from google.generativeai.models import list_tuned_models
@@ -89,6 +90,10 @@
8990
from google.generativeai.models import update_tuned_model
9091
from google.generativeai.models import delete_tuned_model
9192

93+
from google.generativeai.operations import list_operations
94+
from google.generativeai.operations import get_operation
95+
96+
9297
from google.generativeai.client import configure
9398

9499
__version__ = version.__version__

google/generativeai/discuss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def count_message_tokens(
565565
messages: discuss_types.MessagesOptions | None = None,
566566
model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL,
567567
client: glm.DiscussServiceAsyncClient | None = None,
568-
):
568+
) -> discuss_types.TokenCount:
569569
model = model_types.make_model_name(model)
570570
prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages)
571571

google/generativeai/models.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def get_base_model(name: model_types.BaseModelNameOptions, *, client=None) -> mo
7474

7575
name = model_types.make_model_name(name)
7676
if not name.startswith("models/"):
77-
raise ValueError("Base model names must start with `models/`")
77+
raise ValueError(f"Base model names must start with `models/`, got: {name}")
7878

7979
result = client.get_model(name=name)
8080
result = type(result).to_dict(result)
@@ -112,6 +112,31 @@ def get_tuned_model(
112112
return model_types.decode_tuned_model(result)
113113

114114

115+
def get_base_model_name(
116+
model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None
117+
):
118+
if isinstance(model, str):
119+
if model.startswith("tunedModels/"):
120+
model = get_model(model, client=client)
121+
base_model = model.base_model
122+
else:
123+
base_model = model
124+
elif isinstance(model, model_types.TunedModel):
125+
base_model = model.base_model
126+
elif isinstance(model, model_types.Model):
127+
base_model = model.name
128+
elif isinstance(model, glm.Model):
129+
base_model = model.name
130+
elif isinstance(model, glm.TunedModel):
131+
base_model = getattr(model, "base_model", None)
132+
if not base_model:
133+
base_model = model.tuned_model_source.base_model
134+
else:
135+
raise TypeError(f"Cannot understand model: {model}")
136+
137+
return base_model
138+
139+
115140
def _list_base_models_next_page(page_size, page_token, client):
116141
"""Returns the next page of the base model or tuned model list."""
117142
result = client.list_models(page_size=page_size, page_token=page_token)
@@ -270,18 +295,14 @@ def create_tuned_model(
270295
client = get_default_model_client()
271296

272297
source_model_name = model_types.make_model_name(source_model)
298+
base_model_name = get_base_model_name(source_model)
273299
if source_model_name.startswith("models/"):
274300
source_model = {"base_model": source_model_name}
275301
elif source_model_name.startswith("tunedModels/"):
276-
source_model = client.get_tuned_model(name=source_model_name)
277-
base_model = source_model.base_model
278-
if not base_model:
279-
base_model = source_model.tuned_model_source.base_model
280-
281302
source_model = {
282303
"tuned_model_source": {
283304
"tuned_model": source_model_name,
284-
"base_model": base_model,
305+
"base_model": base_model_name,
285306
}
286307
}
287308
else:

google/generativeai/text.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from google.generativeai import string_utils
2525
from google.generativeai.types import text_types
2626
from google.generativeai.types import model_types
27+
from google.generativeai import models
2728
from google.generativeai.types import safety_types
2829

2930
DEFAULT_TEXT_MODEL = "models/text-bison-001"
@@ -217,6 +218,23 @@ def _generate_response(
217218
return Completion(_client=client, **response)
218219

219220

221+
def count_text_tokens(
222+
model: model_types.AnyModelNameOptions,
223+
prompt: str,
224+
client: glm.TextServiceClient | None = None,
225+
) -> text_types.TokenCount:
226+
base_model = models.get_base_model_name(model)
227+
228+
if client is None:
229+
client = get_default_text_client()
230+
231+
result = client.count_text_tokens(
232+
glm.CountTextTokensRequest(model=base_model, prompt={"text": prompt})
233+
)
234+
235+
return type(result).to_dict(result)
236+
237+
220238
@overload
221239
def generate_embeddings(
222240
model: model_types.BaseModelNameOptions,

google/generativeai/types/discuss_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
]
4040

4141

42+
class TokenCount(TypedDict):
43+
token_count: int
44+
45+
4246
class MessageDict(TypedDict):
4347
"""A dict representation of a `glm.Message`."""
4448

google/generativeai/types/model_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def make_model_name(name: AnyModelNameOptions):
243243
raise TypeError("Expected: str, Model, or TunedModel")
244244

245245
if not (name.startswith("models/") or name.startswith("tunedModels/")):
246-
raise ValueError("Model names should start with `models/` or `tunedModels/`")
246+
raise ValueError("Model names should start with `models/` or `tunedModels/`, got: {name}")
247247

248248
return name
249249

google/generativeai/types/text_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
__all__ = ["Completion"]
2727

2828

29+
class TokenCount(TypedDict):
30+
token_count: int
31+
32+
2933
class EmbeddingDict(TypedDict):
3034
embedding: list[float]
3135

tests/test_text.py

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
16-
import os
15+
import copy
1716
import unittest
1817
import unittest.mock as mock
1918

@@ -22,6 +21,7 @@
2221
from google.generativeai import text as text_service
2322
from google.generativeai import client
2423
from google.generativeai.types import safety_types
24+
from google.generativeai.types import model_types
2525
from absl.testing import absltest
2626
from absl.testing import parameterized
2727

@@ -31,8 +31,9 @@ def setUp(self):
3131
self.client = unittest.mock.MagicMock()
3232

3333
client._client_manager.text_client = self.client
34+
client._client_manager.model_client = self.client
3435

35-
self.observed_request = None
36+
self.observed_requests = []
3637

3738
self.responses = {}
3839

@@ -45,23 +46,37 @@ def add_client_method(f):
4546
def generate_text(
4647
request: glm.GenerateTextRequest,
4748
) -> glm.GenerateTextResponse:
48-
self.observed_request = request
49+
self.observed_requests.append(request)
4950
return self.responses["generate_text"]
5051

5152
@add_client_method
5253
def embed_text(
5354
request: glm.EmbedTextRequest,
5455
) -> glm.EmbedTextResponse:
55-
self.observed_request = request
56+
self.observed_requests.append(request)
5657
return self.responses["embed_text"]
5758

5859
@add_client_method
5960
def batch_embed_text(
6061
request: glm.EmbedTextRequest,
6162
) -> glm.EmbedTextResponse:
62-
self.observed_request = request
63+
self.observed_requests.append(request)
6364
return self.responses["batch_embed_text"]
6465

66+
@add_client_method
67+
def count_text_tokens(
68+
request: glm.CountTextTokensRequest,
69+
) -> glm.CountTextTokensResponse:
70+
self.observed_requests.append(request)
71+
return self.responses["count_text_tokens"]
72+
73+
@add_client_method
74+
def get_tuned_model(name) -> glm.TunedModel:
75+
request = glm.GetTunedModelRequest(name=name)
76+
self.observed_requests.append(request)
77+
response = copy.copy(self.responses["get_tuned_model"])
78+
return response
79+
6580
@parameterized.named_parameters(
6681
[
6782
dict(testcase_name="string", prompt="Hello how are"),
@@ -99,7 +114,7 @@ def test_generate_embeddings(self, model, text):
99114
emb = text_service.generate_embeddings(model=model, text=text)
100115

101116
self.assertIsInstance(emb, dict)
102-
self.assertEqual(self.observed_request, glm.EmbedTextRequest(model=model, text=text))
117+
self.assertEqual(self.observed_requests[-1], glm.EmbedTextRequest(model=model, text=text))
103118
self.assertIsInstance(emb["embedding"][0], float)
104119

105120
@parameterized.named_parameters(
@@ -123,8 +138,7 @@ def test_generate_embeddings_batch(self, model, text):
123138

124139
self.assertIsInstance(emb, dict)
125140
self.assertEqual(
126-
self.observed_request,
127-
glm.BatchEmbedTextRequest(model=model, texts=text),
141+
self.observed_requests[-1], glm.BatchEmbedTextRequest(model=model, texts=text)
128142
)
129143
self.assertIsInstance(emb["embedding"][0], list)
130144

@@ -160,7 +174,7 @@ def test_generate_response(self, *, prompt, **kwargs):
160174
complete = text_service.generate_text(prompt=prompt, **kwargs)
161175

162176
self.assertEqual(
163-
self.observed_request,
177+
self.observed_requests[-1],
164178
glm.GenerateTextRequest(
165179
model="models/text-bison-001", prompt=glm.TextPrompt(text=prompt), **kwargs
166180
),
@@ -188,15 +202,15 @@ def test_stop_string(self):
188202
complete = text_service.generate_text(prompt="Hello", stop_sequences="stop")
189203

190204
self.assertEqual(
191-
self.observed_request,
205+
self.observed_requests[-1],
192206
glm.GenerateTextRequest(
193207
model="models/text-bison-001",
194208
prompt=glm.TextPrompt(text="Hello"),
195209
stop_sequences=["stop"],
196210
),
197211
)
198212
# Just make sure it made it into the request object.
199-
self.assertEqual(self.observed_request.stop_sequences, ["stop"])
213+
self.assertEqual(self.observed_requests[-1].stop_sequences, ["stop"])
200214

201215
@parameterized.named_parameters(
202216
[
@@ -251,7 +265,7 @@ def test_safety_settings(self, safety_settings):
251265
)
252266

253267
self.assertEqual(
254-
self.observed_request.safety_settings[0].category,
268+
self.observed_requests[-1].safety_settings[0].category,
255269
safety_types.HarmCategory.HARM_CATEGORY_MEDICAL,
256270
)
257271

@@ -367,6 +381,72 @@ def test_candidate_citations(self):
367381
6,
368382
)
369383

384+
@parameterized.named_parameters(
385+
[
386+
dict(testcase_name="base-name", model="models/text-bison-001"),
387+
dict(testcase_name="tuned-name", model="tunedModels/bipedal-pangolin-001"),
388+
dict(
389+
testcase_name="model",
390+
model=model_types.Model(
391+
name="models/text-bison-001",
392+
base_model_id="text-bison-001",
393+
version="001",
394+
display_name="🦬",
395+
description="🦬🦬🦬🦬🦬🦬🦬🦬🦬🦬🦬",
396+
input_token_limit=8000,
397+
output_token_limit=4000,
398+
supported_generation_methods=["GenerateText"],
399+
),
400+
),
401+
dict(
402+
testcase_name="tuned_model",
403+
model=model_types.TunedModel(
404+
name="tunedModels/bipedal-pangolin-001",
405+
base_model="models/text-bison-001",
406+
),
407+
),
408+
dict(
409+
testcase_name="glm_model",
410+
model=glm.Model(
411+
name="models/text-bison-001",
412+
),
413+
),
414+
dict(
415+
testcase_name="glm_tuned_model",
416+
model=glm.TunedModel(
417+
name="tunedModels/bipedal-pangolin-001",
418+
base_model="models/text-bison-001",
419+
),
420+
),
421+
dict(
422+
testcase_name="glm_tuned_model_nested",
423+
model=glm.TunedModel(
424+
name="tunedModels/bipedal-pangolin-002",
425+
tuned_model_source={
426+
"tuned_model": "tunedModels/bipedal-pangolin-002",
427+
"base_model": "models/text-bison-001",
428+
},
429+
),
430+
),
431+
]
432+
)
433+
def test_count_message_tokens(self, model):
434+
self.responses["get_tuned_model"] = glm.TunedModel(
435+
name="tunedModels/bipedal-pangolin-001", base_model="models/text-bison-001"
436+
)
437+
self.responses["count_text_tokens"] = glm.CountTextTokensResponse(token_count=7)
438+
439+
response = text_service.count_text_tokens(model, "Tell me a story about a magic backpack.")
440+
self.assertEqual({"token_count": 7}, response)
441+
442+
should_look_up_model = isinstance(model, str) and model.startswith("tunedModels/")
443+
if should_look_up_model:
444+
self.assertLen(self.observed_requests, 2)
445+
self.assertEqual(
446+
self.observed_requests[0],
447+
glm.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"),
448+
)
449+
370450

371451
if __name__ == "__main__":
372452
absltest.main()

0 commit comments

Comments
 (0)