Skip to content

Commit d894807

Browse files
MarkDaoustmarkmcd
andauthored
Automatic function calling. (#201)
* Starting automatic function calling * Working on AFC * Fix typos * Add tools overrides for generate_content and send_message * Add initial AFC loop. * Basic debugging, streaming's probably broken. * Add error with stream=True * format * add pydantic * fix tests * replace __init__ * Fix pytype * Remove property * format * working on it * working on it * working on it * format * Add test for schema gen * Split test * Fix type anno & classmethod * fixup: black * Fix mutable defaults. * Fix mutable defaults --------- Co-authored-by: Mark McDonald <[email protected]>
1 parent b28f2d1 commit d894807

File tree

6 files changed

+747
-63
lines changed

6 files changed

+747
-63
lines changed

google/generativeai/generative_models.py

Lines changed: 158 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
model_name: str = "gemini-pro",
7272
safety_settings: safety_types.SafetySettingOptions | None = None,
7373
generation_config: generation_types.GenerationConfigType | None = None,
74-
tools: content_types.ToolsType = None,
74+
tools: content_types.FunctionLibraryType | None = None,
7575
):
7676
if "/" not in model_name:
7777
model_name = "models/" + model_name
@@ -80,7 +80,7 @@ def __init__(
8080
safety_settings, harm_category_set="new"
8181
)
8282
self._generation_config = generation_types.to_generation_config_dict(generation_config)
83-
self._tools = content_types.to_tools(tools)
83+
self._tools = content_types.to_function_library(tools)
8484

8585
self._client = None
8686
self._async_client = None
@@ -94,8 +94,9 @@ def __str__(self):
9494
f"""\
9595
genai.GenerativeModel(
9696
model_name='{self.model_name}',
97-
generation_config={self._generation_config}.
98-
safety_settings={self._safety_settings}
97+
generation_config={self._generation_config},
98+
safety_settings={self._safety_settings},
99+
tools={self._tools},
99100
)"""
100101
)
101102

@@ -107,12 +108,16 @@ def _prepare_request(
107108
contents: content_types.ContentsType,
108109
generation_config: generation_types.GenerationConfigType | None = None,
109110
safety_settings: safety_types.SafetySettingOptions | None = None,
110-
**kwargs,
111+
tools: content_types.FunctionLibraryType | None,
111112
) -> glm.GenerateContentRequest:
112113
"""Creates a `glm.GenerateContentRequest` from raw inputs."""
113114
if not contents:
114115
raise TypeError("contents must not be empty")
115116

117+
tools_lib = self._get_tools_lib(tools)
118+
if tools_lib is not None:
119+
tools_lib = tools_lib.to_proto()
120+
116121
contents = content_types.to_contents(contents)
117122

118123
generation_config = generation_types.to_generation_config_dict(generation_config)
@@ -129,19 +134,26 @@ def _prepare_request(
129134
contents=contents,
130135
generation_config=merged_gc,
131136
safety_settings=merged_ss,
132-
tools=self._tools,
133-
**kwargs,
137+
tools=tools_lib,
134138
)
135139

140+
def _get_tools_lib(
141+
self, tools: content_types.FunctionLibraryType
142+
) -> content_types.FunctionLibrary | None:
143+
if tools is None:
144+
return self._tools
145+
else:
146+
return content_types.to_function_library(tools)
147+
136148
def generate_content(
137149
self,
138150
contents: content_types.ContentsType,
139151
*,
140152
generation_config: generation_types.GenerationConfigType | None = None,
141153
safety_settings: safety_types.SafetySettingOptions | None = None,
142154
stream: bool = False,
155+
tools: content_types.FunctionLibraryType | None = None,
143156
request_options: dict[str, Any] | None = None,
144-
**kwargs,
145157
) -> generation_types.GenerateContentResponse:
146158
"""A multipurpose function to generate responses from the model.
147159
@@ -201,7 +213,7 @@ def generate_content(
201213
contents=contents,
202214
generation_config=generation_config,
203215
safety_settings=safety_settings,
204-
**kwargs,
216+
tools=tools,
205217
)
206218
if self._client is None:
207219
self._client = client.get_default_generative_client()
@@ -230,15 +242,15 @@ async def generate_content_async(
230242
generation_config: generation_types.GenerationConfigType | None = None,
231243
safety_settings: safety_types.SafetySettingOptions | None = None,
232244
stream: bool = False,
245+
tools: content_types.FunctionLibraryType | None = None,
233246
request_options: dict[str, Any] | None = None,
234-
**kwargs,
235247
) -> generation_types.AsyncGenerateContentResponse:
236248
"""The async version of `GenerativeModel.generate_content`."""
237249
request = self._prepare_request(
238250
contents=contents,
239251
generation_config=generation_config,
240252
safety_settings=safety_settings,
241-
**kwargs,
253+
tools=tools,
242254
)
243255
if self._async_client is None:
244256
self._async_client = client.get_default_generative_async_client()
@@ -299,6 +311,7 @@ def start_chat(
299311
self,
300312
*,
301313
history: Iterable[content_types.StrictContentType] | None = None,
314+
enable_automatic_function_calling: bool = False,
302315
) -> ChatSession:
303316
"""Returns a `genai.ChatSession` attached to this model.
304317
@@ -314,6 +327,7 @@ def start_chat(
314327
return ChatSession(
315328
model=self,
316329
history=history,
330+
enable_automatic_function_calling=enable_automatic_function_calling,
317331
)
318332

319333

@@ -341,11 +355,13 @@ def __init__(
341355
self,
342356
model: GenerativeModel,
343357
history: Iterable[content_types.StrictContentType] | None = None,
358+
enable_automatic_function_calling: bool = False,
344359
):
345360
self.model: GenerativeModel = model
346361
self._history: list[glm.Content] = content_types.to_contents(history)
347362
self._last_sent: glm.Content | None = None
348363
self._last_received: generation_types.BaseGenerateContentResponse | None = None
364+
self.enable_automatic_function_calling = enable_automatic_function_calling
349365

350366
def send_message(
351367
self,
@@ -354,7 +370,7 @@ def send_message(
354370
generation_config: generation_types.GenerationConfigType = None,
355371
safety_settings: safety_types.SafetySettingOptions = None,
356372
stream: bool = False,
357-
**kwargs,
373+
tools: content_types.FunctionLibraryType | None = None,
358374
) -> generation_types.GenerateContentResponse:
359375
"""Sends the conversation history with the added message and returns the model's response.
360376
@@ -387,23 +403,52 @@ def send_message(
387403
safety_settings: Overrides for the model's safety settings.
388404
stream: If True, yield response chunks as they are generated.
389405
"""
406+
if self.enable_automatic_function_calling and stream:
407+
raise NotImplementedError(
408+
"The `google.generativeai` SDK does not yet support `stream=True` with "
409+
"`enable_automatic_function_calling=True`"
410+
)
411+
412+
tools_lib = self.model._get_tools_lib(tools)
413+
390414
content = content_types.to_content(content)
415+
391416
if not content.role:
392417
content.role = self._USER_ROLE
418+
393419
history = self.history[:]
394420
history.append(content)
395421

396422
generation_config = generation_types.to_generation_config_dict(generation_config)
397423
if generation_config.get("candidate_count", 1) > 1:
398424
raise ValueError("Can't chat with `candidate_count > 1`")
425+
399426
response = self.model.generate_content(
400427
contents=history,
401428
generation_config=generation_config,
402429
safety_settings=safety_settings,
403430
stream=stream,
404-
**kwargs,
431+
tools=tools_lib,
405432
)
406433

434+
self._check_response(response=response, stream=stream)
435+
436+
if self.enable_automatic_function_calling and tools_lib is not None:
437+
self.history, content, response = self._handle_afc(
438+
response=response,
439+
history=history,
440+
generation_config=generation_config,
441+
safety_settings=safety_settings,
442+
stream=stream,
443+
tools_lib=tools_lib,
444+
)
445+
446+
self._last_sent = content
447+
self._last_received = response
448+
449+
return response
450+
451+
def _check_response(self, *, response, stream):
407452
if response.prompt_feedback.block_reason:
408453
raise generation_types.BlockedPromptException(response.prompt_feedback)
409454

@@ -415,10 +460,49 @@ def send_message(
415460
):
416461
raise generation_types.StopCandidateException(response.candidates[0])
417462

418-
self._last_sent = content
419-
self._last_received = response
463+
def _get_function_calls(self, response) -> list[glm.FunctionCall]:
464+
candidates = response.candidates
465+
if len(candidates) != 1:
466+
raise ValueError(
467+
f"Automatic function calling only works with 1 candidate, got: {len(candidates)}"
468+
)
469+
parts = candidates[0].content.parts
470+
function_calls = [part.function_call for part in parts if part and "function_call" in part]
471+
return function_calls
472+
473+
def _handle_afc(
474+
self, *, response, history, generation_config, safety_settings, stream, tools_lib
475+
) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]:
476+
477+
while function_calls := self._get_function_calls(response):
478+
if not all(callable(tools_lib[fc]) for fc in function_calls):
479+
break
480+
history.append(response.candidates[0].content)
481+
482+
function_response_parts: list[glm.Part] = []
483+
for fc in function_calls:
484+
fr = tools_lib(fc)
485+
assert fr is not None, (
486+
"This should never happen, it should only return None if the declaration"
487+
"is not callable, and that's guarded against above."
488+
)
489+
function_response_parts.append(fr)
420490

421-
return response
491+
send = glm.Content(role=self._USER_ROLE, parts=function_response_parts)
492+
history.append(send)
493+
494+
response = self.model.generate_content(
495+
contents=history,
496+
generation_config=generation_config,
497+
safety_settings=safety_settings,
498+
stream=stream,
499+
tools=tools_lib,
500+
)
501+
502+
self._check_response(response=response, stream=stream)
503+
504+
*history, content = history
505+
return history, content, response
422506

423507
async def send_message_async(
424508
self,
@@ -427,42 +511,88 @@ async def send_message_async(
427511
generation_config: generation_types.GenerationConfigType = None,
428512
safety_settings: safety_types.SafetySettingOptions = None,
429513
stream: bool = False,
430-
**kwargs,
514+
tools: content_types.FunctionLibraryType | None = None,
431515
) -> generation_types.AsyncGenerateContentResponse:
432516
"""The async version of `ChatSession.send_message`."""
517+
if self.enable_automatic_function_calling and stream:
518+
raise NotImplementedError(
519+
"The `google.generativeai` SDK does not yet support `stream=True` with "
520+
"`enable_automatic_function_calling=True`"
521+
)
522+
523+
tools_lib = self.model._get_tools_lib(tools)
524+
433525
content = content_types.to_content(content)
526+
434527
if not content.role:
435528
content.role = self._USER_ROLE
529+
436530
history = self.history[:]
437531
history.append(content)
438532

439533
generation_config = generation_types.to_generation_config_dict(generation_config)
440534
if generation_config.get("candidate_count", 1) > 1:
441535
raise ValueError("Can't chat with `candidate_count > 1`")
442-
response = await self.model.generate_content_async(
536+
537+
response = await self.model.generate_content(
443538
contents=history,
444539
generation_config=generation_config,
445540
safety_settings=safety_settings,
446541
stream=stream,
447-
**kwargs,
542+
tools=tools_lib,
448543
)
449544

450-
if response.prompt_feedback.block_reason:
451-
raise generation_types.BlockedPromptException(response.prompt_feedback)
545+
self._check_response(response=response, stream=stream)
452546

453-
if not stream:
454-
if response.candidates[0].finish_reason not in (
455-
glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED,
456-
glm.Candidate.FinishReason.STOP,
457-
glm.Candidate.FinishReason.MAX_TOKENS,
458-
):
459-
raise generation_types.StopCandidateException(response.candidates[0])
547+
if self.enable_automatic_function_calling and tools_lib is not None:
548+
self.history, content, response = await self._handle_afc_async(
549+
response=response,
550+
history=history,
551+
generation_config=generation_config,
552+
safety_settings=safety_settings,
553+
stream=stream,
554+
tools_lib=tools_lib,
555+
)
460556

461557
self._last_sent = content
462558
self._last_received = response
463559

464560
return response
465561

562+
async def _handle_afc_async(
563+
self, *, response, history, generation_config, safety_settings, stream, tools_lib
564+
) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]:
565+
566+
while function_calls := self._get_function_calls(response):
567+
if not all(callable(tools_lib[fc]) for fc in function_calls):
568+
break
569+
history.append(response.candidates[0].content)
570+
571+
function_response_parts: list[glm.Part] = []
572+
for fc in function_calls:
573+
fr = tools_lib(fc)
574+
assert fr is not None, (
575+
"This should never happen, it should only return None if the declaration"
576+
"is not callable, and that's guarded against above."
577+
)
578+
function_response_parts.append(fr)
579+
580+
send = glm.Content(role=self._USER_ROLE, parts=function_response_parts)
581+
history.append(send)
582+
583+
response = await self.model.generate_content_async(
584+
contents=history,
585+
generation_config=generation_config,
586+
safety_settings=safety_settings,
587+
stream=stream,
588+
tools=tools_lib,
589+
)
590+
591+
self._check_response(response=response, stream=stream)
592+
593+
*history, content = history
594+
return history, content, response
595+
466596
def __copy__(self):
467597
return ChatSession(
468598
model=self.model,

0 commit comments

Comments
 (0)