2
2
3
3
import urllib .parse
4
4
from typing import (
5
+ Any ,
5
6
Dict ,
7
+ List ,
6
8
Optional ,
7
9
Type ,
8
10
Union ,
9
11
)
10
12
11
13
from ibm_watsonx_ai import APIClient , Credentials # type: ignore
12
- from ibm_watsonx_ai .foundation_models .utils import Tool , Toolkit # type: ignore
14
+ from ibm_watsonx_ai .foundation_models .utils import ( # type: ignore
15
+ Tool ,
16
+ Toolkit ,
17
+ )
13
18
from langchain_core .callbacks import CallbackManagerForToolRun
14
19
from langchain_core .tools .base import BaseTool , BaseToolkit
15
20
from langchain_core .utils .utils import secret_from_env
16
- from pydantic import BaseModel , ConfigDict , Field , SecretStr , model_validator
21
+ from pydantic import (
22
+ BaseModel ,
23
+ ConfigDict ,
24
+ Field ,
25
+ PrivateAttr ,
26
+ SecretStr ,
27
+ create_model ,
28
+ model_validator ,
29
+ )
17
30
from typing_extensions import Self
18
31
19
- from langchain_ibm .utils import check_for_attribute
20
-
21
-
22
- class ToolSchema (BaseModel ):
23
- input : Union [str , dict ]
24
- """Input to be used when running a tool."""
25
-
26
- config : Optional [dict ] = None
27
- """Configuration options that can be passed for some tools,
28
- must match the config schema for the tool."""
32
+ from langchain_ibm .utils import check_for_attribute , convert_to_watsonx_tool
29
33
30
34
31
35
class WatsonxTool (BaseTool ):
@@ -47,32 +51,66 @@ class WatsonxTool(BaseTool):
47
51
tool_config_schema : Optional [Dict ] = None
48
52
"""Schema of the config that can be provided when running the tool if applicable."""
49
53
50
- args_schema : Type [BaseModel ] = ToolSchema
54
+ tool_config : Optional [Dict ] = None
55
+ """Config properties to be used when running a tool if applicable."""
51
56
52
- watsonx_tool : Tool = Field (default = None , exclude = True ) #: :meta private:
57
+ args_schema : Type [BaseModel ] = BaseModel
58
+
59
+ _watsonx_tool : Optional [Tool ] = PrivateAttr (default = None ) #: :meta private:
53
60
54
61
watsonx_client : APIClient = Field (exclude = True )
55
62
56
63
@model_validator (mode = "after" )
57
64
def validate_tool (self ) -> Self :
58
- self .watsonx_tool = Tool (
65
+ self ._watsonx_tool = Tool (
59
66
api_client = self .watsonx_client ,
60
67
name = self .name ,
61
68
description = self .description ,
62
69
agent_description = self .agent_description ,
63
70
input_schema = self .tool_input_schema ,
64
71
config_schema = self .tool_config_schema ,
65
72
)
73
+ converted_tool = convert_to_watsonx_tool (self )
74
+ json_schema = converted_tool ["function" ]["parameters" ]
75
+ self .args_schema = json_schema_to_pydantic_model (
76
+ name = "ToolArgsSchema" , schema = json_schema
77
+ )
78
+
66
79
return self
67
80
68
81
def _run (
69
82
self ,
70
- input : Union [str , dict ],
71
- config : Optional [dict ] = None ,
83
+ * args : Any ,
72
84
run_manager : Optional [CallbackManagerForToolRun ] = None ,
85
+ ** kwargs : Any ,
73
86
) -> dict :
74
87
"""Run the tool."""
75
- return self .watsonx_tool .run (input , config )
88
+ if self .tool_input_schema is None :
89
+ input = kwargs .get ("input" ) or args [0 ]
90
+ else :
91
+ input = {
92
+ k : v
93
+ for k , v in kwargs .items ()
94
+ if k in self .tool_input_schema ["properties" ]
95
+ }
96
+
97
+ return self ._watsonx_tool .run (input , self .tool_config ) # type: ignore[union-attr]
98
+
99
+ def set_tool_config (self , tool_config : dict ) -> None :
100
+ """Set tool config properties.
101
+
102
+ Example:
103
+ .. code-block:: python
104
+
105
+ google_search = watsonx_toolkit.get_tool("GoogleSearch")
106
+ print(google_search.tool_config_schema)
107
+ tool_config = {
108
+ "maxResults": 3
109
+ }
110
+ google_search.set_tool_config(tool_config)
111
+
112
+ """
113
+ self .tool_config = tool_config
76
114
77
115
78
116
class WatsonxToolkit (BaseToolkit ):
@@ -99,20 +137,19 @@ class WatsonxToolkit(BaseToolkit):
99
137
watsonx_toolkit = WatsonxToolkit(
100
138
url="https://us-south.ml.cloud.ibm.com",
101
139
apikey="*****",
102
- project_id="*****",
103
140
)
104
141
tools = watsonx_toolkit.get_tools()
105
142
106
- google_search = watsonx_toolkit.get_tool("GoogleSearch")
143
+ google_search = watsonx_toolkit.get_tool(tool_name= "GoogleSearch")
107
144
108
- config = {
145
+ tool_config = {
109
146
"maxResults": 3,
110
147
}
148
+ google_search.set_tool_config(tool_config)
111
149
input = {
112
150
"input": "Search IBM",
113
- "config": config,
114
151
}
115
- search_result = google_search.invoke(input=input )
152
+ search_result = google_search.invoke(input)
116
153
117
154
"""
118
155
@@ -145,7 +182,10 @@ class WatsonxToolkit(BaseToolkit):
145
182
* True - default path to truststore will be taken
146
183
* False - no verification will be made"""
147
184
148
- watsonx_toolkit : Toolkit = Field (default = None , exclude = True ) #: :meta private:
185
+ _tools : Optional [List [WatsonxTool ]] = None
186
+ """Tools in the toolkit."""
187
+
188
+ _watsonx_toolkit : Optional [Toolkit ] = PrivateAttr (default = None ) #: :meta private:
149
189
150
190
watsonx_client : Optional [APIClient ] = Field (default = None , exclude = True )
151
191
@@ -155,7 +195,7 @@ class WatsonxToolkit(BaseToolkit):
155
195
def validate_environment (self ) -> Self :
156
196
"""Validate that credentials and python package exists in environment."""
157
197
if isinstance (self .watsonx_client , APIClient ):
158
- self .watsonx_toolkit = Toolkit (self .watsonx_client )
198
+ self ._watsonx_toolkit = Toolkit (self .watsonx_client )
159
199
else :
160
200
check_for_attribute (self .url , "url" , "WATSONX_URL" )
161
201
@@ -187,15 +227,9 @@ def validate_environment(self) -> Self:
187
227
project_id = self .project_id ,
188
228
space_id = self .space_id ,
189
229
)
190
- self .watsonx_toolkit = Toolkit (self .watsonx_client )
191
-
192
- return self
193
-
194
- def get_tools (self ) -> list [WatsonxTool ]: # type: ignore
195
- """Get the tools in the toolkit."""
196
- tools = self .watsonx_toolkit .get_tools ()
230
+ self ._watsonx_toolkit = Toolkit (self .watsonx_client )
197
231
198
- return [
232
+ self . _tools = [
199
233
WatsonxTool (
200
234
watsonx_client = self .watsonx_client ,
201
235
name = tool ["name" ],
@@ -204,13 +238,42 @@ def get_tools(self) -> list[WatsonxTool]: # type: ignore
204
238
tool_input_schema = tool .get ("input_schema" ),
205
239
tool_config_schema = tool .get ("config_schema" ),
206
240
)
207
- for tool in tools
241
+ for tool in self . _watsonx_toolkit . get_tools ()
208
242
]
209
243
244
+ return self
245
+
246
+ def get_tools (self ) -> list [WatsonxTool ]: # type: ignore
247
+ """Get the tools in the toolkit."""
248
+ return self ._tools # type: ignore[return-value]
249
+
210
250
def get_tool (self , tool_name : str ) -> WatsonxTool :
211
251
"""Get the tool with a given name."""
212
- tools = self .get_tools ()
213
- for tool in tools :
252
+ for tool in self .get_tools ():
214
253
if tool .name == tool_name :
215
254
return tool
216
255
raise ValueError (f"A tool with the given name ({ tool_name } ) was not found." )
256
+
257
+
258
+ def json_schema_to_pydantic_model (name : str , schema : Dict [str , Any ]) -> Type [BaseModel ]:
259
+ properties = schema .get ("properties" , {})
260
+ fields = {}
261
+
262
+ type_mapping = {
263
+ "string" : str ,
264
+ "integer" : int ,
265
+ "number" : float ,
266
+ "boolean" : bool ,
267
+ "array" : list ,
268
+ "object" : dict ,
269
+ }
270
+
271
+ for field_name , field_schema in properties .items ():
272
+ field_type = field_schema .get ("type" , "string" )
273
+ is_required = field_name in schema .get ("required" , [])
274
+
275
+ py_type = type_mapping .get (field_type , Any )
276
+
277
+ fields [field_name ] = (py_type , ... if is_required else None )
278
+
279
+ return create_model (name , ** fields ) # type: ignore[call-overload]
0 commit comments