Skip to content

Commit fcec2a8

Browse files
authored
Add basic SI and tool_config support (#257)
* Add basic SI and tool_config support * fix tool_config=None * Fix format. * Fix types
1 parent 6967517 commit fcec2a8

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed

google/generativeai/generative_models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def __init__(
7373
safety_settings: safety_types.SafetySettingOptions | None = None,
7474
generation_config: generation_types.GenerationConfigType | None = None,
7575
tools: content_types.FunctionLibraryType | None = None,
76+
tool_config: content_types.ToolConfigType | None = None,
77+
system_instructions: content_types.ContentType | None = None,
7678
):
7779
if "/" not in model_name:
7880
model_name = "models/" + model_name
@@ -83,6 +85,16 @@ def __init__(
8385
self._generation_config = generation_types.to_generation_config_dict(generation_config)
8486
self._tools = content_types.to_function_library(tools)
8587

88+
if tool_config is None:
89+
self._tool_config = None
90+
else:
91+
self._tool_config = content_types.to_tool_config(tool_config)
92+
93+
if system_instructions is None:
94+
self._system_instructions = None
95+
else:
96+
self._system_instructions = content_types.to_content(system_instructions)
97+
8698
self._client = None
8799
self._async_client = None
88100

@@ -110,6 +122,7 @@ def _prepare_request(
110122
generation_config: generation_types.GenerationConfigType | None = None,
111123
safety_settings: safety_types.SafetySettingOptions | None = None,
112124
tools: content_types.FunctionLibraryType | None,
125+
tool_config: content_types.ToolConfigType | None,
113126
) -> glm.GenerateContentRequest:
114127
"""Creates a `glm.GenerateContentRequest` from raw inputs."""
115128
if not contents:
@@ -119,6 +132,11 @@ def _prepare_request(
119132
if tools_lib is not None:
120133
tools_lib = tools_lib.to_proto()
121134

135+
if tool_config is None:
136+
tool_config = self._tool_config
137+
else:
138+
tool_config = content_types.to_tool_config(tool_config)
139+
122140
contents = content_types.to_contents(contents)
123141

124142
generation_config = generation_types.to_generation_config_dict(generation_config)
@@ -136,6 +154,8 @@ def _prepare_request(
136154
generation_config=merged_gc,
137155
safety_settings=merged_ss,
138156
tools=tools_lib,
157+
tool_config=tool_config,
158+
system_instructions=self._system_instructions,
139159
)
140160

141161
def _get_tools_lib(
@@ -154,6 +174,7 @@ def generate_content(
154174
safety_settings: safety_types.SafetySettingOptions | None = None,
155175
stream: bool = False,
156176
tools: content_types.FunctionLibraryType | None = None,
177+
tool_config: content_types.ToolConfigType | None = None,
157178
request_options: dict[str, Any] | None = None,
158179
) -> generation_types.GenerateContentResponse:
159180
"""A multipurpose function to generate responses from the model.
@@ -215,6 +236,7 @@ def generate_content(
215236
generation_config=generation_config,
216237
safety_settings=safety_settings,
217238
tools=tools,
239+
tool_config=tool_config,
218240
)
219241
if self._client is None:
220242
self._client = client.get_default_generative_client()
@@ -252,6 +274,7 @@ async def generate_content_async(
252274
safety_settings: safety_types.SafetySettingOptions | None = None,
253275
stream: bool = False,
254276
tools: content_types.FunctionLibraryType | None = None,
277+
tool_config: content_types.ToolConfigType | None = None,
255278
request_options: dict[str, Any] | None = None,
256279
) -> generation_types.AsyncGenerateContentResponse:
257280
"""The async version of `GenerativeModel.generate_content`."""
@@ -260,6 +283,7 @@ async def generate_content_async(
260283
generation_config=generation_config,
261284
safety_settings=safety_settings,
262285
tools=tools,
286+
tool_config=tool_config,
263287
)
264288
if self._async_client is None:
265289
self._async_client = client.get_default_generative_async_client()
@@ -388,6 +412,7 @@ def send_message(
388412
safety_settings: safety_types.SafetySettingOptions = None,
389413
stream: bool = False,
390414
tools: content_types.FunctionLibraryType | None = None,
415+
tool_config: content_types.ToolConfigType | None = None,
391416
) -> generation_types.GenerateContentResponse:
392417
"""Sends the conversation history with the added message and returns the model's response.
393418
@@ -446,6 +471,7 @@ def send_message(
446471
safety_settings=safety_settings,
447472
stream=stream,
448473
tools=tools_lib,
474+
tool_config=tool_config,
449475
)
450476

451477
self._check_response(response=response, stream=stream)
@@ -529,6 +555,7 @@ async def send_message_async(
529555
safety_settings: safety_types.SafetySettingOptions = None,
530556
stream: bool = False,
531557
tools: content_types.FunctionLibraryType | None = None,
558+
tool_config: content_types.ToolConfigType | None = None,
532559
) -> generation_types.AsyncGenerateContentResponse:
533560
"""The async version of `ChatSession.send_message`."""
534561
if self.enable_automatic_function_calling and stream:
@@ -557,6 +584,7 @@ async def send_message_async(
557584
safety_settings=safety_settings,
558585
stream=stream,
559586
tools=tools_lib,
587+
tool_config=tool_config,
560588
)
561589

562590
self._check_response(response=response, stream=stream)

google/generativeai/types/content_types.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,3 +661,69 @@ def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | No
661661
return lib
662662
else:
663663
return FunctionLibrary(tools=lib)
664+
665+
666+
FunctionCallingMode = glm.FunctionCallingConfig.Mode
667+
668+
# fmt: off
669+
_FUNCTION_CALLING_MODE = {
670+
1: FunctionCallingMode.AUTO,
671+
FunctionCallingMode.AUTO: FunctionCallingMode.AUTO,
672+
"mode_auto": FunctionCallingMode.AUTO,
673+
"auto": FunctionCallingMode.AUTO,
674+
675+
2: FunctionCallingMode.ANY,
676+
FunctionCallingMode.ANY: FunctionCallingMode.ANY,
677+
"mode_any": FunctionCallingMode.ANY,
678+
"any": FunctionCallingMode.ANY,
679+
680+
3: FunctionCallingMode.NONE,
681+
FunctionCallingMode.NONE: FunctionCallingMode.NONE,
682+
"mode_none": FunctionCallingMode.NONE,
683+
"none": FunctionCallingMode.NONE,
684+
}
685+
# fmt: on
686+
687+
FunctionCallingModeType = Union[FunctionCallingMode, str, int]
688+
689+
690+
def to_function_calling_mode(x: FunctionCallingModeType) -> FunctionCallingMode:
691+
if isinstance(x, str):
692+
x = x.lower()
693+
return _FUNCTION_CALLING_MODE[x]
694+
695+
696+
class FunctionCallingConfigDict(TypedDict):
697+
mode: FunctionCallingModeType
698+
allowed_function_names: list[str]
699+
700+
701+
FunctionCallingConfigType = Union[FunctionCallingConfigDict, glm.FunctionCallingConfig]
702+
703+
704+
def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig:
705+
if isinstance(obj, (FunctionCallingMode, str, int)):
706+
obj = {"mode": to_function_calling_mode(obj)}
707+
708+
return glm.FunctionCallingConfig(obj)
709+
710+
711+
class ToolConfigDict:
712+
function_calling_config: FunctionCallingConfigType
713+
714+
715+
ToolConfigType = Union[ToolConfigDict, glm.ToolConfig]
716+
717+
718+
def to_tool_config(obj: ToolConfigType) -> glm.ToolConfig:
719+
if isinstance(obj, glm.ToolConfig):
720+
return obj
721+
elif isinstance(obj, dict):
722+
fcc = obj.pop("function_calling_config")
723+
fcc = to_function_calling_config(fcc)
724+
obj["function_calling_config"] = fcc
725+
return glm.ToolConfig(**obj)
726+
else:
727+
raise TypeError(
728+
f"Could not convert input to `glm.ToolConfig`: \n'" f" type: {type(obj)}\n", obj
729+
)

0 commit comments

Comments
 (0)