Skip to content

Commit 2152b71

Browse files
authored
Support bare model names in palm.get_model (#35)
Right now we require model names to be specified as `"models/type-size-version"` (with the `models/` prefix), but that seems slightly unintuitive. This change adds the `models/` prefix if it detects the expected *bare* model string (e.g. `foo-bar-001`). Adds `test_models.py` too.
1 parent a99f45e commit 2152b71

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

google/generativeai/models.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,26 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import dataclasses
15+
import re
1616
from typing import Optional, List
1717

1818
import google.ai.generativelanguage as glm
1919
from google.generativeai.client import get_default_model_client
2020
from google.generativeai.types import model_types
2121

22+
# A bare model name, with no preceding namespace. e.g. foo-bar-001
23+
_BARE_MODEL_NAME = re.compile(r"^\w+-\w+-\d+$")
2224

23-
def get_model(name, *, client=None) -> model_types.Model:
25+
26+
def get_model(name: str, *, client=None) -> model_types.Model:
2427
"""Get the `types.Model` for the given model name."""
2528
if client is None:
2629
client = get_default_model_client()
2730

31+
# If only a bare model name is passed, give it the structure we expect.
32+
if _BARE_MODEL_NAME.match(name):
33+
name = f"models/{name}"
34+
2835
result = client.get_model(name=name)
2936
result = type(result).to_dict(result)
3037
return model_types.Model(**result)
@@ -37,7 +44,7 @@ def __init__(
3744
page_size: int,
3845
page_token: Optional[str],
3946
models: List[model_types.Model],
40-
client: Optional[glm.ModelServiceClient]
47+
client: Optional[glm.ModelServiceClient],
4148
):
4249
self._page_size = page_size
4350
self._page_token = page_token

tests/test_models.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2023 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from unittest import mock
16+
17+
from absl.testing import absltest
18+
19+
import google.ai.generativelanguage as glm
20+
from google.ai.generativelanguage_v1beta2.types import model
21+
22+
from google.generativeai import models
23+
24+
_FAKE_MODEL = model.Model(
25+
name="models/fake-model-001",
26+
base_model_id="",
27+
version="001",
28+
display_name="Fake Model",
29+
description="A fake model",
30+
input_token_limit=123,
31+
output_token_limit=234,
32+
supported_generation_methods=[],
33+
)
34+
35+
36+
class UnitTests(absltest.TestCase):
37+
def test_model_prefix(self):
38+
"""Test `models/` prefix applies to get_model calls when necessary."""
39+
# The SUT needs a concrete return type from `get_model`, so set up a real-enough client.
40+
fake_client = mock.Mock(spec=glm.ModelServiceClient)
41+
fake_client.get_model.return_value = _FAKE_MODEL
42+
43+
# Ensure that we don't mess with correctly structure args.
44+
models.get_model(name="models/text-bison-001", client=fake_client)
45+
fake_client.get_model.assert_called_with(name="models/text-bison-001")
46+
47+
# Ensure that we do correct bare models.
48+
models.get_model(name="text-bison-001", client=fake_client)
49+
fake_client.get_model.assert_called_with(name="models/text-bison-001")
50+
51+
# And unknown structure is not touched.
52+
models.get_model(name="unknown_string", client=fake_client)
53+
fake_client.get_model.assert_called_with(name="unknown_string")

0 commit comments

Comments
 (0)