Skip to content

Commit 0dca4ce

Browse files
authored
Add request options to chat. (#341)
* Add request options to chat Change-Id: I6f7e4c980fd7e2a14fec4c3e2d837ad745c69c9a * fix async Change-Id: Ia224e9e8327443a9920ce5d9a877ebb8c272e583 * fix Change-Id: I7eed70131346c7d7ffe435c8f6909f7eb3f7e9f7 * merge from main Change-Id: I4b92a5bc25aa7bf11bfaf31aa6c029096f3e68bc * add tests Change-Id: I368315f220413ba9508012721e64093372555590 * format Change-Id: I26c7fa1f040e7d1ea16068034d78fb9f6cc13db0
1 parent 386994a commit 0dca4ce

File tree

2 files changed

+106
-45
lines changed

2 files changed

+106
-45
lines changed

google/generativeai/generative_models.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ def send_message(
444444
stream: bool = False,
445445
tools: content_types.FunctionLibraryType | None = None,
446446
tool_config: content_types.ToolConfigType | None = None,
447+
request_options: helper_types.RequestOptionsType | None = None,
447448
) -> generation_types.GenerateContentResponse:
448449
"""Sends the conversation history with the added message and returns the model's response.
449450
@@ -476,6 +477,9 @@ def send_message(
476477
safety_settings: Overrides for the model's safety settings.
477478
stream: If True, yield response chunks as they are generated.
478479
"""
480+
if request_options is None:
481+
request_options = {}
482+
479483
if self.enable_automatic_function_calling and stream:
480484
raise NotImplementedError(
481485
"Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`."
@@ -504,6 +508,7 @@ def send_message(
504508
stream=stream,
505509
tools=tools_lib,
506510
tool_config=tool_config,
511+
request_options=request_options,
507512
)
508513

509514
self._check_response(response=response, stream=stream)
@@ -516,6 +521,7 @@ def send_message(
516521
safety_settings=safety_settings,
517522
stream=stream,
518523
tools_lib=tools_lib,
524+
request_options=request_options,
519525
)
520526

521527
self._last_sent = content
@@ -546,7 +552,15 @@ def _get_function_calls(self, response) -> list[glm.FunctionCall]:
546552
return function_calls
547553

548554
def _handle_afc(
549-
self, *, response, history, generation_config, safety_settings, stream, tools_lib
555+
self,
556+
*,
557+
response,
558+
history,
559+
generation_config,
560+
safety_settings,
561+
stream,
562+
tools_lib,
563+
request_options,
550564
) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]:
551565

552566
while function_calls := self._get_function_calls(response):
@@ -572,6 +586,7 @@ def _handle_afc(
572586
safety_settings=safety_settings,
573587
stream=stream,
574588
tools=tools_lib,
589+
request_options=request_options,
575590
)
576591

577592
self._check_response(response=response, stream=stream)
@@ -588,8 +603,12 @@ async def send_message_async(
588603
stream: bool = False,
589604
tools: content_types.FunctionLibraryType | None = None,
590605
tool_config: content_types.ToolConfigType | None = None,
606+
request_options: helper_types.RequestOptionsType | None = None,
591607
) -> generation_types.AsyncGenerateContentResponse:
592608
"""The async version of `ChatSession.send_message`."""
609+
if request_options is None:
610+
request_options = {}
611+
593612
if self.enable_automatic_function_calling and stream:
594613
raise NotImplementedError(
595614
"Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`."
@@ -618,6 +637,7 @@ async def send_message_async(
618637
stream=stream,
619638
tools=tools_lib,
620639
tool_config=tool_config,
640+
request_options=request_options,
621641
)
622642

623643
self._check_response(response=response, stream=stream)
@@ -630,6 +650,7 @@ async def send_message_async(
630650
safety_settings=safety_settings,
631651
stream=stream,
632652
tools_lib=tools_lib,
653+
request_options=request_options,
633654
)
634655

635656
self._last_sent = content
@@ -638,7 +659,15 @@ async def send_message_async(
638659
return response
639660

640661
async def _handle_afc_async(
641-
self, *, response, history, generation_config, safety_settings, stream, tools_lib
662+
self,
663+
*,
664+
response,
665+
history,
666+
generation_config,
667+
safety_settings,
668+
stream,
669+
tools_lib,
670+
request_options,
642671
) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]:
643672

644673
while function_calls := self._get_function_calls(response):
@@ -664,6 +693,7 @@ async def _handle_afc_async(
664693
safety_settings=safety_settings,
665694
stream=stream,
666695
tools=tools_lib,
696+
request_options=request_options,
667697
)
668698

669699
self._check_response(response=response, stream=stream)

tests/test_generative_models.py

Lines changed: 74 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from google.generativeai import generative_models
1313
from google.generativeai.types import content_types
1414
from google.generativeai.types import generation_types
15+
from google.generativeai.types import helper_types
16+
1517

1618
import PIL.Image
1719

@@ -37,49 +39,63 @@ def simple_response(text: str) -> glm.GenerateContentResponse:
3739
return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]})
3840

3941

42+
class MockGenerativeServiceClient:
43+
def __init__(self, test):
44+
self.test = test
45+
self.observed_requests = []
46+
self.observed_kwargs = []
47+
self.responses = collections.defaultdict(list)
48+
49+
def generate_content(
50+
self,
51+
request: glm.GenerateContentRequest,
52+
**kwargs,
53+
) -> glm.GenerateContentResponse:
54+
self.test.assertIsInstance(request, glm.GenerateContentRequest)
55+
self.observed_requests.append(request)
56+
self.observed_kwargs.append(kwargs)
57+
response = self.responses["generate_content"].pop(0)
58+
return response
59+
60+
def stream_generate_content(
61+
self,
62+
request: glm.GetModelRequest,
63+
**kwargs,
64+
) -> Iterable[glm.GenerateContentResponse]:
65+
self.observed_requests.append(request)
66+
self.observed_kwargs.append(kwargs)
67+
response = self.responses["stream_generate_content"].pop(0)
68+
return response
69+
70+
def count_tokens(
71+
self,
72+
request: glm.CountTokensRequest,
73+
**kwargs,
74+
) -> Iterable[glm.GenerateContentResponse]:
75+
self.observed_requests.append(request)
76+
self.observed_kwargs.append(kwargs)
77+
response = self.responses["count_tokens"].pop(0)
78+
return response
79+
80+
4081
class CUJTests(parameterized.TestCase):
4182
"""Tests are in order with the design doc."""
4283

43-
def setUp(self):
44-
self.client = unittest.mock.MagicMock()
84+
@property
85+
def observed_requests(self):
86+
return self.client.observed_requests
4587

46-
client_lib._client_manager.clients["generative"] = self.client
47-
48-
def add_client_method(f):
49-
name = f.__name__
50-
setattr(self.client, name, f)
51-
return f
88+
@property
89+
def observed_kwargs(self):
90+
return self.client.observed_kwargs
5291

53-
self.observed_requests = []
54-
self.responses = collections.defaultdict(list)
92+
@property
93+
def responses(self):
94+
return self.client.responses
5595

56-
@add_client_method
57-
def generate_content(
58-
request: glm.GenerateContentRequest,
59-
**kwargs,
60-
) -> glm.GenerateContentResponse:
61-
self.assertIsInstance(request, glm.GenerateContentRequest)
62-
self.observed_requests.append(request)
63-
response = self.responses["generate_content"].pop(0)
64-
return response
65-
66-
@add_client_method
67-
def stream_generate_content(
68-
request: glm.GetModelRequest,
69-
**kwargs,
70-
) -> Iterable[glm.GenerateContentResponse]:
71-
self.observed_requests.append(request)
72-
response = self.responses["stream_generate_content"].pop(0)
73-
return response
74-
75-
@add_client_method
76-
def count_tokens(
77-
request: glm.CountTokensRequest,
78-
**kwargs,
79-
) -> Iterable[glm.GenerateContentResponse]:
80-
self.observed_requests.append(request)
81-
response = self.responses["count_tokens"].pop(0)
82-
return response
96+
def setUp(self):
97+
self.client = MockGenerativeServiceClient(self)
98+
client_lib._client_manager.clients["generative"] = self.client
8399

84100
def test_hello(self):
85101
# Generate text from text prompt
@@ -451,7 +467,7 @@ def test_copy_history(self):
451467
chat1 = model.start_chat()
452468
chat1.send_message("hello1")
453469

454-
chat2 = copy.deepcopy(chat1)
470+
chat2 = copy.copy(chat1)
455471
chat2.send_message("hello2")
456472

457473
chat1.send_message("hello3")
@@ -810,7 +826,7 @@ def test_async_code_match(self, obj, aobj):
810826
)
811827

812828
asource = re.sub(" *?# type: ignore", "", asource)
813-
self.assertEqual(source, asource)
829+
self.assertEqual(source, asource, f"error in {obj=}")
814830

815831
def test_repr_for_unary_non_streamed_response(self):
816832
model = generative_models.GenerativeModel(model_name="gemini-pro")
@@ -1208,15 +1224,30 @@ def test_repr_for_system_instruction(self):
12081224
self.assertIn("system_instruction='Be excellent.'", result)
12091225

12101226
def test_count_tokens_called_with_request_options(self):
1211-
self.client.count_tokens = unittest.mock.MagicMock()
1212-
request = unittest.mock.ANY
1227+
self.responses["count_tokens"].append(glm.CountTokensResponse())
12131228
request_options = {"timeout": 120}
12141229

1215-
self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
12161230
model = generative_models.GenerativeModel("gemini-pro-vision")
12171231
model.count_tokens([{"role": "user", "parts": ["hello"]}], request_options=request_options)
12181232

1219-
self.client.count_tokens.assert_called_once_with(request, **request_options)
1233+
self.assertEqual(request_options, self.observed_kwargs[0])
1234+
1235+
def test_chat_with_request_options(self):
1236+
self.responses["generate_content"].append(
1237+
glm.GenerateContentResponse(
1238+
{
1239+
"candidates": [{"finish_reason": "STOP"}],
1240+
}
1241+
)
1242+
)
1243+
request_options = {"timeout": 120}
1244+
1245+
model = generative_models.GenerativeModel("gemini-pro")
1246+
chat = model.start_chat()
1247+
chat.send_message("hello", request_options=helper_types.RequestOptions(**request_options))
1248+
1249+
request_options["retry"] = None
1250+
self.assertEqual(request_options, self.observed_kwargs[0])
12201251

12211252

12221253
if __name__ == "__main__":

0 commit comments

Comments
 (0)