Skip to content

feat: introducing watsonx sql database wrapper and toolkit #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions libs/ibm/langchain_ibm/agent_toolkits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .sql_toolkit import WatsonxSQLDatabaseToolkit

__all__ = ["WatsonxSQLDatabaseToolkit"]
76 changes: 76 additions & 0 deletions libs/ibm/langchain_ibm/agent_toolkits/sql_toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""IBM watsonx.ai Toolkit wrapper."""

from typing import List

from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field

from ..utilities.sql_database import WatsonxSQLDatabase
from .tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QuerySQLCheckerTool,
QuerySQLDatabaseTool,
)


class WatsonxSQLDatabaseToolkit(BaseToolkit):
"""Toolkit for interacting with IBM watsonx.ai databases."""

db: WatsonxSQLDatabase = Field(exclude=True)
"""Instance of the watsonx SQL database."""

llm: BaseLanguageModel = Field(exclude=True)
"""Instance of the LLM."""

model_config = ConfigDict(
arbitrary_types_allowed=True,
)

def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
info_sql_database_tool_description = (
"Input to this tool is a comma-separated list of tables, output is the "
"SQL statement with table metadata. "
"Be sure that the tables actually exist by calling "
f"{list_sql_database_tool.name} first! "
"Example Input: table1, table2, table3"
)
info_sql_database_tool = InfoSQLDatabaseTool(
db=self.db, description=info_sql_database_tool_description
)
query_sql_database_tool_description = (
"Input to this tool is a detailed and correct SQL query, output is a "
"result from the database. If the query is not correct, an error message "
"will be returned. If an error is returned, rewrite the query, check the "
"query, and try again. If you encounter an issue with Unknown column "
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
"to query the correct table fields."
)
query_sql_database_tool = QuerySQLDatabaseTool(
db=self.db, description=query_sql_database_tool_description
)
query_sql_checker_tool_description = (
"Use this tool to double check if your query is correct before executing "
"it. Always use this tool before executing a query with "
f"{query_sql_database_tool.name}!"
)
query_sql_checker_tool = QuerySQLCheckerTool(
db=self.db, llm=self.llm, description=query_sql_checker_tool_description
)
return [
query_sql_database_tool,
info_sql_database_tool,
list_sql_database_tool,
query_sql_checker_tool,
]

def get_context(self) -> dict:
"""Return db context that you may want in agent prompt."""
return self.db.get_context()


WatsonxSQLDatabaseToolkit.model_rebuild()
180 changes: 180 additions & 0 deletions libs/ibm/langchain_ibm/agent_toolkits/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""Tools for interacting with a watsonx SQL databases via pyarrow.flight.FlightClient.

Based on the langchain_community.tools.sql_database.tool module."""

from typing import Any, Dict, Optional, Type, cast

from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, model_validator

from langchain_ibm.utilities.sql_database import WatsonxSQLDatabase

QUERY_CHECKER = """
{query}
Double check the query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
- Make sure that schema name `{schema}` is added to the table name, e.g. {schema}.table1

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

Output the final SQL query only.

SQL Query: """ # noqa: E501


class BaseSQLDatabaseTool(BaseModel):
"""Base tool for interacting with a SQL database."""

db: WatsonxSQLDatabase = Field(exclude=True)

model_config = ConfigDict(
arbitrary_types_allowed=True,
)


class _QuerySQLDatabaseToolInput(BaseModel):
query: str = Field(..., description="A detailed and correct SQL query.")


class QuerySQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for querying a SQL database."""

name: str = "sql_db_query"
description: str = """
Execute a SQL query against the database and get back the result.
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query correctness,
and try again.
"""
args_schema: Type[BaseModel] = _QuerySQLDatabaseToolInput

def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Execute the query, return the results or an error message."""
return self.db.run_no_throw(query)


class _InfoSQLDatabaseToolInput(BaseModel):
table_names: str = Field(
...,
description=(
"A comma-separated list of the table names "
"for which to return the schema. "
"Example input: 'table1, table2, table3'"
),
)


class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting metadata about a SQL database."""

name: str = "sql_db_schema"
description: str = "Get the schema and sample rows for the specified SQL tables."
args_schema: Type[BaseModel] = _InfoSQLDatabaseToolInput

def _run(
self,
table_names: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
return self.db.get_table_info_no_throw(
[t.strip() for t in table_names.split(",")]
)


class _ListSQLDatabaseToolInput(BaseModel):
tool_input: str = Field("", description="An empty string")


class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names."""

name: str = "sql_db_list_tables"
description: str = (
"Input is an empty string, output is a comma-separated list "
"of tables in the database."
)
args_schema: Type[BaseModel] = _ListSQLDatabaseToolInput

def _run(
self,
tool_input: str = "",
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get a comma-separated list of table names."""
return ", ".join(self.db.get_usable_table_names())


class _QuerySQLCheckerToolInput(BaseModel):
query: str = Field(..., description="A detailed and SQL query to be checked.")


class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct."""

template: str = QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: Any = Field(init=False)
name: str = "sql_db_query_checker"
description: str = """
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with sql_db_query!
"""
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput

@model_validator(mode="before")
@classmethod
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Any:
if "llm_chain" not in values:
prompt = PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "schema"]
)
llm = cast(BaseLanguageModel, values.get("llm"))

values["llm_chain"] = prompt | llm

if values["llm_chain"].first.input_variables != ["query", "schema"]:
raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'schema']" # noqa: E501
)

return values

def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the LLM to check the query."""
return self.llm_chain.invoke(
{"query": query, "schema": self.db.schema},
callbacks=run_manager.get_child() if run_manager else None,
).content

async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return await self.llm_chain.apredict(
query=query,
schema=self.db.schema,
callbacks=run_manager.get_child() if run_manager else None,
)
Empty file.
Loading