Skip to content

Commit c5cebf2

Browse files
Update test cases and _make_search_grounding
1 parent cc45552 commit c5cebf2

File tree

2 files changed

+22
-46
lines changed

2 files changed

+22
-46
lines changed

google/generativeai/types/content_types.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -677,39 +677,22 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F
677677
def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType):
678678
if isinstance(gsr, protos.GoogleSearchRetrieval):
679679
return gsr
680-
elif isinstance(gsr, Iterable) and not isinstance(gsr, Mapping):
681-
# Handle list of protos.Tool(...) and list of protos.GoogleSearchRetrieval
682-
return gsr
683680
elif isinstance(gsr, Mapping):
684-
if "mode" in gsr["dynamic_retrieval_config"]:
685-
print(to_mode(gsr["dynamic_retrieval_config"]["mode"]))
686-
# Create proto object from dictionary
687-
gsr = {
688-
"google_search_retrieval": {
689-
"dynamic_retrieval_config": {
690-
"mode": to_mode(gsr["dynamic_retrieval_config"]["mode"]),
691-
"dynamic_threshold": gsr["dynamic_retrieval_config"]["dynamic_threshold"],
692-
}
693-
}
694-
}
695-
print(gsr)
696-
elif "mode" in gsr.keys():
697-
# Create proto object from dictionary
698-
gsr = {
699-
"google_search_retrieval": {
700-
"dynamic_retrieval_config": {
701-
"mode": to_mode(gsr["mode"]),
702-
"dynamic_threshold": gsr["dynamic_threshold"],
703-
}
704-
}
705-
}
706-
return gsr
681+
drc = gsr.get("dynamic_retrieval_config", None)
682+
if drc is not None:
683+
mode = drc.get("mode", None)
684+
if mode is not None:
685+
mode = to_mode(mode)
686+
gsr = gsr.copy()
687+
gsr["dynamic_retrieval_config"]["mode"] = mode
688+
return protos.GoogleSearchRetrieval(gsr)
707689
else:
708690
raise TypeError(
709691
"Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n"
710692
f"However, received an object of type: {type(gsr)}.\n"
711693
f"Object Value: {gsr}"
712694
)
695+
713696

714697

715698
class Tool:
@@ -741,13 +724,13 @@ def __init__(
741724

742725
if google_search_retrieval:
743726
if isinstance(google_search_retrieval, str):
744-
google_search_retrieval = {
727+
self._google_search_retrieval = {
745728
"google_search_retrieval": {
746729
"dynamic_retrieval_config": {"mode": to_mode(google_search_retrieval)}
747730
}
748731
}
749732
else:
750-
_make_google_search_retrieval(google_search_retrieval)
733+
self._google_search_retrieval = _make_google_search_retrieval(google_search_retrieval)
751734

752735
self._proto = protos.Tool(
753736
function_declarations=[_encode_fd(fd) for fd in self._function_declarations],
@@ -763,7 +746,7 @@ def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDec
763746

764747
@property
765748
def google_search_retrieval(self) -> protos.GoogleSearchRetrieval:
766-
return self._proto.google_search_retrieval
749+
return self._google_search_retrieval
767750

768751
@property
769752
def code_execution(self) -> protos.CodeExecution:

tests/test_content.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -426,17 +426,16 @@ def no_args():
426426
["empty_dictionary_list", [{"code_execution": {}}]],
427427
)
428428
def test_code_execution(self, tools):
429-
if isinstance(tools, Iterable):
430-
t = content_types._make_tools(tools)
431-
self.assertIsInstance(t[0].code_execution, protos.CodeExecution)
432-
else:
433-
t = content_types._make_tool(tools) # Pass code execution into tools
434-
self.assertIsInstance(t.code_execution, protos.CodeExecution)
429+
t = content_types._make_tools(tools)
430+
self.assertIsInstance(t[0].code_execution, protos.CodeExecution)
435431

436432
@parameterized.named_parameters(
437433
["string", "google_search_retrieval"],
438434
["empty_dictionary", {"google_search_retrieval": {}}],
439-
["empty_dictionary_with_dynamic_retrieval_config", {"google_search_retrieval": {"dynamic_retrieval_config": {}}}],
435+
[
436+
"empty_dictionary_with_dynamic_retrieval_config",
437+
{"google_search_retrieval": {"dynamic_retrieval_config": {}}},
438+
],
440439
[
441440
"dictionary_with_mode_integer",
442441
{"google_search_retrieval": {"dynamic_retrieval_config": {"mode": 0}}},
@@ -453,10 +452,6 @@ def test_code_execution(self, tools):
453452
}
454453
},
455454
],
456-
[
457-
"dictionary_without_dynamic_retrieval_config",
458-
{"google_search_retrieval": {"mode": "unspecified", "dynamic_threshold": 0.5}},
459-
],
460455
[
461456
"proto_object",
462457
protos.GoogleSearchRetrieval(
@@ -499,12 +494,10 @@ def test_code_execution(self, tools):
499494
],
500495
)
501496
def test_search_grounding(self, tools):
502-
if isinstance(tools, Iterable):
503-
t = content_types._make_tools(tools)
504-
self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval)
505-
else:
506-
t = content_types._make_tool(tools) # Pass google_search_retrieval into tools
507-
self.assertIsInstance(t.google_search_retrieval, protos.GoogleSearchRetrieval)
497+
if self._testMethodName == "test_search_grounding_empty_dictionary":
498+
pass
499+
t = content_types._make_tools(tools)
500+
self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval)
508501

509502
def test_two_fun_is_one_tool(self):
510503
def a():

0 commit comments

Comments
 (0)