Skip to content

Commit 8cc54d1

Browse files
authored
Make model name handling more consistent (and fix some pytype errors). (#36)
* Make model name handling more consistent. * fix test
1 parent 2152b71 commit 8cc54d1

File tree

5 files changed

+26
-15
lines changed

5 files changed

+26
-15
lines changed

google/generativeai/discuss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,10 @@ def count_message_tokens(
501501
context: Optional[str] = None,
502502
examples: Optional[discuss_types.ExamplesOptions] = None,
503503
messages: Optional[discuss_types.MessagesOptions] = None,
504-
model: str = DEFAULT_DISCUSS_MODEL,
504+
model: model_types.ModelNameOptions = DEFAULT_DISCUSS_MODEL,
505505
client: Optional[glm.DiscussServiceAsyncClient] = None,
506506
):
507+
model = model_types.make_model_name(model)
507508
prompt = _make_message_prompt(
508509
prompt, context=context, examples=examples, messages=messages
509510
)

google/generativeai/models.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,13 @@
1919
from google.generativeai.client import get_default_model_client
2020
from google.generativeai.types import model_types
2121

22-
# A bare model name, with no preceding namespace. e.g. foo-bar-001
23-
_BARE_MODEL_NAME = re.compile(r"^\w+-\w+-\d+$")
24-
2522

2623
def get_model(name: str, *, client=None) -> model_types.Model:
2724
"""Get the `types.Model` for the given model name."""
2825
if client is None:
2926
client = get_default_model_client()
3027

31-
# If only a bare model name is passed, give it the structure we expect.
32-
if _BARE_MODEL_NAME.match(name):
33-
name = f"models/{name}"
28+
name = model_types.make_model_name(name)
3429

3530
result = client.get_model(name=name)
3631
result = type(result).to_dict(result)

google/generativeai/text.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from google.generativeai.types import model_types
2727
from google.generativeai.types import safety_types
2828

29+
DEFAULT_TEXT_MODEL = "models/text-bison-001"
30+
2931

3032
def _make_text_prompt(prompt: Union[str, dict[str, str]]) -> glm.TextPrompt:
3133
if isinstance(prompt, str):
@@ -38,15 +40,15 @@ def _make_text_prompt(prompt: Union[str, dict[str, str]]) -> glm.TextPrompt:
3840

3941
def _make_generate_text_request(
4042
*,
41-
model: model_types.ModelNameOptions = "models/chat-lamda-001",
43+
model: model_types.ModelNameOptions = DEFAULT_TEXT_MODEL,
4244
prompt: Optional[str] = None,
4345
temperature: Optional[float] = None,
4446
candidate_count: Optional[int] = None,
4547
max_output_tokens: Optional[int] = None,
4648
top_p: Optional[int] = None,
4749
top_k: Optional[int] = None,
4850
safety_settings: Optional[List[safety_types.SafetySettingDict]] = None,
49-
stop_sequences: Union[str, Iterable[str]] = None,
51+
stop_sequences: Optional[Union[str, Iterable[str]]] = None,
5052
) -> glm.GenerateTextRequest:
5153
model = model_types.make_model_name(model)
5254
prompt = _make_text_prompt(prompt=prompt)
@@ -70,15 +72,15 @@ def _make_generate_text_request(
7072

7173
def generate_text(
7274
*,
73-
model: Optional[model_types.ModelNameOptions] = "models/text-bison-001",
75+
model: model_types.ModelNameOptions = DEFAULT_TEXT_MODEL,
7476
prompt: str,
7577
temperature: Optional[float] = None,
7678
candidate_count: Optional[int] = None,
7779
max_output_tokens: Optional[int] = None,
7880
top_p: Optional[float] = None,
7981
top_k: Optional[float] = None,
80-
safety_settings: Optional[Iterable[safety.SafetySettingDict]] = None,
81-
stop_sequences: Union[str, Iterable[str]] = None,
82+
safety_settings: Optional[Iterable[safety_types.SafetySettingDict]] = None,
83+
stop_sequences: Optional[Union[str, Iterable[str]]] = None,
8284
client: Optional[glm.TextServiceClient] = None,
8385
) -> text_types.Completion:
8486
"""Calls the API and returns a `types.Completion` containing the response.
@@ -170,7 +172,9 @@ def _generate_response(
170172
return Completion(_client=client, **response)
171173

172174

173-
def generate_embeddings(model: str, text: str, client: glm.TextServiceClient = None):
175+
def generate_embeddings(
176+
model: model_types.ModelNameOptions, text: str, client: glm.TextServiceClient = None
177+
):
174178
"""Calls the API to create an embedding for the text passed in.
175179
176180
Args:

google/generativeai/types/model_types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""Type definitions for the models service."""
1616

17+
import re
1718
import abc
1819
import dataclasses
1920
from typing import Iterator, List, Optional, Union
@@ -58,10 +59,18 @@ class Model:
5859

5960
ModelNameOptions = Union[str, Model]
6061

62+
# A bare model name, with no preceding namespace. e.g. foo-bar-001
63+
_BARE_MODEL_NAME = re.compile(r"^\w+-\w+-\d+$")
64+
6165

6266
def make_model_name(name: ModelNameOptions):
6367
if isinstance(name, Model):
6468
name = name.name
69+
elif isinstance(name, str):
70+
# If only a bare model name is passed, give it the structure we expect.
71+
if _BARE_MODEL_NAME.match(name):
72+
name = f"models/{name}"
73+
6574
return name
6675

6776

tests/test_text.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ def test_make_prompt(self, prompt):
7474
]
7575
)
7676
def test_make_generate_text_request(self, prompt):
77-
x = text_service._make_generate_text_request(prompt=prompt)
78-
self.assertEqual("models/chat-lamda-001", x.model)
77+
x = text_service._make_generate_text_request(
78+
model="chat-bison-001", prompt=prompt
79+
)
80+
self.assertEqual("models/chat-bison-001", x.model)
7981
self.assertIsInstance(x, glm.GenerateTextRequest)
8082

8183
@parameterized.named_parameters(

0 commit comments

Comments
 (0)