Skip to content

Commit 30337c2

Browse files
MarkDaoustmarkmcd
andauthored
Improve request_options (#297)
* Working on request_options * Add helper_types Change-Id: Idc3e813616413f4ce085c05b771c0127e4dfc886 * format Change-Id: I186e015de97ceece56ee5a97f6edef47ef223d18 * UpdateRequestOptions Change-Id: I9f92466967fb1aa605d442cb143699da4308409b * Add docs Change-Id: I209b2b2ad8d783001b1828cbcac84ca301c11bec * work Change-Id: I00a2e2edb1e9bf3d4f51c0a868a34e044be3c6ff * Fix Py3.9 Change-Id: I8cf0ccac90ba3c4548e7549fec7d0b9b58925e7e * use RequestOptions in tests Change-Id: I92b68bc86330ad874c3765f428a2e64ba220750f * annotations Change-Id: Idbc428075729255d66d2ba8b3bcce0a1d6e8f048 * Update tests/test_discuss.py Co-authored-by: Mark McDonald <[email protected]> * tests Change-Id: Ife30e2cc47bd4c52d2dddafdd85a51df0e42e160 --------- Co-authored-by: Mark McDonald <[email protected]>
1 parent 51d806d commit 30337c2

File tree

14 files changed

+257
-84
lines changed

14 files changed

+257
-84
lines changed

google/generativeai/answer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@
2626
get_default_generative_client,
2727
get_default_generative_async_client,
2828
)
29-
from google.generativeai import string_utils
3029
from google.generativeai.types import model_types
31-
from google.generativeai import models
30+
from google.generativeai.types import helper_types
3231
from google.generativeai.types import safety_types
3332
from google.generativeai.types import content_types
34-
from google.generativeai.types import answer_types
3533
from google.generativeai.types import retriever_types
3634
from google.generativeai.types.retriever_types import MetadataFilter
3735

@@ -245,7 +243,7 @@ def generate_answer(
245243
safety_settings: safety_types.SafetySettingOptions | None = None,
246244
temperature: float | None = None,
247245
client: glm.GenerativeServiceClient | None = None,
248-
request_options: dict[str, Any] | None = None,
246+
request_options: helper_types.RequestOptionsType | None = None,
249247
):
250248
"""
251249
Calls the GenerateAnswer API and returns a `types.Answer` containing the response.
@@ -318,7 +316,7 @@ async def generate_answer_async(
318316
safety_settings: safety_types.SafetySettingOptions | None = None,
319317
temperature: float | None = None,
320318
client: glm.GenerativeServiceClient | None = None,
321-
request_options: dict[str, Any] | None = None,
319+
request_options: helper_types.RequestOptionsType | None = None,
322320
):
323321
"""
324322
Calls the API and returns a `types.Answer` containing the answer.

google/generativeai/discuss.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.generativeai.client import get_default_discuss_async_client
2727
from google.generativeai import string_utils
2828
from google.generativeai.types import discuss_types
29+
from google.generativeai.types import helper_types
2930
from google.generativeai.types import model_types
3031
from google.generativeai.types import palm_safety_types
3132

@@ -316,7 +317,7 @@ def chat(
316317
top_k: float | None = None,
317318
prompt: discuss_types.MessagePromptOptions | None = None,
318319
client: glm.DiscussServiceClient | None = None,
319-
request_options: dict[str, Any] | None = None,
320+
request_options: helper_types.RequestOptionsType | None = None,
320321
) -> discuss_types.ChatResponse:
321322
"""Calls the API and returns a `types.ChatResponse` containing the response.
322323
@@ -416,7 +417,7 @@ async def chat_async(
416417
top_k: float | None = None,
417418
prompt: discuss_types.MessagePromptOptions | None = None,
418419
client: glm.DiscussServiceAsyncClient | None = None,
419-
request_options: dict[str, Any] | None = None,
420+
request_options: helper_types.RequestOptionsType | None = None,
420421
) -> discuss_types.ChatResponse:
421422
request = _make_generate_message_request(
422423
model=model,
@@ -469,7 +470,7 @@ def last(self, message: discuss_types.MessageOptions):
469470
def reply(
470471
self,
471472
message: discuss_types.MessageOptions,
472-
request_options: dict[str, Any] | None = None,
473+
request_options: helper_types.RequestOptionsType | None = None,
473474
) -> discuss_types.ChatResponse:
474475
if isinstance(self._client, glm.DiscussServiceAsyncClient):
475476
raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
@@ -537,7 +538,7 @@ def _build_chat_response(
537538
def _generate_response(
538539
request: glm.GenerateMessageRequest,
539540
client: glm.DiscussServiceClient | None = None,
540-
request_options: dict[str, Any] | None = None,
541+
request_options: helper_types.RequestOptionsType | None = None,
541542
) -> ChatResponse:
542543
if request_options is None:
543544
request_options = {}
@@ -553,7 +554,7 @@ def _generate_response(
553554
async def _generate_response_async(
554555
request: glm.GenerateMessageRequest,
555556
client: glm.DiscussServiceAsyncClient | None = None,
556-
request_options: dict[str, Any] | None = None,
557+
request_options: helper_types.RequestOptionsType | None = None,
557558
) -> ChatResponse:
558559
if request_options is None:
559560
request_options = {}
@@ -574,7 +575,7 @@ def count_message_tokens(
574575
messages: discuss_types.MessagesOptions | None = None,
575576
model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL,
576577
client: glm.DiscussServiceAsyncClient | None = None,
577-
request_options: dict[str, Any] | None = None,
578+
request_options: helper_types.RequestOptionsType | None = None,
578579
) -> discuss_types.TokenCount:
579580
model = model_types.make_model_name(model)
580581
prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages)

google/generativeai/embedding.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
import dataclasses
18-
from collections.abc import Iterable, Sequence, Mapping
1917
import itertools
2018
from typing import Any, Iterable, overload, TypeVar, Union, Mapping
2119

@@ -24,7 +22,7 @@
2422
from google.generativeai.client import get_default_generative_client
2523
from google.generativeai.client import get_default_generative_async_client
2624

27-
from google.generativeai import string_utils
25+
from google.generativeai.types import helper_types
2826
from google.generativeai.types import text_types
2927
from google.generativeai.types import model_types
3028
from google.generativeai.types import content_types
@@ -104,7 +102,7 @@ def embed_content(
104102
title: str | None = None,
105103
output_dimensionality: int | None = None,
106104
client: glm.GenerativeServiceClient | None = None,
107-
request_options: dict[str, Any] | None = None,
105+
request_options: helper_types.RequestOptionsType | None = None,
108106
) -> text_types.EmbeddingDict: ...
109107

110108

@@ -116,7 +114,7 @@ def embed_content(
116114
title: str | None = None,
117115
output_dimensionality: int | None = None,
118116
client: glm.GenerativeServiceClient | None = None,
119-
request_options: dict[str, Any] | None = None,
117+
request_options: helper_types.RequestOptionsType | None = None,
120118
) -> text_types.BatchEmbeddingDict: ...
121119

122120

@@ -127,7 +125,7 @@ def embed_content(
127125
title: str | None = None,
128126
output_dimensionality: int | None = None,
129127
client: glm.GenerativeServiceClient = None,
130-
request_options: dict[str, Any] | None = None,
128+
request_options: helper_types.RequestOptionsType | None = None,
131129
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
132130
"""Calls the API to create embeddings for content passed in.
133131
@@ -224,7 +222,7 @@ async def embed_content_async(
224222
title: str | None = None,
225223
output_dimensionality: int | None = None,
226224
client: glm.GenerativeServiceAsyncClient | None = None,
227-
request_options: dict[str, Any] | None = None,
225+
request_options: helper_types.RequestOptionsType | None = None,
228226
) -> text_types.EmbeddingDict: ...
229227

230228

@@ -236,7 +234,7 @@ async def embed_content_async(
236234
title: str | None = None,
237235
output_dimensionality: int | None = None,
238236
client: glm.GenerativeServiceAsyncClient | None = None,
239-
request_options: dict[str, Any] | None = None,
237+
request_options: helper_types.RequestOptionsType | None = None,
240238
) -> text_types.BatchEmbeddingDict: ...
241239

242240

@@ -247,7 +245,7 @@ async def embed_content_async(
247245
title: str | None = None,
248246
output_dimensionality: int | None = None,
249247
client: glm.GenerativeServiceAsyncClient = None,
250-
request_options: dict[str, Any] | None = None,
248+
request_options: helper_types.RequestOptionsType | None = None,
251249
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
252250
"""The async version of `genai.embed_content`."""
253251
model = model_types.make_model_name(model)

google/generativeai/generative_models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import google.api_core.exceptions
1616
from google.ai import generativelanguage as glm
1717
from google.generativeai import client
18-
from google.generativeai import string_utils
1918
from google.generativeai.types import content_types
2019
from google.generativeai.types import generation_types
20+
from google.generativeai.types import helper_types
2121
from google.generativeai.types import safety_types
2222

2323

@@ -179,7 +179,7 @@ def generate_content(
179179
stream: bool = False,
180180
tools: content_types.FunctionLibraryType | None = None,
181181
tool_config: content_types.ToolConfigType | None = None,
182-
request_options: dict[str, Any] | None = None,
182+
request_options: helper_types.RequestOptionsType | None = None,
183183
) -> generation_types.GenerateContentResponse:
184184
"""A multipurpose function to generate responses from the model.
185185
@@ -279,7 +279,7 @@ async def generate_content_async(
279279
stream: bool = False,
280280
tools: content_types.FunctionLibraryType | None = None,
281281
tool_config: content_types.ToolConfigType | None = None,
282-
request_options: dict[str, Any] | None = None,
282+
request_options: helper_types.RequestOptionsType | None = None,
283283
) -> generation_types.AsyncGenerateContentResponse:
284284
"""The async version of `GenerativeModel.generate_content`."""
285285
request = self._prepare_request(
@@ -326,7 +326,7 @@ def count_tokens(
326326
safety_settings: safety_types.SafetySettingOptions | None = None,
327327
tools: content_types.FunctionLibraryType | None = None,
328328
tool_config: content_types.ToolConfigType | None = None,
329-
request_options: dict[str, Any] | None = None,
329+
request_options: helper_types.RequestOptionsType | None = None,
330330
) -> glm.CountTokensResponse:
331331
if request_options is None:
332332
request_options = {}
@@ -353,7 +353,7 @@ async def count_tokens_async(
353353
safety_settings: safety_types.SafetySettingOptions | None = None,
354354
tools: content_types.FunctionLibraryType | None = None,
355355
tool_config: content_types.ToolConfigType | None = None,
356-
request_options: dict[str, Any] | None = None,
356+
request_options: helper_types.RequestOptionsType | None = None,
357357
) -> glm.CountTokensResponse:
358358
if request_options is None:
359359
request_options = {}

google/generativeai/models.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.generativeai import operations
2222
from google.generativeai.client import get_default_model_client
2323
from google.generativeai.types import model_types
24+
from google.generativeai.types import helper_types
2425
from google.api_core import operation
2526
from google.api_core import protobuf_helpers
2627
from google.protobuf import field_mask_pb2
@@ -31,7 +32,7 @@ def get_model(
3132
name: model_types.AnyModelNameOptions,
3233
*,
3334
client=None,
34-
request_options: dict[str, Any] | None = None,
35+
request_options: helper_types.RequestOptionsType | None = None,
3536
) -> model_types.Model | model_types.TunedModel:
3637
"""Given a model name, fetch the `types.Model`
3738
@@ -64,7 +65,7 @@ def get_base_model(
6465
name: model_types.BaseModelNameOptions,
6566
*,
6667
client=None,
67-
request_options: dict[str, Any] | None = None,
68+
request_options: helper_types.RequestOptionsType | None = None,
6869
) -> model_types.Model:
6970
"""Get the `types.Model` for the given base model name.
7071
@@ -101,7 +102,7 @@ def get_tuned_model(
101102
name: model_types.TunedModelNameOptions,
102103
*,
103104
client=None,
104-
request_options: dict[str, Any] | None = None,
105+
request_options: helper_types.RequestOptionsType | None = None,
105106
) -> model_types.TunedModel:
106107
"""Get the `types.TunedModel` for the given tuned model name.
107108
@@ -164,7 +165,7 @@ def list_models(
164165
*,
165166
page_size: int | None = 50,
166167
client: glm.ModelServiceClient | None = None,
167-
request_options: dict[str, Any] | None = None,
168+
request_options: helper_types.RequestOptionsType | None = None,
168169
) -> model_types.ModelsIterable:
169170
"""Lists available models.
170171
@@ -198,7 +199,7 @@ def list_tuned_models(
198199
*,
199200
page_size: int | None = 50,
200201
client: glm.ModelServiceClient | None = None,
201-
request_options: dict[str, Any] | None = None,
202+
request_options: helper_types.RequestOptionsType | None = None,
202203
) -> model_types.TunedModelsIterable:
203204
"""Lists available models.
204205
@@ -246,7 +247,7 @@ def create_tuned_model(
246247
input_key: str = "text_input",
247248
output_key: str = "output",
248249
client: glm.ModelServiceClient | None = None,
249-
request_options: dict[str, Any] | None = None,
250+
request_options: helper_types.RequestOptionsType | None = None,
250251
) -> operations.CreateTunedModelOperation:
251252
"""Launches a tuning job to create a TunedModel.
252253
@@ -346,6 +347,7 @@ def create_tuned_model(
346347
top_k=top_k,
347348
tuning_task=tuning_task,
348349
)
350+
349351
operation = client.create_tuned_model(
350352
dict(tuned_model_id=id, tuned_model=tuned_model), **request_options
351353
)
@@ -359,7 +361,7 @@ def update_tuned_model(
359361
updates: None = None,
360362
*,
361363
client: glm.ModelServiceClient | None = None,
362-
request_options: dict[str, Any] | None = None,
364+
request_options: helper_types.RequestOptionsType | None = None,
363365
) -> model_types.TunedModel:
364366
pass
365367

@@ -370,7 +372,7 @@ def update_tuned_model(
370372
updates: dict[str, Any],
371373
*,
372374
client: glm.ModelServiceClient | None = None,
373-
request_options: dict[str, Any] | None = None,
375+
request_options: helper_types.RequestOptionsType | None = None,
374376
) -> model_types.TunedModel:
375377
pass
376378

@@ -380,7 +382,7 @@ def update_tuned_model(
380382
updates: dict[str, Any] | None = None,
381383
*,
382384
client: glm.ModelServiceClient | None = None,
383-
request_options: dict[str, Any] | None = None,
385+
request_options: helper_types.RequestOptionsType | None = None,
384386
) -> model_types.TunedModel:
385387
"""Push updates to the tuned model. Only certain attributes are updatable."""
386388
if request_options is None:
@@ -397,6 +399,7 @@ def update_tuned_model(
397399
"`updates` must be a `dict`.\n"
398400
f"got: {type(updates)}"
399401
)
402+
400403
tuned_model = client.get_tuned_model(name=name, **request_options)
401404

402405
updates = flatten_update_paths(updates)
@@ -438,7 +441,7 @@ def _apply_update(thing, path, value):
438441
def delete_tuned_model(
439442
tuned_model: model_types.TunedModelNameOptions,
440443
client: glm.ModelServiceClient | None = None,
441-
request_options: dict[str, Any] | None = None,
444+
request_options: helper_types.RequestOptionsType | None = None,
442445
) -> None:
443446
if request_options is None:
444447
request_options = {}

0 commit comments

Comments
 (0)