Skip to content

Commit 88c9fc8

Browse files
authored
feat: Improve WatsonxToolkit and WatsonxTool (#65)
* Improve WatsonxToolkit, WatsonxTool, update tests * Format, cleanup * Lint * Fix docstring * Fix lint * Some fixes and improvements * Improve docs and move converter * Fix circular import error * Change default value * Set attributes as private
1 parent a1eea3b commit 88c9fc8

File tree

8 files changed

+357
-61
lines changed

8 files changed

+357
-61
lines changed

libs/ibm/langchain_ibm/toolkit.py

Lines changed: 99 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,34 @@
22

33
import urllib.parse
44
from typing import (
5+
Any,
56
Dict,
7+
List,
68
Optional,
79
Type,
810
Union,
911
)
1012

1113
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
12-
from ibm_watsonx_ai.foundation_models.utils import Tool, Toolkit # type: ignore
14+
from ibm_watsonx_ai.foundation_models.utils import ( # type: ignore
15+
Tool,
16+
Toolkit,
17+
)
1318
from langchain_core.callbacks import CallbackManagerForToolRun
1419
from langchain_core.tools.base import BaseTool, BaseToolkit
1520
from langchain_core.utils.utils import secret_from_env
16-
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
21+
from pydantic import (
22+
BaseModel,
23+
ConfigDict,
24+
Field,
25+
PrivateAttr,
26+
SecretStr,
27+
create_model,
28+
model_validator,
29+
)
1730
from typing_extensions import Self
1831

19-
from langchain_ibm.utils import check_for_attribute
20-
21-
22-
class ToolSchema(BaseModel):
23-
input: Union[str, dict]
24-
"""Input to be used when running a tool."""
25-
26-
config: Optional[dict] = None
27-
"""Configuration options that can be passed for some tools,
28-
must match the config schema for the tool."""
32+
from langchain_ibm.utils import check_for_attribute, convert_to_watsonx_tool
2933

3034

3135
class WatsonxTool(BaseTool):
@@ -47,32 +51,66 @@ class WatsonxTool(BaseTool):
4751
tool_config_schema: Optional[Dict] = None
4852
"""Schema of the config that can be provided when running the tool if applicable."""
4953

50-
args_schema: Type[BaseModel] = ToolSchema
54+
tool_config: Optional[Dict] = None
55+
"""Config properties to be used when running a tool if applicable."""
5156

52-
watsonx_tool: Tool = Field(default=None, exclude=True) #: :meta private:
57+
args_schema: Type[BaseModel] = BaseModel
58+
59+
_watsonx_tool: Optional[Tool] = PrivateAttr(default=None) #: :meta private:
5360

5461
watsonx_client: APIClient = Field(exclude=True)
5562

5663
@model_validator(mode="after")
5764
def validate_tool(self) -> Self:
58-
self.watsonx_tool = Tool(
65+
self._watsonx_tool = Tool(
5966
api_client=self.watsonx_client,
6067
name=self.name,
6168
description=self.description,
6269
agent_description=self.agent_description,
6370
input_schema=self.tool_input_schema,
6471
config_schema=self.tool_config_schema,
6572
)
73+
converted_tool = convert_to_watsonx_tool(self)
74+
json_schema = converted_tool["function"]["parameters"]
75+
self.args_schema = json_schema_to_pydantic_model(
76+
name="ToolArgsSchema", schema=json_schema
77+
)
78+
6679
return self
6780

6881
def _run(
6982
self,
70-
input: Union[str, dict],
71-
config: Optional[dict] = None,
83+
*args: Any,
7284
run_manager: Optional[CallbackManagerForToolRun] = None,
85+
**kwargs: Any,
7386
) -> dict:
7487
"""Run the tool."""
75-
return self.watsonx_tool.run(input, config)
88+
if self.tool_input_schema is None:
89+
input = kwargs.get("input") or args[0]
90+
else:
91+
input = {
92+
k: v
93+
for k, v in kwargs.items()
94+
if k in self.tool_input_schema["properties"]
95+
}
96+
97+
return self._watsonx_tool.run(input, self.tool_config) # type: ignore[union-attr]
98+
99+
def set_tool_config(self, tool_config: dict) -> None:
100+
"""Set tool config properties.
101+
102+
Example:
103+
.. code-block:: python
104+
105+
google_search = watsonx_toolkit.get_tool("GoogleSearch")
106+
print(google_search.tool_config_schema)
107+
tool_config = {
108+
"maxResults": 3
109+
}
110+
google_search.set_tool_config(tool_config)
111+
112+
"""
113+
self.tool_config = tool_config
76114

77115

78116
class WatsonxToolkit(BaseToolkit):
@@ -99,20 +137,19 @@ class WatsonxToolkit(BaseToolkit):
99137
watsonx_toolkit = WatsonxToolkit(
100138
url="https://us-south.ml.cloud.ibm.com",
101139
apikey="*****",
102-
project_id="*****",
103140
)
104141
tools = watsonx_toolkit.get_tools()
105142
106-
google_search = watsonx_toolkit.get_tool("GoogleSearch")
143+
google_search = watsonx_toolkit.get_tool(tool_name="GoogleSearch")
107144
108-
config = {
145+
tool_config = {
109146
"maxResults": 3,
110147
}
148+
google_search.set_tool_config(tool_config)
111149
input = {
112150
"input": "Search IBM",
113-
"config": config,
114151
}
115-
search_result = google_search.invoke(input=input)
152+
search_result = google_search.invoke(input)
116153
117154
"""
118155

@@ -145,7 +182,10 @@ class WatsonxToolkit(BaseToolkit):
145182
* True - default path to truststore will be taken
146183
* False - no verification will be made"""
147184

148-
watsonx_toolkit: Toolkit = Field(default=None, exclude=True) #: :meta private:
185+
_tools: Optional[List[WatsonxTool]] = None
186+
"""Tools in the toolkit."""
187+
188+
_watsonx_toolkit: Optional[Toolkit] = PrivateAttr(default=None) #: :meta private:
149189

150190
watsonx_client: Optional[APIClient] = Field(default=None, exclude=True)
151191

@@ -155,7 +195,7 @@ class WatsonxToolkit(BaseToolkit):
155195
def validate_environment(self) -> Self:
156196
"""Validate that credentials and python package exists in environment."""
157197
if isinstance(self.watsonx_client, APIClient):
158-
self.watsonx_toolkit = Toolkit(self.watsonx_client)
198+
self._watsonx_toolkit = Toolkit(self.watsonx_client)
159199
else:
160200
check_for_attribute(self.url, "url", "WATSONX_URL")
161201

@@ -187,15 +227,9 @@ def validate_environment(self) -> Self:
187227
project_id=self.project_id,
188228
space_id=self.space_id,
189229
)
190-
self.watsonx_toolkit = Toolkit(self.watsonx_client)
191-
192-
return self
193-
194-
def get_tools(self) -> list[WatsonxTool]: # type: ignore
195-
"""Get the tools in the toolkit."""
196-
tools = self.watsonx_toolkit.get_tools()
230+
self._watsonx_toolkit = Toolkit(self.watsonx_client)
197231

198-
return [
232+
self._tools = [
199233
WatsonxTool(
200234
watsonx_client=self.watsonx_client,
201235
name=tool["name"],
@@ -204,13 +238,42 @@ def get_tools(self) -> list[WatsonxTool]: # type: ignore
204238
tool_input_schema=tool.get("input_schema"),
205239
tool_config_schema=tool.get("config_schema"),
206240
)
207-
for tool in tools
241+
for tool in self._watsonx_toolkit.get_tools()
208242
]
209243

244+
return self
245+
246+
def get_tools(self) -> list[WatsonxTool]: # type: ignore
247+
"""Get the tools in the toolkit."""
248+
return self._tools # type: ignore[return-value]
249+
210250
def get_tool(self, tool_name: str) -> WatsonxTool:
211251
"""Get the tool with a given name."""
212-
tools = self.get_tools()
213-
for tool in tools:
252+
for tool in self.get_tools():
214253
if tool.name == tool_name:
215254
return tool
216255
raise ValueError(f"A tool with the given name ({tool_name}) was not found.")
256+
257+
258+
def json_schema_to_pydantic_model(name: str, schema: Dict[str, Any]) -> Type[BaseModel]:
259+
properties = schema.get("properties", {})
260+
fields = {}
261+
262+
type_mapping = {
263+
"string": str,
264+
"integer": int,
265+
"number": float,
266+
"boolean": bool,
267+
"array": list,
268+
"object": dict,
269+
}
270+
271+
for field_name, field_schema in properties.items():
272+
field_type = field_schema.get("type", "string")
273+
is_required = field_name in schema.get("required", [])
274+
275+
py_type = type_mapping.get(field_type, Any)
276+
277+
fields[field_name] = (py_type, ... if is_required else None)
278+
279+
return create_model(name, **fields) # type: ignore[call-overload]

libs/ibm/langchain_ibm/utils.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from copy import deepcopy
2-
from typing import Any, Dict, Optional, Union
2+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
33

44
from ibm_watsonx_ai.foundation_models.schema import BaseSchema # type: ignore
55
from pydantic import SecretStr
66

7+
if TYPE_CHECKING:
8+
from langchain_ibm.toolkit import WatsonxTool
9+
710

811
def check_for_attribute(value: SecretStr | None, key: str, env_key: str) -> None:
912
if not value or not value.get_secret_value():
@@ -57,3 +60,78 @@ def check_duplicate_chat_params(params: dict, kwargs: dict) -> None:
5760
f"Duplicate parameters found in params and keyword arguments: "
5861
f"{list(duplicate_keys)}"
5962
)
63+
64+
65+
def convert_to_watsonx_tool(tool: "WatsonxTool") -> dict:
66+
"""Convert `WatsonxTool` to watsonx tool structure.
67+
68+
Args:
69+
tool: `WatsonxTool` from `WatsonxToolkit`
70+
71+
72+
Example:
73+
74+
.. code-block:: python
75+
76+
from langchain_ibm import WatsonxToolkit
77+
78+
watsonx_toolkit = WatsonxToolkit(
79+
url="https://us-south.ml.cloud.ibm.com",
80+
apikey="*****",
81+
)
82+
weather_tool = watsonx_toolkit.get_tool("Weather")
83+
convert_to_watsonx_tool(weather_tool)
84+
85+
# Return
86+
# {
87+
# "type": "function",
88+
# "function": {
89+
# "name": "Weather",
90+
# "description": "Find the weather for a city.",
91+
# "parameters": {
92+
# "type": "object",
93+
# "properties": {
94+
# "location": {
95+
# "title": "location",
96+
# "description": "Name of the location",
97+
# "type": "string",
98+
# },
99+
# "country": {
100+
# "title": "country",
101+
# "description": "Name of the state or country",
102+
# "type": "string",
103+
# },
104+
# },
105+
# "required": ["location"],
106+
# },
107+
# },
108+
# }
109+
110+
"""
111+
112+
def parse_parameters(input_schema: dict | None) -> dict:
113+
if input_schema:
114+
parameters = deepcopy(input_schema)
115+
else:
116+
parameters = {
117+
"type": "object",
118+
"properties": {
119+
"input": {
120+
"description": "Input to be used when running tool.",
121+
"type": "string",
122+
},
123+
},
124+
"required": ["input"],
125+
}
126+
127+
return parameters
128+
129+
watsonx_tool = {
130+
"type": "function",
131+
"function": {
132+
"name": tool.name,
133+
"description": tool.description,
134+
"parameters": parse_parameters(tool.tool_input_schema),
135+
},
136+
}
137+
return watsonx_tool

0 commit comments

Comments
 (0)