Skip to content

Commit 7d2c17e

Browse files
Function calling mode patch (#271)
* Handle function_calling_mode when passed as a dict with allowed_func_names * Add tests for tool_config * Update alias FunctionCallingConfigType * format * Remove unnecessary test functions * Replace alias with concrete data types for instance checks
1 parent ba6b439 commit 7d2c17e

File tree

3 files changed

+202
-2
lines changed

3 files changed

+202
-2
lines changed

google/generativeai/types/content_types.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,12 +698,25 @@ class FunctionCallingConfigDict(TypedDict):
698698
allowed_function_names: list[str]
699699

700700

701-
FunctionCallingConfigType = Union[FunctionCallingConfigDict, glm.FunctionCallingConfig]
701+
FunctionCallingConfigType = Union[
702+
FunctionCallingModeType, FunctionCallingConfigDict, glm.FunctionCallingConfig
703+
]
702704

703705

704706
def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig:
705-
if isinstance(obj, (FunctionCallingMode, str, int)):
707+
if isinstance(obj, glm.FunctionCallingConfig):
708+
return obj
709+
elif isinstance(obj, (FunctionCallingMode, str, int)):
706710
obj = {"mode": to_function_calling_mode(obj)}
711+
elif isinstance(obj, dict):
712+
obj = obj.copy()
713+
mode = obj.pop("mode")
714+
obj["mode"] = to_function_calling_mode(mode)
715+
else:
716+
raise TypeError(
717+
f"Could not convert input to `glm.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n",
718+
obj,
719+
)
707720

708721
return glm.FunctionCallingConfig(obj)
709722

tests/test_generative_models.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,99 @@ def test_tools(self):
613613
self.assertLen(obr.tools, 1)
614614
self.assertEqual(type(obr.tools[0]).to_dict(obr.tools[0]), tools)
615615

616+
@parameterized.named_parameters(
617+
dict(
618+
testcase_name="test_FunctionCallingMode_str",
619+
tool_config={"function_calling_config": "any"},
620+
expected_tool_config={
621+
"function_calling_config": {
622+
"mode": content_types.FunctionCallingMode.ANY,
623+
"allowed_function_names": [],
624+
}
625+
},
626+
),
627+
dict(
628+
testcase_name="test_FunctionCallingMode_int",
629+
tool_config={"function_calling_config": 1},
630+
expected_tool_config={
631+
"function_calling_config": {
632+
"mode": content_types.FunctionCallingMode.AUTO,
633+
"allowed_function_names": [],
634+
}
635+
},
636+
),
637+
dict(
638+
testcase_name="test_FunctionCallingMode",
639+
tool_config={"function_calling_config": content_types.FunctionCallingMode.NONE},
640+
expected_tool_config={
641+
"function_calling_config": {
642+
"mode": content_types.FunctionCallingMode.NONE,
643+
"allowed_function_names": [],
644+
}
645+
},
646+
),
647+
dict(
648+
testcase_name="test_glm_FunctionCallingConfig",
649+
tool_config={
650+
"function_calling_config": glm.FunctionCallingConfig(
651+
mode=content_types.FunctionCallingMode.AUTO
652+
)
653+
},
654+
expected_tool_config={
655+
"function_calling_config": {
656+
"mode": content_types.FunctionCallingMode.AUTO,
657+
"allowed_function_names": [],
658+
}
659+
},
660+
),
661+
dict(
662+
testcase_name="test_FunctionCallingConfigDict",
663+
tool_config={
664+
"function_calling_config": {
665+
"mode": "mode_auto",
666+
"allowed_function_names": ["datetime", "greetings", "random"],
667+
}
668+
},
669+
expected_tool_config={
670+
"function_calling_config": {
671+
"mode": content_types.FunctionCallingMode.AUTO,
672+
"allowed_function_names": ["datetime", "greetings", "random"],
673+
}
674+
},
675+
),
676+
dict(
677+
testcase_name="test_glm_ToolConfig",
678+
tool_config=glm.ToolConfig(
679+
function_calling_config=glm.FunctionCallingConfig(
680+
mode=content_types.FunctionCallingMode.NONE
681+
)
682+
),
683+
expected_tool_config={
684+
"function_calling_config": {
685+
"mode": content_types.FunctionCallingMode.NONE,
686+
"allowed_function_names": [],
687+
}
688+
},
689+
),
690+
)
691+
def test_tool_config(self, tool_config, expected_tool_config):
692+
tools = dict(
693+
function_declarations=[
694+
dict(name="datetime", description="Returns the current UTC date and time."),
695+
dict(name="greetings", description="Returns a greeting."),
696+
dict(name="random", description="Returns a random number."),
697+
]
698+
)
699+
self.responses["generate_content"] = [simple_response("echo echo")]
700+
701+
model = generative_models.GenerativeModel("gemini-pro", tools=tools)
702+
_ = model.generate_content("Hello", tools=[tools], tool_config=tool_config)
703+
704+
req = self.observed_requests[0]
705+
706+
self.assertLen(type(req.tools[0]).to_dict(req.tools[0]).get("function_declarations"), 3)
707+
self.assertEqual(type(req.tool_config).to_dict(req.tool_config), expected_tool_config)
708+
616709
@parameterized.named_parameters(
617710
["bare_str", "talk like a pirate", simple_part("talk like a pirate")],
618711
[

tests/test_generative_models_async.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from google.generativeai import client as client_lib
2525
from google.generativeai import generative_models
26+
from google.generativeai.types import content_types
2627
import google.ai.generativelanguage as glm
2728

2829
from absl.testing import absltest
@@ -107,6 +108,99 @@ async def responses():
107108

108109
self.assertEqual(response.text, "world!")
109110

111+
@parameterized.named_parameters(
112+
dict(
113+
testcase_name="test_FunctionCallingMode_str",
114+
tool_config={"function_calling_config": "any"},
115+
expected_tool_config={
116+
"function_calling_config": {
117+
"mode": content_types.FunctionCallingMode.ANY,
118+
"allowed_function_names": [],
119+
}
120+
},
121+
),
122+
dict(
123+
testcase_name="test_FunctionCallingMode_int",
124+
tool_config={"function_calling_config": 1},
125+
expected_tool_config={
126+
"function_calling_config": {
127+
"mode": content_types.FunctionCallingMode.AUTO,
128+
"allowed_function_names": [],
129+
}
130+
},
131+
),
132+
dict(
133+
testcase_name="test_FunctionCallingMode",
134+
tool_config={"function_calling_config": content_types.FunctionCallingMode.NONE},
135+
expected_tool_config={
136+
"function_calling_config": {
137+
"mode": content_types.FunctionCallingMode.NONE,
138+
"allowed_function_names": [],
139+
}
140+
},
141+
),
142+
dict(
143+
testcase_name="test_glm_FunctionCallingConfig",
144+
tool_config={
145+
"function_calling_config": glm.FunctionCallingConfig(
146+
mode=content_types.FunctionCallingMode.AUTO
147+
)
148+
},
149+
expected_tool_config={
150+
"function_calling_config": {
151+
"mode": content_types.FunctionCallingMode.AUTO,
152+
"allowed_function_names": [],
153+
}
154+
},
155+
),
156+
dict(
157+
testcase_name="test_FunctionCallingConfigDict",
158+
tool_config={
159+
"function_calling_config": {
160+
"mode": "mode_auto",
161+
"allowed_function_names": ["datetime", "greetings", "random"],
162+
}
163+
},
164+
expected_tool_config={
165+
"function_calling_config": {
166+
"mode": content_types.FunctionCallingMode.AUTO,
167+
"allowed_function_names": ["datetime", "greetings", "random"],
168+
}
169+
},
170+
),
171+
dict(
172+
testcase_name="test_glm_ToolConfig",
173+
tool_config=glm.ToolConfig(
174+
function_calling_config=glm.FunctionCallingConfig(
175+
mode=content_types.FunctionCallingMode.NONE
176+
)
177+
),
178+
expected_tool_config={
179+
"function_calling_config": {
180+
"mode": content_types.FunctionCallingMode.NONE,
181+
"allowed_function_names": [],
182+
}
183+
},
184+
),
185+
)
186+
async def test_tool_config(self, tool_config, expected_tool_config):
187+
tools = dict(
188+
function_declarations=[
189+
dict(name="datetime", description="Returns the current UTC date and time."),
190+
dict(name="greetings", description="Returns a greeting."),
191+
dict(name="random", description="Returns a random number."),
192+
]
193+
)
194+
self.responses["generate_content"] = [simple_response("echo echo")]
195+
196+
model = generative_models.GenerativeModel("gemini-pro", tools=tools)
197+
_ = await model.generate_content_async("Hello", tools=[tools], tool_config=tool_config)
198+
199+
req = self.observed_requests[0]
200+
201+
self.assertLen(type(req.tools[0]).to_dict(req.tools[0]).get("function_declarations"), 3)
202+
self.assertEqual(type(req.tool_config).to_dict(req.tool_config), expected_tool_config)
203+
110204
@parameterized.named_parameters(
111205
["basic", "Hello"],
112206
["list", ["Hello"]],

0 commit comments

Comments
 (0)