2323log = logging .getLogger (__name__ )
2424
2525
26- def apply_extra_params_to_tool_function (
26+ def get_async_tool_function_and_apply_extra_params (
2727 function : Callable , extra_params : dict
2828) -> Callable [..., Awaitable ]:
2929 sig = inspect .signature (function )
3030 extra_params = {k : v for k , v in extra_params .items () if k in sig .parameters }
3131 partial_func = partial (function , ** extra_params )
32+
3233 if inspect .iscoroutinefunction (function ):
3334 update_wrapper (partial_func , function )
3435 return partial_func
36+ else :
37+ # Make it a coroutine function
38+ async def new_function (* args , ** kwargs ):
39+ return partial_func (* args , ** kwargs )
3540
36- async def new_function (* args , ** kwargs ):
37- return partial_func (* args , ** kwargs )
38-
39- update_wrapper (new_function , function )
40- return new_function
41+ update_wrapper (new_function , function )
42+ return new_function
4143
4244
4345def get_tools (
@@ -48,22 +50,49 @@ def get_tools(
4850 for tool_id in tool_ids :
4951 tool = Tools .get_tool_by_id (tool_id )
5052 if tool is None :
51-
5253 if tool_id .startswith ("server:" ):
5354 server_idx = int (tool_id .split (":" )[1 ])
55+ tool_server_connection = (
56+ request .app .state .config .TOOL_SERVER_CONNECTIONS [server_idx ]
57+ )
5458 tool_server_data = request .app .state .TOOL_SERVERS [server_idx ]
59+ specs = tool_server_data .get ("specs" , [])
60+
61+ for spec in specs :
62+ function_name = spec ["name" ]
63+
64+ auth_type = tool_server_connection .get ("auth_type" , "bearer" )
65+ token = None
66+
67+ if auth_type == "bearer" :
68+ token = tool_server_connection .get ("key" , "" )
69+ elif auth_type == "session" :
70+ token = request .state .token .credentials
71+
72+ callable = get_async_tool_function_and_apply_extra_params (
73+ execute_tool_server ,
74+ {
75+ "token" : token ,
76+ "url" : tool_server_data ["url" ],
77+ "name" : function_name ,
78+ "server_data" : tool_server_data ,
79+ },
80+ )
5581
56- tool_dict = {
57- "spec" : spec ,
58- "callable" : callable ,
59- "tool_id" : tool_id ,
60- # Misc info
61- "metadata" : {
62- "file_handler" : hasattr (module , "file_handler" )
63- and module .file_handler ,
64- "citation" : hasattr (module , "citation" ) and module .citation ,
65- },
66- }
82+ tool_dict = {
83+ "tool_id" : tool_id ,
84+ "callable" : callable ,
85+ "spec" : spec ,
86+ }
87+
88+ # TODO: if collision, prepend toolkit name
89+ if function_name in tools_dict :
90+ log .warning (
91+ f"Tool { function_name } already exists in another tools!"
92+ )
93+ log .warning (f"Discarding { tool_id } .{ function_name } " )
94+ else :
95+ tools_dict [function_name ] = tool_dict
6796 else :
6897 continue
6998 else :
@@ -73,10 +102,11 @@ def get_tools(
73102 request .app .state .TOOLS [tool_id ] = module
74103
75104 extra_params ["__id__" ] = tool_id
105+
106+ # Set valves for the tool
76107 if hasattr (module , "valves" ) and hasattr (module , "Valves" ):
77108 valves = Tools .get_tool_valves_by_id (tool_id ) or {}
78109 module .valves = module .Valves (** valves )
79-
80110 if hasattr (module , "UserValves" ):
81111 extra_params ["__user__" ]["valves" ] = module .UserValves ( # type: ignore
82112 ** Tools .get_user_valves_by_id_and_user_id (tool_id , user .id )
@@ -89,31 +119,31 @@ def get_tools(
89119 if val ["type" ] == "str" :
90120 val ["type" ] = "string"
91121
92- # Remove internal parameters
122+ # Remove internal reserved parameters (e.g. __id__, __user__)
93123 spec ["parameters" ]["properties" ] = {
94124 key : val
95125 for key , val in spec ["parameters" ]["properties" ].items ()
96126 if not key .startswith ("__" )
97127 }
98128
99- function_name = spec ["name" ]
100-
101129 # convert to function that takes only model params and inserts custom params
102- original_func = getattr (module , function_name )
103- callable = apply_extra_params_to_tool_function (
104- original_func , extra_params
130+ function_name = spec ["name" ]
131+ tool_function = getattr (module , function_name )
132+ callable = get_async_tool_function_and_apply_extra_params (
133+ tool_function , extra_params
105134 )
106135
136+ # TODO: Support Pydantic models as parameters
107137 if callable .__doc__ and callable .__doc__ .strip () != "" :
108138 s = re .split (":(param|return)" , callable .__doc__ , 1 )
109139 spec ["description" ] = s [0 ]
110140 else :
111141 spec ["description" ] = function_name
112142
113143 tool_dict = {
114- "spec" : spec ,
115- "callable" : callable ,
116144 "tool_id" : tool_id ,
145+ "callable" : callable ,
146+ "spec" : spec ,
117147 # Misc info
118148 "metadata" : {
119149 "file_handler" : hasattr (module , "file_handler" )
@@ -127,8 +157,7 @@ def get_tools(
127157 log .warning (
128158 f"Tool { function_name } already exists in another tools!"
129159 )
130- log .warning (f"Collision between { tool } and { tool_id } ." )
131- log .warning (f"Discarding { tool } .{ function_name } " )
160+ log .warning (f"Discarding { tool_id } .{ function_name } " )
132161 else :
133162 tools_dict [function_name ] = tool_dict
134163
0 commit comments