Skip to content

Commit 12c25ff

Browse files
Updated tests and current progress on adding search grounding.
1 parent 4f42118 commit 12c25ff

File tree

2 files changed

+92
-3
lines changed

2 files changed

+92
-3
lines changed

google/generativeai/types/content_types.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,27 @@
7171
"FunctionLibraryType",
7272
]
7373

74+
Mode = protos.DynamicRetrievalConfig.Mode
75+
76+
ModeOptions = Union[str, str, Mode]
77+
78+
_MODE: dict[ModeOptions, Mode] = {
79+
Mode.MODE_UNSPECIFIED: Mode.MODE_UNSPECIFIED,
80+
0: Mode.MODE_UNSPECIFIED,
81+
"mode_unspecified": Mode.MODE_UNSPECIFIED,
82+
"unspecified": Mode.MODE_UNSPECIFIED,
83+
Mode.DYNAMIC: Mode.DYNAMIC,
84+
1: Mode.DYNAMIC,
85+
"mode_dynamic": Mode.DYNAMIC,
86+
"dynamic": Mode.DYNAMIC,
87+
}
88+
89+
90+
def to_mode(x: ModeOptions) -> Mode:
91+
if isinstance(x, str):
92+
x = x.lower()
93+
return _MODE[x]
94+
7495

7596
def pil_to_blob(img):
7697
# When you load an image with PIL you get a subclass of PIL.Image
@@ -656,6 +677,7 @@ class Tool:
656677
def __init__(
657678
self,
658679
function_declarations: Iterable[FunctionDeclarationType] | None = None,
680+
google_search_retrieval: protos.GoogleSearchRetrieval | None = None,
659681
code_execution: protos.CodeExecution | None = None,
660682
):
661683
# The main path doesn't use this but is seems useful.
@@ -676,6 +698,7 @@ def __init__(
676698

677699
self._proto = protos.Tool(
678700
function_declarations=[_encode_fd(fd) for fd in self._function_declarations],
701+
google_search_retrieval=google_search_retrieval,
679702
code_execution=code_execution,
680703
)
681704

@@ -723,20 +746,36 @@ def _make_tool(tool: ToolType) -> Tool:
723746
code_execution = tool.code_execution
724747
else:
725748
code_execution = None
726-
return Tool(function_declarations=tool.function_declarations, code_execution=code_execution)
749+
return Tool(
750+
function_declarations=tool.function_declarations,
751+
google_search_retrieval=tool.google_search_retrieval,
752+
code_execution=code_execution,
753+
)
727754
elif isinstance(tool, dict):
728-
if "function_declarations" in tool or "code_execution" in tool:
755+
if (
756+
"function_declarations" in tool
757+
or "google_search_retrieval" in tool
758+
or "code_execution" in tool
759+
):
729760
return Tool(**tool)
730761
else:
731762
fd = tool
732763
return Tool(function_declarations=[protos.FunctionDeclaration(**fd)])
733764
elif isinstance(tool, str):
734765
if tool.lower() == "code_execution":
735766
return Tool(code_execution=protos.CodeExecution())
767+
# Check to see if one of the mode enums matches
768+
elif to_mode(tool) == Mode.MODE_UNSPECIFIED or to_mode(tool) == Mode.DYNAMIC:
769+
mode = to_mode(tool)
770+
return Tool(google_search_retrieval=protos.GoogleSearchRetrieval(mode=mode))
736771
else:
737-
raise ValueError("The only string that can be passed as a tool is 'code_execution'.")
772+
raise ValueError(
773+
"The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval."
774+
)
738775
elif isinstance(tool, protos.CodeExecution):
739776
return Tool(code_execution=tool)
777+
elif isinstance(tool, protos.GoogleSearchRetrieval):
778+
return Tool(google_search_retrieval=tool)
740779
elif isinstance(tool, Iterable):
741780
return Tool(function_declarations=tool)
742781
else:

tests/test_content.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,56 @@ def test_code_execution(self, tools):
433433
t = content_types._make_tool(tools) # Pass code execution into tools
434434
self.assertIsInstance(t.code_execution, protos.CodeExecution)
435435

436+
@parameterized.named_parameters(
437+
["string", "unspecified"],
438+
[
439+
"dictionary",
440+
{"google_search_retrieval": {"mode": "unspecified", "dynamic_threshold": 0.5}},
441+
],
442+
["tuple", ("unspecified", 0.5)],
443+
[
444+
"proto_object",
445+
protos.GoogleSearchRetrieval(
446+
protos.DynamicRetrievalConfig(mode="MODE_UNSPECIFIED", dynamic_threshold=0.5)
447+
),
448+
],
449+
[
450+
"proto_passed_in",
451+
protos.Tool(
452+
google_search_retrieval=protos.GoogleSearchRetrieval(
453+
protos.DynamicRetrievalConfig(mode="MODE_UNSPECIFIED", dynamic_threshold=0.5)
454+
)
455+
),
456+
],
457+
[
458+
"proto_object_list",
459+
[
460+
protos.GoogleSearchRetrieval(
461+
protos.DynamicRetrievalConfig(mode="MODE_UNSPECIFIED", dynamic_threshold=0.5)
462+
)
463+
],
464+
],
465+
[
466+
"proto_passed_in_list",
467+
[
468+
protos.Tool(
469+
google_search_retrieval=protos.GoogleSearchRetrieval(
470+
protos.DynamicRetrievalConfig(
471+
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
472+
)
473+
)
474+
)
475+
],
476+
],
477+
)
478+
def test_search_grounding(self, tools):
479+
if isinstance(tools, Iterable):
480+
t = content_types._make_tools(tools)
481+
self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval)
482+
else:
483+
t = content_types._make_tool(tools) # Pass code execution into tools
484+
self.assertIsInstance(t.google_search_retrieval, protos.GoogleSearchRetrieval)
485+
436486
def test_two_fun_is_one_tool(self):
437487
def a():
438488
pass

0 commit comments

Comments
 (0)