Skip to content

Commit db0fe2d

Browse files
authored
Simplify list-models iteration - the GAPIC lib handles this. (#86)
* Simplify list-models iteration - the GAPIC lib handles this. * resolve comments
1 parent 0d8d339 commit db0fe2d

File tree

2 files changed

+35
-95
lines changed

2 files changed

+35
-95
lines changed

google/generativeai/models.py

Lines changed: 14 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -137,56 +137,9 @@ def get_base_model_name(
137137
return base_model
138138

139139

140-
def _list_base_models_next_page(page_size, page_token, client):
141-
"""Returns the next page of the base model or tuned model list."""
142-
result = client.list_models(page_size=page_size, page_token=page_token)
143-
result = result._response
144-
result = type(result).to_dict(result)
145-
result["models"] = [model_types.Model(**mod) for mod in result["models"]]
146-
result["page_size"] = page_size
147-
result["page_token"] = result.pop("next_page_token")
148-
result["client"] = client
149-
return result
150-
151-
152-
def _list_tuned_models_next_page(page_size, page_token, client):
153-
"""Returns the next page of the base model or tuned model list."""
154-
result = client.list_tuned_models(
155-
glm.ListTunedModelsRequest(page_size=page_size, page_token=page_token)
156-
)
157-
result = result._response
158-
result = type(result).to_dict(result)
159-
result["models"] = [model_types.decode_tuned_model(mod) for mod in result.pop("tuned_models")]
160-
result["page_size"] = page_size
161-
result["page_token"] = result.pop("next_page_token")
162-
result["client"] = client
163-
return result
164-
165-
166-
def _list_models_iter_pages(
167-
*,
168-
page_size: int | None = None,
169-
select: Literal["base", "tuned"],
170-
client: glm.ModelServiceClient | None = None,
171-
):
172-
if client is None:
173-
client = get_default_model_client()
174-
175-
page_token = None
176-
while True:
177-
if select == "base":
178-
result = _list_base_models_next_page(page_size, page_token=page_token, client=client)
179-
elif select == "tuned":
180-
result = _list_tuned_models_next_page(page_size, page_token=page_token, client=client)
181-
yield from result["models"]
182-
page_token = result["page_token"]
183-
if page_token == "":
184-
break
185-
186-
187140
def list_models(
188141
*,
189-
page_size: int | None = None,
142+
page_size: int | None = 50,
190143
client: glm.ModelServiceClient | None = None,
191144
) -> model_types.ModelsIterable:
192145
"""Lists available models.
@@ -205,12 +158,17 @@ def list_models(
205158
`types.Model` objects.
206159
207160
"""
208-
return _list_models_iter_pages(page_size=page_size, select="base", client=client)
161+
if client is None:
162+
client = get_default_model_client()
163+
164+
for model in client.list_models(page_size=page_size):
165+
model = type(model).to_dict(model)
166+
yield model_types.Model(**model)
209167

210168

211169
def list_tuned_models(
212170
*,
213-
page_size: int | None = None,
171+
page_size: int | None = 50,
214172
client: glm.ModelServiceClient | None = None,
215173
) -> model_types.TunedModelsIterable:
216174
"""Lists available models.
@@ -228,7 +186,12 @@ def list_tuned_models(
228186
Yields:
229187
`types.TunedModel` objects.
230188
"""
231-
return _list_models_iter_pages(page_size=page_size, select="tuned", client=client)
189+
if client is None:
190+
client = get_default_model_client()
191+
192+
for model in client.list_tuned_models(page_size=page_size):
193+
model = type(model).to_dict(model)
194+
yield model_types.decode_tuned_model(model)
232195

233196

234197
def create_tuned_model(

tests/test_models.py

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import copy
16+
from collections.abc import Iterable
1617
import datetime
1718
import dataclasses
1819
import pathlib
@@ -42,6 +43,8 @@ def setUp(self):
4243

4344
client._client_manager.model_client = self.client
4445

46+
# TODO(markdaoust): Check if typechecking works better if wee define this as a
47+
# subclass of `glm.ModelServiceClient`, would pyi files for `glm` help?
4548
def add_client_method(f):
4649
name = f.__name__
4750
setattr(self.client, name, f)
@@ -70,10 +73,6 @@ def get_tuned_model(
7073
response = copy.copy(self.responses["get_tuned_model"])
7174
return response
7275

73-
@dataclasses.dataclass
74-
class ListWrapper:
75-
_response: Any
76-
7776
@add_client_method
7877
def list_models(
7978
request: Union[glm.ListModelsRequest, None] = None,
@@ -85,25 +84,25 @@ def list_models(
8584
request = glm.ListModelsRequest(page_size=page_size, page_token=page_token)
8685
self.assertIsInstance(request, glm.ListModelsRequest)
8786
self.observed_requests.append(request)
88-
response = self.responses["list_models"][request.page_token]
89-
return ListWrapper(response)
87+
response = self.responses["list_models"]
88+
return (item for item in response)
9089

9190
@add_client_method
9291
def list_tuned_models(
9392
request: glm.ListTunedModelsRequest = None,
9493
*,
9594
page_size=None,
9695
page_token=None,
97-
) -> glm.ListModelsResponse:
96+
) -> Iterable[glm.TunedModel]:
9897
if request is None:
9998
request = glm.ListTunedModelsRequest(page_size=page_size, page_token=page_token)
10099
self.assertIsInstance(request, glm.ListTunedModelsRequest)
101100
self.observed_requests.append(request)
102-
response = self.responses["list_tuned_models"][request.page_token]
103-
return ListWrapper(response)
101+
response = self.responses["list_tuned_models"]
102+
return (item for item in response)
104103

105104
@add_client_method
106-
def update_tuned_model(request: glm.UpdateTunedModelRequest):
105+
def update_tuned_model(request: glm.UpdateTunedModelRequest) -> glm.TunedModel:
107106
self.observed_requests.append(request)
108107
response = self.responses.get("update_tuned_model", None)
109108
if response is None:
@@ -156,24 +155,13 @@ def test_fail_with_unscoped_model_name(self, name):
156155
model = models.get_model(name)
157156

158157
def test_list_models(self):
158+
# The low level lib wraps the response in an iterable, so this is a fair test.
159159
self.responses = {
160-
"list_models": {
161-
# The first request doesn't pass a page token
162-
"": glm.ListModelsResponse(
163-
models=[
164-
glm.Model(name="models/fake-bison-001"),
165-
glm.Model(name="models/fake-bison-002"),
166-
],
167-
next_page_token="page1",
168-
),
169-
"page1": glm.ListModelsResponse(
170-
models=[
171-
glm.Model(name="models/fake-bison-003"),
172-
],
173-
# The last page returns an empty page token.
174-
next_page_token="",
175-
),
176-
}
160+
"list_models": [
161+
glm.Model(name="models/fake-bison-001"),
162+
glm.Model(name="models/fake-bison-002"),
163+
glm.Model(name="models/fake-bison-003"),
164+
]
177165
}
178166

179167
found_models = list(models.list_models())
@@ -183,23 +171,12 @@ def test_list_models(self):
183171

184172
def test_list_tuned_models(self):
185173
self.responses = {
186-
"list_tuned_models": {
187-
# The first request doesn't pass a page token
188-
"": glm.ListTunedModelsResponse(
189-
tuned_models=[
190-
glm.TunedModel(name="tunedModels/my-pig-001"),
191-
glm.TunedModel(name="tunedModels/my-pig-002"),
192-
],
193-
next_page_token="page1",
194-
),
195-
"page1": glm.ListTunedModelsResponse(
196-
tuned_models=[
197-
glm.TunedModel(name="tunedModels/my-pig-003"),
198-
],
199-
# The last page returns an empty page token.
200-
next_page_token="",
201-
),
202-
}
174+
# The low level lib wraps the response in an iterable, so this is a fair test.
175+
"list_tuned_models": [
176+
glm.TunedModel(name="tunedModels/my-pig-001"),
177+
glm.TunedModel(name="tunedModels/my-pig-002"),
178+
glm.TunedModel(name="tunedModels/my-pig-003"),
179+
]
203180
}
204181
found_models = list(models.list_tuned_models())
205182
self.assertLen(found_models, 3)

0 commit comments

Comments
 (0)