73
73
74
74
Mode = protos .DynamicRetrievalConfig .Mode
75
75
76
- ModeOptions = Union [str , str , Mode ]
76
+ ModeOptions = Union [int , str , Mode ]
77
77
78
78
_MODE : dict [ModeOptions , Mode ] = {
79
79
Mode .MODE_UNSPECIFIED : Mode .MODE_UNSPECIFIED ,
80
80
0 : Mode .MODE_UNSPECIFIED ,
81
81
"mode_unspecified" : Mode .MODE_UNSPECIFIED ,
82
82
"unspecified" : Mode .MODE_UNSPECIFIED ,
83
- Mode .DYNAMIC : Mode .DYNAMIC ,
84
- 1 : Mode .DYNAMIC ,
85
- "mode_dynamic" : Mode .DYNAMIC ,
86
- "dynamic" : Mode .DYNAMIC ,
83
+ Mode .MODE_DYNAMIC : Mode .MODE_DYNAMIC ,
84
+ 1 : Mode .MODE_DYNAMIC ,
85
+ "mode_dynamic" : Mode .MODE_DYNAMIC ,
86
+ "dynamic" : Mode .MODE_DYNAMIC ,
87
87
}
88
88
89
89
@@ -670,14 +670,43 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F
670
670
671
671
return fd .to_proto ()
672
672
673
+ GoogleSearchRetrievalType = Union [protos .GoogleSearchRetrieval , dict [str , float ]]
674
+
675
+ def _make_google_search_retrieval (gsr : GoogleSearchRetrievalType ):
676
+ if isinstance (gsr , protos .GoogleSearchRetrieval ):
677
+ return gsr
678
+ elif isinstance (gsr , Iterable ) and not isinstance (gsr , Mapping ):
679
+ # Handle list of protos.Tool(...) and list of protos.GoogleSearchRetrieval
680
+ return gsr
681
+ elif isinstance (gsr , Mapping ):
682
+ if "mode" in gsr ["dynamic_retrieval_config" ]:
683
+ print (to_mode (gsr ["dynamic_retrieval_config" ]["mode" ]))
684
+ # Create proto object from dictionary
685
+ gsr = {"google_search_retrieval" : {"dynamic_retrieval_config" : {"mode" : to_mode (gsr ["dynamic_retrieval_config" ]["mode" ]),
686
+ "dynamic_threshold" : gsr ["dynamic_retrieval_config" ]["dynamic_threshold" ]}}}
687
+ print (gsr )
688
+ elif "mode" in gsr .keys ():
689
+ # Create proto object from dictionary
690
+ gsr = {"google_search_retrieval" : {"dynamic_retrieval_config" : {"mode" : to_mode (gsr ["mode" ]),
691
+ "dynamic_threshold" : gsr ["dynamic_threshold" ]}}}
692
+ return gsr
693
+ else :
694
+ raise TypeError (
695
+ "Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n "
696
+ f"However, received an object of type: { type (gsr )} .\n "
697
+ f"Object Value: { gsr } "
698
+ )
699
+
673
700
674
701
class Tool :
675
- """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects."""
702
+ """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects,
703
+ protos.CodeExecution object, and protos.GoogleSearchRetrieval object."""
676
704
677
705
def __init__ (
678
706
self ,
707
+ * ,
679
708
function_declarations : Iterable [FunctionDeclarationType ] | None = None ,
680
- google_search_retrieval : protos .GoogleSearchRetrieval | None = None ,
709
+ google_search_retrieval : Union [ protos .GoogleSearchRetrieval , str ] | None = None ,
681
710
code_execution : protos .CodeExecution | None = None ,
682
711
):
683
712
# The main path doesn't use this but is seems useful.
@@ -695,17 +724,29 @@ def __init__(
695
724
# Consistent fields
696
725
self ._function_declarations = []
697
726
self ._index = {}
727
+
728
+ if google_search_retrieval :
729
+ if isinstance (google_search_retrieval , str ):
730
+ google_search_retrieval = {"google_search_retrieval" : {"dynamic_retrieval_config" : {"mode" : to_mode (google_search_retrieval )}}}
731
+ else :
732
+ _make_google_search_retrieval (google_search_retrieval )
698
733
699
734
self ._proto = protos .Tool (
700
735
function_declarations = [_encode_fd (fd ) for fd in self ._function_declarations ],
701
736
google_search_retrieval = google_search_retrieval ,
702
737
code_execution = code_execution ,
703
738
)
704
739
740
+ print (self ._proto .google_search_retrieval )
741
+
705
742
@property
706
743
def function_declarations (self ) -> list [FunctionDeclaration | protos .FunctionDeclaration ]:
707
744
return self ._function_declarations
708
745
746
+ @property
747
+ def google_search_retrieval (self ) -> protos .GoogleSearchRetrieval :
748
+ return self ._proto .google_search_retrieval
749
+
709
750
@property
710
751
def code_execution (self ) -> protos .CodeExecution :
711
752
return self ._proto .code_execution
@@ -734,7 +775,7 @@ class ToolDict(TypedDict):
734
775
735
776
736
777
ToolType = Union [
737
- Tool , protos .Tool , ToolDict , Iterable [FunctionDeclarationType ], FunctionDeclarationType
778
+ str , Tool , protos .Tool , ToolDict , Iterable [FunctionDeclarationType ], FunctionDeclarationType
738
779
]
739
780
740
781
@@ -746,9 +787,15 @@ def _make_tool(tool: ToolType) -> Tool:
746
787
code_execution = tool .code_execution
747
788
else :
748
789
code_execution = None
790
+
791
+ if "google_search_retrieval" in tool :
792
+ google_search_retrieval = tool .google_search_retrieval
793
+ else :
794
+ google_search_retrieval = None
795
+
749
796
return Tool (
750
797
function_declarations = tool .function_declarations ,
751
- google_search_retrieval = tool . google_search_retrieval ,
798
+ google_search_retrieval = google_search_retrieval ,
752
799
code_execution = code_execution ,
753
800
)
754
801
elif isinstance (tool , dict ):
@@ -765,9 +812,8 @@ def _make_tool(tool: ToolType) -> Tool:
765
812
if tool .lower () == "code_execution" :
766
813
return Tool (code_execution = protos .CodeExecution ())
767
814
# Check to see if one of the mode enums matches
768
- elif to_mode (tool ) == Mode .MODE_UNSPECIFIED or to_mode (tool ) == Mode .DYNAMIC :
769
- mode = to_mode (tool )
770
- return Tool (google_search_retrieval = protos .GoogleSearchRetrieval (mode = mode ))
815
+ elif tool .lower () == "google_search_retrieval" :
816
+ return Tool (google_search_retrieval = protos .GoogleSearchRetrieval ())
771
817
else :
772
818
raise ValueError (
773
819
"The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval."
@@ -831,7 +877,7 @@ def to_proto(self):
831
877
832
878
def _make_tools (tools : ToolsType ) -> list [Tool ]:
833
879
if isinstance (tools , str ):
834
- if tools .lower () == "code_execution" :
880
+ if tools .lower () == "code_execution" or tools . lower () == "google_search_retrieval" :
835
881
return [_make_tool (tools )]
836
882
else :
837
883
raise ValueError ("The only string that can be passed as a tool is 'code_execution'." )
0 commit comments