generated from langchain-ai/integration-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 18
feat: introducing watsonx sql database wrapper and toolkit #97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Mateusz-Switala
wants to merge
26
commits into
main
Choose a base branch
from
feat-wx-sql-database
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,352
−2
Open
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
625b708
initial implementation of sql database class
Mateusz-Switala 78a5aa5
add toolkit
Mateusz-Switala 5603c22
improve description and add extras
Mateusz-Switala 7a9629f
change import path
Mateusz-Switala 38247d7
add basic unittests
Mateusz-Switala 2603229
add extended unittests
Mateusz-Switala 09d206c
improve docstring
Mateusz-Switala 81e5740
add unittest for tool
Mateusz-Switala ae8f941
minor changes in pyproject.toml
Mateusz-Switala 7b81955
raise error if no tables found
Mateusz-Switala 8c5e041
revert changes in pyproject.toml
Mateusz-Switala 555238d
fix format
Mateusz-Switala 49c4d5c
improve docstrings
Mateusz-Switala 946b7f7
include pyarrow installation in gh actions
Mateusz-Switala 9911b2a
fix unittest
Mateusz-Switala 455088c
fix poetry lock
Mateusz-Switala d43c35d
fix gh action config
Mateusz-Switala 006da26
fix to install dependencies for linting
Mateusz-Switala 2b977a7
fix poetry lock
Mateusz-Switala 1fe3d96
improve database and use recommended way of chaining
Mateusz-Switala 669fbbf
poetry lock after update
Mateusz-Switala 6387549
fix linting for tests
Mateusz-Switala 228a650
remove unused imports
Mateusz-Switala 1a962a2
fix sdk version in pyproject toml
Mateusz-Switala b844978
Merge branch 'main' into feat-wx-sql-database
Mateusz-Switala 40ec203
after poetry update
Mateusz-Switala File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .sql_toolkit import WatsonxSQLDatabaseToolkit | ||
|
||
__all__ = ["WatsonxSQLDatabaseToolkit"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
"""IBM watsonx.ai Toolkit wrapper.""" | ||
|
||
from typing import List | ||
|
||
from langchain_core.caches import BaseCache as BaseCache | ||
from langchain_core.callbacks import Callbacks as Callbacks | ||
from langchain_core.language_models import BaseLanguageModel | ||
from langchain_core.tools import BaseTool | ||
from langchain_core.tools.base import BaseToolkit | ||
from pydantic import ConfigDict, Field | ||
|
||
from ..utilities.sql_database import WatsonxSQLDatabase | ||
from .tool import ( | ||
InfoSQLDatabaseTool, | ||
ListSQLDatabaseTool, | ||
QuerySQLCheckerTool, | ||
QuerySQLDatabaseTool, | ||
) | ||
|
||
|
||
class WatsonxSQLDatabaseToolkit(BaseToolkit): | ||
"""Toolkit for interacting with IBM watsonx.ai databases.""" | ||
|
||
db: WatsonxSQLDatabase = Field(exclude=True) | ||
"""Instance of the watsonx SQL database.""" | ||
|
||
llm: BaseLanguageModel = Field(exclude=True) | ||
"""Instance of the LLM.""" | ||
|
||
model_config = ConfigDict( | ||
arbitrary_types_allowed=True, | ||
) | ||
|
||
def get_tools(self) -> List[BaseTool]: | ||
"""Get the tools in the toolkit.""" | ||
list_sql_database_tool = ListSQLDatabaseTool(db=self.db) | ||
info_sql_database_tool_description = ( | ||
"Input to this tool is a comma-separated list of tables, output is the " | ||
"SQL statement with table metadata. " | ||
"Be sure that the tables actually exist by calling " | ||
f"{list_sql_database_tool.name} first! " | ||
"Example Input: table1, table2, table3" | ||
) | ||
info_sql_database_tool = InfoSQLDatabaseTool( | ||
db=self.db, description=info_sql_database_tool_description | ||
) | ||
query_sql_database_tool_description = ( | ||
"Input to this tool is a detailed and correct SQL query, output is a " | ||
"result from the database. If the query is not correct, an error message " | ||
"will be returned. If an error is returned, rewrite the query, check the " | ||
"query, and try again. If you encounter an issue with Unknown column " | ||
f"'xxxx' in 'field list', use {info_sql_database_tool.name} " | ||
"to query the correct table fields." | ||
) | ||
query_sql_database_tool = QuerySQLDatabaseTool( | ||
db=self.db, description=query_sql_database_tool_description | ||
) | ||
query_sql_checker_tool_description = ( | ||
"Use this tool to double check if your query is correct before executing " | ||
"it. Always use this tool before executing a query with " | ||
f"{query_sql_database_tool.name}!" | ||
) | ||
query_sql_checker_tool = QuerySQLCheckerTool( | ||
db=self.db, llm=self.llm, description=query_sql_checker_tool_description | ||
) | ||
return [ | ||
query_sql_database_tool, | ||
info_sql_database_tool, | ||
list_sql_database_tool, | ||
query_sql_checker_tool, | ||
] | ||
|
||
def get_context(self) -> dict: | ||
"""Return db context that you may want in agent prompt.""" | ||
return self.db.get_context() | ||
|
||
|
||
WatsonxSQLDatabaseToolkit.model_rebuild() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
"""Tools for interacting with a watsonx SQL databases via pyarrow.flight.FlightClient. | ||
|
||
Based on the langchain_community.tools.sql_database.tool module.""" | ||
|
||
from typing import Any, Dict, Optional, Type | ||
|
||
from langchain_core.callbacks import ( | ||
AsyncCallbackManagerForToolRun, | ||
CallbackManagerForToolRun, | ||
) | ||
from langchain_core.language_models import BaseLanguageModel | ||
from langchain_core.prompts import PromptTemplate | ||
from langchain_core.tools import BaseTool | ||
from pydantic import BaseModel, ConfigDict, Field, model_validator | ||
|
||
from langchain_ibm.utilities.sql_database import WatsonxSQLDatabase | ||
|
||
QUERY_CHECKER = """ | ||
{query} | ||
Double check the query above for common mistakes, including: | ||
- Using NOT IN with NULL values | ||
- Using UNION when UNION ALL should have been used | ||
- Using BETWEEN for exclusive ranges | ||
- Data type mismatch in predicates | ||
- Properly quoting identifiers | ||
- Using the correct number of arguments for functions | ||
- Casting to the correct data type | ||
- Using the proper columns for joins | ||
- Make sure that schema name `{schema}` is added to the table name, e.g. {schema}.table1 | ||
|
||
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query. | ||
|
||
Output the final SQL query only. | ||
|
||
SQL Query: """ # noqa: E501 | ||
|
||
|
||
class BaseSQLDatabaseTool(BaseModel): | ||
"""Base tool for interacting with a SQL database.""" | ||
|
||
db: WatsonxSQLDatabase = Field(exclude=True) | ||
|
||
model_config = ConfigDict( | ||
arbitrary_types_allowed=True, | ||
) | ||
|
||
|
||
class _QuerySQLDatabaseToolInput(BaseModel): | ||
query: str = Field(..., description="A detailed and correct SQL query.") | ||
|
||
|
||
class QuerySQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): | ||
"""Tool for querying a SQL database.""" | ||
|
||
name: str = "sql_db_query" | ||
description: str = """ | ||
Execute a SQL query against the database and get back the result. | ||
If the query is not correct, an error message will be returned. | ||
If an error is returned, rewrite the query, check the query correctness, | ||
and try again. | ||
""" | ||
args_schema: Type[BaseModel] = _QuerySQLDatabaseToolInput | ||
|
||
def _run( | ||
self, | ||
query: str, | ||
run_manager: Optional[CallbackManagerForToolRun] = None, | ||
) -> str: | ||
"""Execute the query, return the results or an error message.""" | ||
return self.db.run_no_throw(query) | ||
|
||
|
||
class _InfoSQLDatabaseToolInput(BaseModel): | ||
table_names: str = Field( | ||
..., | ||
description=( | ||
"A comma-separated list of the table names " | ||
"for which to return the schema. " | ||
"Example input: 'table1, table2, table3'" | ||
), | ||
) | ||
|
||
|
||
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): | ||
"""Tool for getting metadata about a SQL database.""" | ||
|
||
name: str = "sql_db_schema" | ||
description: str = "Get the schema and sample rows for the specified SQL tables." | ||
args_schema: Type[BaseModel] = _InfoSQLDatabaseToolInput | ||
|
||
def _run( | ||
self, | ||
table_names: str, | ||
run_manager: Optional[CallbackManagerForToolRun] = None, | ||
) -> str: | ||
"""Get the schema for tables in a comma-separated list.""" | ||
return self.db.get_table_info_no_throw( | ||
[t.strip() for t in table_names.split(",")] | ||
) | ||
|
||
|
||
class _ListSQLDatabaseToolInput(BaseModel): | ||
tool_input: str = Field("", description="An empty string") | ||
|
||
|
||
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): | ||
"""Tool for getting tables names.""" | ||
|
||
name: str = "sql_db_list_tables" | ||
description: str = ( | ||
"Input is an empty string, output is a comma-separated list " | ||
"of tables in the database." | ||
) | ||
args_schema: Type[BaseModel] = _ListSQLDatabaseToolInput | ||
|
||
def _run( | ||
self, | ||
tool_input: str = "", | ||
run_manager: Optional[CallbackManagerForToolRun] = None, | ||
) -> str: | ||
"""Get a comma-separated list of table names.""" | ||
return ", ".join(self.db.get_usable_table_names()) | ||
|
||
|
||
class _QuerySQLCheckerToolInput(BaseModel): | ||
query: str = Field(..., description="A detailed and SQL query to be checked.") | ||
|
||
|
||
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): | ||
"""Use an LLM to check if a query is correct.""" | ||
|
||
template: str = QUERY_CHECKER | ||
llm: BaseLanguageModel | ||
llm_chain: Any = Field(init=False) | ||
name: str = "sql_db_query_checker" | ||
description: str = """ | ||
Use this tool to double check if your query is correct before executing it. | ||
Always use this tool before executing a query with sql_db_query! | ||
""" | ||
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput | ||
|
||
@model_validator(mode="before") | ||
@classmethod | ||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Any: | ||
if "llm_chain" not in values: | ||
from langchain.chains.llm import LLMChain | ||
|
||
values["llm_chain"] = LLMChain( | ||
llm=values.get("llm"), # type: ignore[arg-type] | ||
prompt=PromptTemplate( | ||
template=QUERY_CHECKER, input_variables=["query", "schema"] | ||
), | ||
) | ||
|
||
if values["llm_chain"].prompt.input_variables != ["query", "schema"]: | ||
raise ValueError( | ||
"LLM chain for QueryCheckerTool must have input variables ['query', 'schema']" # noqa: E501 | ||
) | ||
|
||
return values | ||
|
||
def _run( | ||
self, | ||
query: str, | ||
run_manager: Optional[CallbackManagerForToolRun] = None, | ||
) -> str: | ||
"""Use the LLM to check the query.""" | ||
return self.llm_chain.predict( | ||
query=query, | ||
schema=self.db.schema, | ||
callbacks=run_manager.get_child() if run_manager else None, | ||
) | ||
|
||
async def _arun( | ||
self, | ||
query: str, | ||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | ||
) -> str: | ||
return await self.llm_chain.apredict( | ||
query=query, | ||
schema=self.db.schema, | ||
callbacks=run_manager.get_child() if run_manager else None, | ||
) |
Empty file.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we are importing
BaseCache
asBaseCache
etc. ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No needed, removed