Skip to content

Commit d7bc0a5

Browse files
feat: introducing watsonx sql database wrapper and toolkit (#97)
* initial implementation of sql database class * add toolkit * improve description and add extras * change import path * add basic unittests * add extended unittests * improve docstring * add unittest for tool * minor changes in pyproject.toml * raise error if no tables found * revert changes in pyproject.toml * fix format * improve docstrings * include pyarrow installation in gh actions * fix unittest * fix poetry lock * fix gh action config * fix to install dependencies for linting * fix poetry lock * improve database and use recommended way of chaining * poetry lock after update * fix linting for tests * remove unused imports * fix sdk version in pyproject toml * after poetry update * change import path for wx toolkit * minor fixes * Update libs/ibm/langchain_ibm/utilities/sql_database.py Co-authored-by: Wojciech-Rebisz <[email protected]> * Update libs/ibm/langchain_ibm/agent_toolkits/utility/utils.py Co-authored-by: Wojciech-Rebisz <[email protected]> * Update libs/ibm/langchain_ibm/utilities/sql_database.py Co-authored-by: Wojciech-Rebisz <[email protected]> * fix poetry lock * import Optional in utils * refactoring --------- Co-authored-by: Wojciech-Rebisz <[email protected]>
1 parent 776e0cc commit d7bc0a5

File tree

17 files changed

+1476
-89
lines changed

17 files changed

+1476
-89
lines changed

libs/ibm/langchain_ibm/__init__.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,37 @@
1+
import importlib
2+
import warnings
3+
from typing import Any
4+
15
from langchain_ibm.chat_models import ChatWatsonx
26
from langchain_ibm.embeddings import WatsonxEmbeddings
37
from langchain_ibm.llms import WatsonxLLM
48
from langchain_ibm.rerank import WatsonxRerank
5-
from langchain_ibm.toolkit import WatsonxTool, WatsonxToolkit
69

710
__all__ = [
811
"WatsonxLLM",
912
"WatsonxEmbeddings",
1013
"ChatWatsonx",
1114
"WatsonxRerank",
12-
"WatsonxToolkit",
13-
"WatsonxTool",
1415
]
16+
17+
18+
_module_lookup = {
19+
"WatsonxTool": "langchain_ibm.agent_toolkits.utility.toolkit",
20+
"WatsonxToolkit": "langchain_ibm.agent_toolkits.utility.toolkit",
21+
}
22+
23+
24+
def __getattr__(name: str) -> Any:
25+
"""Look up attributes dynamically."""
26+
if name in _module_lookup:
27+
warnings.warn(
28+
(
29+
f"Import path `from langchain_ibm import {name}` is deprecated "
30+
"and may be removed in future. "
31+
f"Use `from langchain_ibm.agent_toolkits.utility import {name}` instead." # noqa: E501
32+
),
33+
category=DeprecationWarning,
34+
)
35+
module = importlib.import_module(_module_lookup[name])
36+
return getattr(module, name)
37+
raise AttributeError(f"module {__name__} has no attribute {name}")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .toolkit import WatsonxSQLDatabaseToolkit
2+
3+
__all__ = ["WatsonxSQLDatabaseToolkit"]
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""Tools for interacting with a watsonx SQL databases via pyarrow.flight.FlightClient.
2+
3+
Based on the langchain_community.tools.sql_database.tool module."""
4+
5+
from typing import Any, Dict, Optional, Type, cast
6+
7+
from langchain_core.callbacks import (
8+
AsyncCallbackManagerForToolRun,
9+
CallbackManagerForToolRun,
10+
)
11+
from langchain_core.language_models import BaseLanguageModel
12+
from langchain_core.prompts import PromptTemplate
13+
from langchain_core.tools import BaseTool
14+
from pydantic import BaseModel, ConfigDict, Field, model_validator
15+
16+
from langchain_ibm.utilities.sql_database import WatsonxSQLDatabase
17+
18+
QUERY_CHECKER = """
19+
{query}
20+
Double check the query above for common mistakes, including:
21+
- Using NOT IN with NULL values
22+
- Using UNION when UNION ALL should have been used
23+
- Using BETWEEN for exclusive ranges
24+
- Data type mismatch in predicates
25+
- Properly quoting identifiers
26+
- Using the correct number of arguments for functions
27+
- Casting to the correct data type
28+
- Using the proper columns for joins
29+
- Make sure that schema name `{schema}` is added to the table name, e.g. {schema}.table1
30+
31+
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
32+
33+
Output the final SQL query only.
34+
35+
SQL Query: """ # noqa: E501
36+
37+
38+
class BaseSQLDatabaseTool(BaseModel):
39+
"""Base tool for interacting with a SQL database."""
40+
41+
db: WatsonxSQLDatabase = Field(exclude=True)
42+
43+
model_config = ConfigDict(
44+
arbitrary_types_allowed=True,
45+
)
46+
47+
48+
class _QuerySQLDatabaseToolInput(BaseModel):
49+
query: str = Field(..., description="A detailed and correct SQL query.")
50+
51+
52+
class QuerySQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
53+
"""Tool for querying a SQL database."""
54+
55+
name: str = "sql_db_query"
56+
description: str = """
57+
Execute a SQL query against the database and get back the result.
58+
If the query is not correct, an error message will be returned.
59+
If an error is returned, rewrite the query, check the query correctness,
60+
and try again.
61+
"""
62+
args_schema: Type[BaseModel] = _QuerySQLDatabaseToolInput
63+
64+
def _run(
65+
self,
66+
query: str,
67+
run_manager: Optional[CallbackManagerForToolRun] = None,
68+
) -> str:
69+
"""Execute the query, return the results or an error message."""
70+
return self.db.run_no_throw(query)
71+
72+
73+
class _InfoSQLDatabaseToolInput(BaseModel):
74+
table_names: str = Field(
75+
...,
76+
description=(
77+
"A comma-separated list of the table names "
78+
"for which to return the schema. "
79+
"Example input: 'table1, table2, table3'"
80+
),
81+
)
82+
83+
84+
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
85+
"""Tool for getting metadata about a SQL database."""
86+
87+
name: str = "sql_db_schema"
88+
description: str = "Get the schema and sample rows for the specified SQL tables."
89+
args_schema: Type[BaseModel] = _InfoSQLDatabaseToolInput
90+
91+
def _run(
92+
self,
93+
table_names: str,
94+
run_manager: Optional[CallbackManagerForToolRun] = None,
95+
) -> str:
96+
"""Get the schema for tables in a comma-separated list."""
97+
return self.db.get_table_info_no_throw(
98+
[t.strip() for t in table_names.split(",")]
99+
)
100+
101+
102+
class _ListSQLDatabaseToolInput(BaseModel):
103+
tool_input: str = Field("", description="An empty string")
104+
105+
106+
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
107+
"""Tool for getting tables names."""
108+
109+
name: str = "sql_db_list_tables"
110+
description: str = (
111+
"Input is an empty string, output is a comma-separated list "
112+
"of tables in the database."
113+
)
114+
args_schema: Type[BaseModel] = _ListSQLDatabaseToolInput
115+
116+
def _run(
117+
self,
118+
tool_input: str = "",
119+
run_manager: Optional[CallbackManagerForToolRun] = None,
120+
) -> str:
121+
"""Get a comma-separated list of table names."""
122+
return ", ".join(self.db.get_usable_table_names())
123+
124+
125+
class _QuerySQLCheckerToolInput(BaseModel):
126+
query: str = Field(..., description="A detailed and SQL query to be checked.")
127+
128+
129+
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
130+
"""Use an LLM to check if a query is correct."""
131+
132+
template: str = QUERY_CHECKER
133+
llm: BaseLanguageModel
134+
llm_chain: Any = Field(init=False)
135+
name: str = "sql_db_query_checker"
136+
description: str = """
137+
Use this tool to double check if your query is correct before executing it.
138+
Always use this tool before executing a query with sql_db_query!
139+
"""
140+
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
141+
142+
@model_validator(mode="before")
143+
@classmethod
144+
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Any:
145+
if "llm_chain" not in values:
146+
prompt = PromptTemplate(
147+
template=QUERY_CHECKER, input_variables=["query", "schema"]
148+
)
149+
llm = cast(BaseLanguageModel, values.get("llm"))
150+
151+
values["llm_chain"] = prompt | llm
152+
153+
if values["llm_chain"].first.input_variables != ["query", "schema"]:
154+
raise ValueError(
155+
"LLM chain for QueryCheckerTool must have input variables ['query', 'schema']" # noqa: E501
156+
)
157+
158+
return values
159+
160+
def _run(
161+
self,
162+
query: str,
163+
run_manager: Optional[CallbackManagerForToolRun] = None,
164+
) -> str:
165+
"""Use the LLM to check the query."""
166+
return self.llm_chain.invoke(
167+
{"query": query, "schema": self.db.schema},
168+
callbacks=run_manager.get_child() if run_manager else None,
169+
).content
170+
171+
async def _arun(
172+
self,
173+
query: str,
174+
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
175+
) -> str:
176+
return await self.llm_chain.apredict(
177+
query=query,
178+
schema=self.db.schema,
179+
callbacks=run_manager.get_child() if run_manager else None,
180+
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""IBM watsonx.ai SQL Toolkit wrapper."""
2+
3+
from typing import List
4+
5+
from langchain_core.language_models import BaseLanguageModel
6+
from langchain_core.tools import BaseTool
7+
from langchain_core.tools.base import BaseToolkit
8+
from pydantic import ConfigDict, Field
9+
10+
from langchain_ibm.utilities.sql_database import WatsonxSQLDatabase
11+
12+
from .tool import (
13+
InfoSQLDatabaseTool,
14+
ListSQLDatabaseTool,
15+
QuerySQLCheckerTool,
16+
QuerySQLDatabaseTool,
17+
)
18+
19+
20+
class WatsonxSQLDatabaseToolkit(BaseToolkit):
21+
"""Toolkit for interacting with IBM watsonx.ai databases."""
22+
23+
db: WatsonxSQLDatabase = Field(exclude=True)
24+
"""Instance of the watsonx SQL database."""
25+
26+
llm: BaseLanguageModel = Field(exclude=True)
27+
"""Instance of the LLM."""
28+
29+
model_config = ConfigDict(
30+
arbitrary_types_allowed=True,
31+
)
32+
33+
def get_tools(self) -> List[BaseTool]:
34+
"""Get the tools in the toolkit."""
35+
list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
36+
info_sql_database_tool_description = (
37+
"Input to this tool is a comma-separated list of tables, output is the "
38+
"SQL statement with table metadata. "
39+
"Be sure that the tables actually exist by calling "
40+
f"{list_sql_database_tool.name} first! "
41+
"Example Input: table1, table2, table3"
42+
)
43+
info_sql_database_tool = InfoSQLDatabaseTool(
44+
db=self.db, description=info_sql_database_tool_description
45+
)
46+
query_sql_database_tool_description = (
47+
"Input to this tool is a detailed and correct SQL query, output is a "
48+
"result from the database. If the query is not correct, an error message "
49+
"will be returned. If an error is returned, rewrite the query, check the "
50+
"query, and try again. If you encounter an issue with Unknown column "
51+
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
52+
"to query the correct table fields."
53+
)
54+
query_sql_database_tool = QuerySQLDatabaseTool(
55+
db=self.db, description=query_sql_database_tool_description
56+
)
57+
query_sql_checker_tool_description = (
58+
"Use this tool to double check if your query is correct before executing "
59+
"it. Always use this tool before executing a query with "
60+
f"{query_sql_database_tool.name}!"
61+
)
62+
query_sql_checker_tool = QuerySQLCheckerTool(
63+
db=self.db, llm=self.llm, description=query_sql_checker_tool_description
64+
)
65+
return [
66+
query_sql_database_tool,
67+
info_sql_database_tool,
68+
list_sql_database_tool,
69+
query_sql_checker_tool,
70+
]
71+
72+
def get_context(self) -> dict:
73+
"""Return db context that you may want in agent prompt."""
74+
return self.db.get_context()
75+
76+
77+
WatsonxSQLDatabaseToolkit.model_rebuild()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .toolkit import WatsonxTool, WatsonxToolkit
2+
3+
__all__ = ["WatsonxToolkit", "WatsonxTool"]

libs/ibm/langchain_ibm/toolkit.py renamed to libs/ibm/langchain_ibm/agent_toolkits/utility/toolkit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
)
3030
from typing_extensions import Self
3131

32-
from langchain_ibm.utils import check_for_attribute, convert_to_watsonx_tool
32+
from langchain_ibm.utils import check_for_attribute
33+
34+
from .utils import convert_to_watsonx_tool
3335

3436

3537
class WatsonxTool(BaseTool):
@@ -132,7 +134,7 @@ class WatsonxToolkit(BaseToolkit):
132134
Example:
133135
.. code-block:: python
134136
135-
from langchain_ibm import WatsonxToolkit
137+
from langchain_ibm.agents_toolkits.utility import WatsonxToolkit
136138
137139
watsonx_toolkit = WatsonxToolkit(
138140
url="https://us-south.ml.cloud.ibm.com",

0 commit comments

Comments
 (0)