Skip to content

Commit 91fc30c

Browse files
Update and add aditional test cases
1 parent fd39814 commit 91fc30c

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

google/generativeai/types/content_types.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -670,8 +670,10 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F
670670

671671
return fd.to_proto()
672672

673+
673674
GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, dict[str, float]]
674675

676+
675677
def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType):
676678
if isinstance(gsr, protos.GoogleSearchRetrieval):
677679
return gsr
@@ -682,13 +684,25 @@ def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType):
682684
if "mode" in gsr["dynamic_retrieval_config"]:
683685
print(to_mode(gsr["dynamic_retrieval_config"]["mode"]))
684686
# Create proto object from dictionary
685-
gsr = {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": to_mode(gsr["dynamic_retrieval_config"]["mode"]),
686-
"dynamic_threshold": gsr["dynamic_retrieval_config"]["dynamic_threshold"]}}}
687+
gsr = {
688+
"google_search_retrieval": {
689+
"dynamic_retrieval_config": {
690+
"mode": to_mode(gsr["dynamic_retrieval_config"]["mode"]),
691+
"dynamic_threshold": gsr["dynamic_retrieval_config"]["dynamic_threshold"],
692+
}
693+
}
694+
}
687695
print(gsr)
688696
elif "mode" in gsr.keys():
689697
# Create proto object from dictionary
690-
gsr = {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": to_mode(gsr["mode"]),
691-
"dynamic_threshold": gsr["dynamic_threshold"]}}}
698+
gsr = {
699+
"google_search_retrieval": {
700+
"dynamic_retrieval_config": {
701+
"mode": to_mode(gsr["mode"]),
702+
"dynamic_threshold": gsr["dynamic_threshold"],
703+
}
704+
}
705+
}
692706
return gsr
693707
else:
694708
raise TypeError(
@@ -724,10 +738,14 @@ def __init__(
724738
# Consistent fields
725739
self._function_declarations = []
726740
self._index = {}
727-
741+
728742
if google_search_retrieval:
729743
if isinstance(google_search_retrieval, str):
730-
google_search_retrieval = {"google_search_retrieval" : {"dynamic_retrieval_config": {"mode": to_mode(google_search_retrieval)}}}
744+
google_search_retrieval = {
745+
"google_search_retrieval": {
746+
"dynamic_retrieval_config": {"mode": to_mode(google_search_retrieval)}
747+
}
748+
}
731749
else:
732750
_make_google_search_retrieval(google_search_retrieval)
733751

@@ -792,7 +810,7 @@ def _make_tool(tool: ToolType) -> Tool:
792810
google_search_retrieval = tool.google_search_retrieval
793811
else:
794812
google_search_retrieval = None
795-
813+
796814
return Tool(
797815
function_declarations=tool.function_declarations,
798816
google_search_retrieval=google_search_retrieval,

tests/test_content.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,16 @@ def test_code_execution(self, tools):
435435

436436
@parameterized.named_parameters(
437437
["string", "google_search_retrieval"],
438+
["empty_dictionary", {"google_search_retrieval": {}}],
439+
["empty_dictionary_with_dynamic_retrieval_config", {"dynamic_retrieval_config": {}}],
440+
[
441+
"dictionary_with_mode_integer",
442+
{"google_search_retrieval": {"dynamic_retrieval_config": {"mode": 0}}},
443+
],
444+
[
445+
"dictionary_with_mode_string",
446+
{"google_search_retrieval": {"dynamic_retrieval_config": {"mode": "DYNAMIC"}}},
447+
],
438448
[
439449
"dictionary_with_dynamic_retrieval_config",
440450
{
@@ -493,7 +503,7 @@ def test_search_grounding(self, tools):
493503
t = content_types._make_tools(tools)
494504
self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval)
495505
else:
496-
t = content_types._make_tool(tools) # Pass code execution into tools
506+
t = content_types._make_tool(tools) # Pass google_search_retrieval into tools
497507
self.assertIsInstance(t.google_search_retrieval, protos.GoogleSearchRetrieval)
498508

499509
def test_two_fun_is_one_tool(self):

0 commit comments

Comments
 (0)