Skip to content

Commit 84a5f29

Browse files
update content_types
1 parent fa1651d commit 84a5f29

File tree

1 file changed

+59
-13
lines changed

1 file changed

+59
-13
lines changed

google/generativeai/types/content_types.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@
7373

7474
Mode = protos.DynamicRetrievalConfig.Mode
7575

76-
ModeOptions = Union[str, str, Mode]
76+
ModeOptions = Union[int, str, Mode]
7777

7878
_MODE: dict[ModeOptions, Mode] = {
7979
Mode.MODE_UNSPECIFIED: Mode.MODE_UNSPECIFIED,
8080
0: Mode.MODE_UNSPECIFIED,
8181
"mode_unspecified": Mode.MODE_UNSPECIFIED,
8282
"unspecified": Mode.MODE_UNSPECIFIED,
83-
Mode.DYNAMIC: Mode.DYNAMIC,
84-
1: Mode.DYNAMIC,
85-
"mode_dynamic": Mode.DYNAMIC,
86-
"dynamic": Mode.DYNAMIC,
83+
Mode.MODE_DYNAMIC: Mode.MODE_DYNAMIC,
84+
1: Mode.MODE_DYNAMIC,
85+
"mode_dynamic": Mode.MODE_DYNAMIC,
86+
"dynamic": Mode.MODE_DYNAMIC,
8787
}
8888

8989

@@ -670,14 +670,43 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F
670670

671671
return fd.to_proto()
672672

673+
GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, dict[str, float]]
674+
675+
def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType):
676+
if isinstance(gsr, protos.GoogleSearchRetrieval):
677+
return gsr
678+
elif isinstance(gsr, Iterable) and not isinstance(gsr, Mapping):
679+
# Handle list of protos.Tool(...) and list of protos.GoogleSearchRetrieval
680+
return gsr
681+
elif isinstance(gsr, Mapping):
682+
if "mode" in gsr["dynamic_retrieval_config"]:
683+
print(to_mode(gsr["dynamic_retrieval_config"]["mode"]))
684+
# Create proto object from dictionary
685+
gsr = {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": to_mode(gsr["dynamic_retrieval_config"]["mode"]),
686+
"dynamic_threshold": gsr["dynamic_retrieval_config"]["dynamic_threshold"]}}}
687+
print(gsr)
688+
elif "mode" in gsr.keys():
689+
# Create proto object from dictionary
690+
gsr = {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": to_mode(gsr["mode"]),
691+
"dynamic_threshold": gsr["dynamic_threshold"]}}}
692+
return gsr
693+
else:
694+
raise TypeError(
695+
"Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n"
696+
f"However, received an object of type: {type(gsr)}.\n"
697+
f"Object Value: {gsr}"
698+
)
699+
673700

674701
class Tool:
675-
"""A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects."""
702+
"""A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects,
703+
protos.CodeExecution object, and protos.GoogleSearchRetrieval object."""
676704

677705
def __init__(
678706
self,
707+
*,
679708
function_declarations: Iterable[FunctionDeclarationType] | None = None,
680-
google_search_retrieval: protos.GoogleSearchRetrieval | None = None,
709+
google_search_retrieval: Union[protos.GoogleSearchRetrieval, str] | None = None,
681710
code_execution: protos.CodeExecution | None = None,
682711
):
683712
# The main path doesn't use this but is seems useful.
@@ -695,17 +724,29 @@ def __init__(
695724
# Consistent fields
696725
self._function_declarations = []
697726
self._index = {}
727+
728+
if google_search_retrieval:
729+
if isinstance(google_search_retrieval, str):
730+
google_search_retrieval = {"google_search_retrieval" : {"dynamic_retrieval_config": {"mode": to_mode(google_search_retrieval)}}}
731+
else:
732+
_make_google_search_retrieval(google_search_retrieval)
698733

699734
self._proto = protos.Tool(
700735
function_declarations=[_encode_fd(fd) for fd in self._function_declarations],
701736
google_search_retrieval=google_search_retrieval,
702737
code_execution=code_execution,
703738
)
704739

740+
print(self._proto.google_search_retrieval)
741+
705742
@property
706743
def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]:
707744
return self._function_declarations
708745

746+
@property
747+
def google_search_retrieval(self) -> protos.GoogleSearchRetrieval:
748+
return self._proto.google_search_retrieval
749+
709750
@property
710751
def code_execution(self) -> protos.CodeExecution:
711752
return self._proto.code_execution
@@ -734,7 +775,7 @@ class ToolDict(TypedDict):
734775

735776

736777
ToolType = Union[
737-
Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType
778+
str, Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType
738779
]
739780

740781

@@ -746,9 +787,15 @@ def _make_tool(tool: ToolType) -> Tool:
746787
code_execution = tool.code_execution
747788
else:
748789
code_execution = None
790+
791+
if "google_search_retrieval" in tool:
792+
google_search_retrieval = tool.google_search_retrieval
793+
else:
794+
google_search_retrieval = None
795+
749796
return Tool(
750797
function_declarations=tool.function_declarations,
751-
google_search_retrieval=tool.google_search_retrieval,
798+
google_search_retrieval=google_search_retrieval,
752799
code_execution=code_execution,
753800
)
754801
elif isinstance(tool, dict):
@@ -765,9 +812,8 @@ def _make_tool(tool: ToolType) -> Tool:
765812
if tool.lower() == "code_execution":
766813
return Tool(code_execution=protos.CodeExecution())
767814
# Check to see if one of the mode enums matches
768-
elif to_mode(tool) == Mode.MODE_UNSPECIFIED or to_mode(tool) == Mode.DYNAMIC:
769-
mode = to_mode(tool)
770-
return Tool(google_search_retrieval=protos.GoogleSearchRetrieval(mode=mode))
815+
elif tool.lower() == "google_search_retrieval":
816+
return Tool(google_search_retrieval=protos.GoogleSearchRetrieval())
771817
else:
772818
raise ValueError(
773819
"The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval."
@@ -831,7 +877,7 @@ def to_proto(self):
831877

832878
def _make_tools(tools: ToolsType) -> list[Tool]:
833879
if isinstance(tools, str):
834-
if tools.lower() == "code_execution":
880+
if tools.lower() == "code_execution" or tools.lower() == "google_search_retrieval":
835881
return [_make_tool(tools)]
836882
else:
837883
raise ValueError("The only string that can be passed as a tool is 'code_execution'.")

0 commit comments

Comments
 (0)