diff --git a/libs/ibm/langchain_ibm/agent_toolkits/__init__.py b/libs/ibm/langchain_ibm/agent_toolkits/__init__.py new file mode 100644 index 0000000..be2324b --- /dev/null +++ b/libs/ibm/langchain_ibm/agent_toolkits/__init__.py @@ -0,0 +1,3 @@ +from .sql_toolkit import WatsonxSQLDatabaseToolkit + +__all__ = ["WatsonxSQLDatabaseToolkit"] diff --git a/libs/ibm/langchain_ibm/agent_toolkits/sql_toolkit.py b/libs/ibm/langchain_ibm/agent_toolkits/sql_toolkit.py new file mode 100644 index 0000000..087e2a6 --- /dev/null +++ b/libs/ibm/langchain_ibm/agent_toolkits/sql_toolkit.py @@ -0,0 +1,76 @@ +"""IBM watsonx.ai Toolkit wrapper.""" + +from typing import List + +from langchain_core.language_models import BaseLanguageModel +from langchain_core.tools import BaseTool +from langchain_core.tools.base import BaseToolkit +from pydantic import ConfigDict, Field + +from ..utilities.sql_database import WatsonxSQLDatabase +from .tool import ( + InfoSQLDatabaseTool, + ListSQLDatabaseTool, + QuerySQLCheckerTool, + QuerySQLDatabaseTool, +) + + +class WatsonxSQLDatabaseToolkit(BaseToolkit): + """Toolkit for interacting with IBM watsonx.ai databases.""" + + db: WatsonxSQLDatabase = Field(exclude=True) + """Instance of the watsonx SQL database.""" + + llm: BaseLanguageModel = Field(exclude=True) + """Instance of the LLM.""" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + list_sql_database_tool = ListSQLDatabaseTool(db=self.db) + info_sql_database_tool_description = ( + "Input to this tool is a comma-separated list of tables, output is the " + "SQL statement with table metadata. " + "Be sure that the tables actually exist by calling " + f"{list_sql_database_tool.name} first! " + "Example Input: table1, table2, table3" + ) + info_sql_database_tool = InfoSQLDatabaseTool( + db=self.db, description=info_sql_database_tool_description + ) + query_sql_database_tool_description = ( + "Input to this tool is a detailed and correct SQL query, output is a " + "result from the database. If the query is not correct, an error message " + "will be returned. If an error is returned, rewrite the query, check the " + "query, and try again. If you encounter an issue with Unknown column " + f"'xxxx' in 'field list', use {info_sql_database_tool.name} " + "to query the correct table fields." + ) + query_sql_database_tool = QuerySQLDatabaseTool( + db=self.db, description=query_sql_database_tool_description + ) + query_sql_checker_tool_description = ( + "Use this tool to double check if your query is correct before executing " + "it. Always use this tool before executing a query with " + f"{query_sql_database_tool.name}!" + ) + query_sql_checker_tool = QuerySQLCheckerTool( + db=self.db, llm=self.llm, description=query_sql_checker_tool_description + ) + return [ + query_sql_database_tool, + info_sql_database_tool, + list_sql_database_tool, + query_sql_checker_tool, + ] + + def get_context(self) -> dict: + """Return db context that you may want in agent prompt.""" + return self.db.get_context() + + +WatsonxSQLDatabaseToolkit.model_rebuild() diff --git a/libs/ibm/langchain_ibm/agent_toolkits/tool.py b/libs/ibm/langchain_ibm/agent_toolkits/tool.py new file mode 100644 index 0000000..7032a59 --- /dev/null +++ b/libs/ibm/langchain_ibm/agent_toolkits/tool.py @@ -0,0 +1,180 @@ +"""Tools for interacting with a watsonx SQL databases via pyarrow.flight.FlightClient. + +Based on the langchain_community.tools.sql_database.tool module.""" + +from typing import Any, Dict, Optional, Type, cast + +from langchain_core.callbacks import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import PromptTemplate +from langchain_core.tools import BaseTool +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from langchain_ibm.utilities.sql_database import WatsonxSQLDatabase + +QUERY_CHECKER = """ +{query} +Double check the query above for common mistakes, including: +- Using NOT IN with NULL values +- Using UNION when UNION ALL should have been used +- Using BETWEEN for exclusive ranges +- Data type mismatch in predicates +- Properly quoting identifiers +- Using the correct number of arguments for functions +- Casting to the correct data type +- Using the proper columns for joins +- Make sure that schema name `{schema}` is added to the table name, e.g. {schema}.table1 + +If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query. + +Output the final SQL query only. + +SQL Query: """ # noqa: E501 + + +class BaseSQLDatabaseTool(BaseModel): + """Base tool for interacting with a SQL database.""" + + db: WatsonxSQLDatabase = Field(exclude=True) + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + +class _QuerySQLDatabaseToolInput(BaseModel): + query: str = Field(..., description="A detailed and correct SQL query.") + + +class QuerySQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for querying a SQL database.""" + + name: str = "sql_db_query" + description: str = """ + Execute a SQL query against the database and get back the result. + If the query is not correct, an error message will be returned. + If an error is returned, rewrite the query, check the query correctness, + and try again. + """ + args_schema: Type[BaseModel] = _QuerySQLDatabaseToolInput + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Execute the query, return the results or an error message.""" + return self.db.run_no_throw(query) + + +class _InfoSQLDatabaseToolInput(BaseModel): + table_names: str = Field( + ..., + description=( + "A comma-separated list of the table names " + "for which to return the schema. " + "Example input: 'table1, table2, table3'" + ), + ) + + +class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for getting metadata about a SQL database.""" + + name: str = "sql_db_schema" + description: str = "Get the schema and sample rows for the specified SQL tables." + args_schema: Type[BaseModel] = _InfoSQLDatabaseToolInput + + def _run( + self, + table_names: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get the schema for tables in a comma-separated list.""" + return self.db.get_table_info_no_throw( + [t.strip() for t in table_names.split(",")] + ) + + +class _ListSQLDatabaseToolInput(BaseModel): + tool_input: str = Field("", description="An empty string") + + +class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): + """Tool for getting tables names.""" + + name: str = "sql_db_list_tables" + description: str = ( + "Input is an empty string, output is a comma-separated list " + "of tables in the database." + ) + args_schema: Type[BaseModel] = _ListSQLDatabaseToolInput + + def _run( + self, + tool_input: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Get a comma-separated list of table names.""" + return ", ".join(self.db.get_usable_table_names()) + + +class _QuerySQLCheckerToolInput(BaseModel): + query: str = Field(..., description="A detailed and SQL query to be checked.") + + +class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): + """Use an LLM to check if a query is correct.""" + + template: str = QUERY_CHECKER + llm: BaseLanguageModel + llm_chain: Any = Field(init=False) + name: str = "sql_db_query_checker" + description: str = """ + Use this tool to double check if your query is correct before executing it. + Always use this tool before executing a query with sql_db_query! + """ + args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput + + @model_validator(mode="before") + @classmethod + def initialize_llm_chain(cls, values: Dict[str, Any]) -> Any: + if "llm_chain" not in values: + prompt = PromptTemplate( + template=QUERY_CHECKER, input_variables=["query", "schema"] + ) + llm = cast(BaseLanguageModel, values.get("llm")) + + values["llm_chain"] = prompt | llm + + if values["llm_chain"].first.input_variables != ["query", "schema"]: + raise ValueError( + "LLM chain for QueryCheckerTool must have input variables ['query', 'schema']" # noqa: E501 + ) + + return values + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the LLM to check the query.""" + return self.llm_chain.invoke( + {"query": query, "schema": self.db.schema}, + callbacks=run_manager.get_child() if run_manager else None, + ).content + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: + return await self.llm_chain.apredict( + query=query, + schema=self.db.schema, + callbacks=run_manager.get_child() if run_manager else None, + ) diff --git a/libs/ibm/langchain_ibm/utilities/__init__.py b/libs/ibm/langchain_ibm/utilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/ibm/langchain_ibm/utilities/sql_database.py b/libs/ibm/langchain_ibm/utilities/sql_database.py new file mode 100644 index 0000000..5452a92 --- /dev/null +++ b/libs/ibm/langchain_ibm/utilities/sql_database.py @@ -0,0 +1,404 @@ +import urllib.parse +from typing import Any, Dict, Iterable, List, Optional, Union + +try: + import pyarrow.flight as flight # type: ignore[import-untyped] +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "To use WatsonxSQLDatabase one need to install langchain-ibm with extras " + "`sql_toolkit`: `pip install langchain-ibm[sql_toolkit]`" + ) from e + +from ibm_watsonx_ai import APIClient, Credentials # type: ignore[import-untyped] +from ibm_watsonx_ai.helpers.connections.flight_sql_service import ( # type: ignore[import-untyped] + FlightSQLClient, +) +from langchain_core.utils.utils import from_env + + +def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str: + """ + Truncate a string to a certain number of characters, based on the max string + length. + + Based on the analogous function from langchain_common.utilities.sql_database.py + """ + + if not isinstance(content, str) or length <= 0: + return content + + if len(content) <= length: + return content + + return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix + + +def _validate_param(value: Optional[str], key: str, env_key: str) -> None: + if value is None: + raise ValueError( + f"Did not find {key}, please add an environment variable" + f" `{env_key}` which contains it, or pass" + f" `{key}` as a named parameter." + ) + return None + + +def _from_env(env_var_name: str) -> str | None: + """Read env variable. If it is not set, return None.""" + return from_env(env_var_name, default=None)() + + +def pretty_print_table_info(schema: str, table_name: str, table_info: dict) -> str: + def convert_column_data(field_metadata: dict) -> str: + name = field_metadata.get("name") + + field_metadata_type = field_metadata.get("type", {}) + native_type = field_metadata_type.get("native_type") + nullable = field_metadata_type.get("nullable") + + return f"{name} {native_type}{'' if nullable else ' NOT NULL'}," + + create_table_template = """ +CREATE TABLE {schema}.{table_name} ( +\t{column_definitions}{primary_key} +\t)""" + + primary_key: dict = next( + filter( + lambda el: el.get("name") == "primary_key", + table_info.get("extended_metadata", [{}]), + ), + {}, + ) + key_columns = primary_key.get("value", {}).get("key_columns", []) + + return create_table_template.format( + schema=schema, + table_name=table_name, + column_definitions="\n\t".join( + [ + convert_column_data(field_metadata=field_metadata) + for field_metadata in table_info["fields"] + ] + ), + primary_key=f"\n\tPRIMARY KEY ({', '.join(key_columns)})" + if primary_key + else "", + ) + + +class WatsonxSQLDatabase: + """Watsonx SQL Database class for IBM watsonx.ai databases + connection assets. Uses Arrow Flight to interact with databases via watsonx. + + :param connection_id: ID of db connection asset + :type connection_id: str + + :param schema: name of the database schema from which tables will be read + :type schema: str + + :param project_id: ID of project, defaults to None + :type project_id: Optional[str], optional + + :param space_id: ID of space, defaults to None + :type space_id: Optional[str], optional + + :param url: URL to the Watson Machine Learning or CPD instance, defaults to None + :type url: Optional[str], optional + + :param apikey: API key to the Watson Machine Learning + or CPD instance, defaults to None + :type apikey: Optional[str], optional + + :param token: service token, used in token authentication, defaults to None + :type token: Optional[str], optional + + :param password: password to the CPD instance., defaults to None + :type password: Optional[str], optional + + :param username: username to the CPD instance., defaults to None + :type username: Optional[str], optional + + :param instance_id: instance_id of the CPD instance., defaults to None + :type instance_id: Optional[str], optional + + :param version: version of the CPD instance, defaults to None + :type version: Optional[str], optional + + :param verify: certificate verification flag, defaults to None + :type verify: Union[str, bool, None], optional + + :param watsonx_client: instance of `ibm_watsonx_ai.APIClient`, defaults to None + :type watsonx_client: Optional[APIClient], optional + + :param ignore_tables: list of tables that will be ignored, defaults to None + :type ignore_tables: Optional[List[str]], optional + + :param include_tables: list of tables that should be included, defaults to None + :type include_tables: Optional[List[str]], optional + + :param sample_rows_in_table_info: number of first rows to be added to the + table info, defaults to 3 + :type sample_rows_in_table_info: int, optional + + :param max_string_length: max length of string, defaults to 300 + :type max_string_length: int, optional + + :raises ValueError: raise if some required credentials are missing + :raises RuntimeError: raise if no tables found in given schema + """ + + def __init__( + self, + *, + connection_id: str, + schema: str, + project_id: Optional[str] = None, + space_id: Optional[str] = None, + url: Optional[str] = None, + apikey: Optional[str] = None, + token: Optional[str] = None, + password: Optional[str] = None, + username: Optional[str] = None, + instance_id: Optional[str] = None, + version: Optional[str] = None, + verify: Union[str, bool, None] = None, + watsonx_client: Optional[APIClient] = None, #: :meta private: + ignore_tables: Optional[List[str]] = None, + include_tables: Optional[List[str]] = None, + sample_rows_in_table_info: int = 3, + max_string_length: int = 300, + ) -> None: + if include_tables and ignore_tables: + raise ValueError("Cannot specify both include_tables and ignore_tables") + + self.schema = schema + self._ignore_tables = set(ignore_tables) if ignore_tables else set() + self._include_tables = set(include_tables) if include_tables else set() + self._sample_rows_in_table_info = sample_rows_in_table_info + + self._max_string_length = max_string_length + + if watsonx_client is None: + url = url or _from_env("WATSONX_URL") + _validate_param(url, "url", "WATSONX_URL") + + parsed_url = urllib.parse.urlparse(url) + if parsed_url.netloc.endswith(".cloud.ibm.com"): # type: ignore[arg-type] + token = token or _from_env("WATSONX_TOKEN") + apikey = apikey or _from_env("WATSONX_APIKEY") + if not token and not apikey: + raise ValueError( + "Did not find 'apikey' or 'token'," + " please add an environment variable" + " `WATSONX_APIKEY` or 'WATSONX_TOKEN' " + "which contains it," + " or pass 'apikey' or 'token'" + " as a named parameter." + ) + else: + token = token or _from_env("WATSONX_TOKEN") + apikey = apikey or _from_env("WATSONX_APIKEY") + password = password or _from_env("WATSONX_PASSWORD") + if not token and not password and not apikey: + raise ValueError( + "Did not find 'token', 'password' or 'apikey'," + " please add an environment variable" + " `WATSONX_TOKEN`, 'WATSONX_PASSWORD' or 'WATSONX_APIKEY' " + "which contains it," + " or pass 'token', 'password' or 'apikey'" + " as a named parameter." + ) + + try: + _validate_param(token, "token", "WATSONX_TOKEN") + except ValueError: + pass + + try: + _validate_param(password, "password", "WATSONX_PASSWORD") + except ValueError: + pass + else: + username = username or _from_env("WATSONX_USERNAME") + _validate_param(username, "username", "WATSONX_USERNAME") + + try: + _validate_param(apikey, "apikey", "WATSONX_APIKEY") + except ValueError: + pass + else: + username = username or _from_env("WATSONX_USERNAME") + _validate_param(username, "username", "WATSONX_USERNAME") + + instance_id = instance_id or _from_env("WATSONX_INSTANCE_ID") + _validate_param(instance_id, "instance_id", "WATSONX_INSTANCE_ID") + + credentials = Credentials( + url=url, + api_key=apikey, + token=token, + password=password, + username=username, + instance_id=instance_id, + version=version, + verify=verify, + ) + project_id = project_id or _from_env("WATSONX_PROJECT_ID") + space_id = space_id or _from_env("WATSONX_SPACE_ID") + self.watsonx_client = APIClient( + credentials=credentials, project_id=project_id, space_id=space_id + ) + else: + self.watsonx_client = watsonx_client + + context_id: dict[str, str | None] = {"project_id": None, "space_id": None} + if project_id is not None: + context_id["project_id"] = project_id + elif space_id is not None: + context_id["space_id"] = space_id + elif self.watsonx_client.default_project_id is not None: + context_id["project_id"] = self.watsonx_client.default_project_id + elif self.watsonx_client.default_space_id is not None: + context_id["space_id"] = self.watsonx_client.default_space_id + else: + raise ValueError("Either project_id or space_id is required.") + + self._flight_sql_client = FlightSQLClient( + connection_id=connection_id, api_client=self.watsonx_client, **context_id + ) + + with self._flight_sql_client as flight_sql_client: + _tables = flight_sql_client.get_tables(schema=self.schema).get("assets") + if _tables is not None: + self._all_tables = { + table.get("name") for table in _tables if table.get("name") + } + else: + raise RuntimeError(f"No tables found in the schema: {schema}") + + if self._include_tables: + missing_tables = self._include_tables - self._all_tables + if missing_tables: + raise ValueError( + f"include_tables {missing_tables} not found in database" + ) + if self._ignore_tables: + missing_tables = self._ignore_tables - self._all_tables + if missing_tables: + raise ValueError( + f"ignore_tables {missing_tables} not found in database" + ) + + self._meta_all_tables = { + table_name: flight_sql_client.get_table_info( + table_name=table_name, schema=self.schema + ) + for table_name in self._all_tables + if table_name in (self._include_tables or self._all_tables) + and table_name not in (self._ignore_tables or {}) + } + + def get_usable_table_names(self) -> Iterable[str]: + """Get names of tables available.""" + if self._include_tables: + return sorted(self._include_tables) + return sorted(self._all_tables - self._ignore_tables) + + def _execute( + self, + command: str, + ) -> dict: + """Execute a command.""" + + with self._flight_sql_client as flight_sql_client: + results = flight_sql_client.execute(query=command) + + return results.to_dict("records") + + def run(self, command: str, include_columns: bool = False) -> str: + """Execute a SQL command and return a string representing the results.""" + result = self._execute(command) + + res: List[Dict] = [ + { + column: truncate_word(value, length=self._max_string_length) + for column, value in r.items() + } + for r in result + ] + + if not include_columns: + res = [tuple(row.values()) for row in res] # type: ignore[misc] + + return str(res) if res else "" + + def run_no_throw( + self, + command: str, + include_columns: bool = False, + ) -> str: + """Execute a SQL command and return a string representing the results. + + If the statement throws an error, the error message is returned. + """ + try: + return self.run( + command, + include_columns=include_columns, + ) + except flight.FlightError as e: + """Format the error message""" + return f"Error: {e}" + + def get_table_info(self, table_names: Optional[Iterable[str]] = None) -> str: + """Get information about specified tables.""" + + all_table_names = self.get_usable_table_names() + if table_names is not None: + missing_tables = set(table_names).difference(all_table_names) + if missing_tables: + raise ValueError(f"table_names {missing_tables} not found in database") + + extra_interaction_properties = { + "schema_name": self.schema, + "row_limit": self._sample_rows_in_table_info, + } + + with self._flight_sql_client as flight_sql_client: + if table_names is None: + table_names = self._all_tables + + return "\n\n".join( + [ + pretty_print_table_info( + schema=self.schema, + table_name=table_name, + table_info=self._meta_all_tables[table_name], + ) + + f"\n\nFirst {self._sample_rows_in_table_info} rows " + + f"of table {table_name}:\n\n" + + flight_sql_client.execute( + None, + interaction_properties=extra_interaction_properties + | {"table_name": table_name}, + ).to_string() + for table_name in table_names + ] + ) + + def get_table_info_no_throw( + self, table_names: Optional[Iterable[str]] = None + ) -> str: + """Get information about specified tables.""" + try: + return self.get_table_info(table_names=table_names) + except (flight.FlightError, ValueError) as e: + """Format the error message""" + return f"Error: {e}" + + def get_context(self) -> Dict[str, Any]: + """Return db context that you may want in agent prompt.""" + table_names = list(self.get_usable_table_names()) + table_info = self.get_table_info_no_throw() + return {"table_info": table_info, "table_names": ", ".join(table_names)} diff --git a/libs/ibm/poetry.lock b/libs/ibm/poetry.lock index 27a852e..40c7efe 100644 --- a/libs/ibm/poetry.lock +++ b/libs/ibm/poetry.lock @@ -1130,6 +1130,62 @@ files = [ dev = ["pre-commit", "tox"] testing = ["coverage", "pytest", "pytest-benchmark"] +[[package]] +name = "pyarrow" +version = "21.0.0" +description = "Python library for Apache Arrow" +optional = false +python-versions = ">=3.9" +groups = ["main", "lint", "test"] +files = [ + {file = "pyarrow-21.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:e563271e2c5ff4d4a4cbeb2c83d5cf0d4938b891518e676025f7268c6fe5fe26"}, + {file = "pyarrow-21.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fee33b0ca46f4c85443d6c450357101e47d53e6c3f008d658c27a2d020d44c79"}, + {file = "pyarrow-21.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:7be45519b830f7c24b21d630a31d48bcebfd5d4d7f9d3bdb49da9cdf6d764edb"}, + {file = "pyarrow-21.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:26bfd95f6bff443ceae63c65dc7e048670b7e98bc892210acba7e4995d3d4b51"}, + {file = "pyarrow-21.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:bd04ec08f7f8bd113c55868bd3fc442a9db67c27af098c5f814a3091e71cc61a"}, + {file = "pyarrow-21.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9b0b14b49ac10654332a805aedfc0147fb3469cbf8ea951b3d040dab12372594"}, + {file = "pyarrow-21.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:9d9f8bcb4c3be7738add259738abdeddc363de1b80e3310e04067aa1ca596634"}, + {file = "pyarrow-21.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c077f48aab61738c237802836fc3844f85409a46015635198761b0d6a688f87b"}, + {file = "pyarrow-21.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:689f448066781856237eca8d1975b98cace19b8dd2ab6145bf49475478bcaa10"}, + {file = "pyarrow-21.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:479ee41399fcddc46159a551705b89c05f11e8b8cb8e968f7fec64f62d91985e"}, + {file = "pyarrow-21.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:40ebfcb54a4f11bcde86bc586cbd0272bac0d516cfa539c799c2453768477569"}, + {file = "pyarrow-21.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8d58d8497814274d3d20214fbb24abcad2f7e351474357d552a8d53bce70c70e"}, + {file = "pyarrow-21.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:585e7224f21124dd57836b1530ac8f2df2afc43c861d7bf3d58a4870c42ae36c"}, + {file = "pyarrow-21.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:555ca6935b2cbca2c0e932bedd853e9bc523098c39636de9ad4693b5b1df86d6"}, + {file = "pyarrow-21.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3a302f0e0963db37e0a24a70c56cf91a4faa0bca51c23812279ca2e23481fccd"}, + {file = "pyarrow-21.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:b6b27cf01e243871390474a211a7922bfbe3bda21e39bc9160daf0da3fe48876"}, + {file = "pyarrow-21.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e72a8ec6b868e258a2cd2672d91f2860ad532d590ce94cdf7d5e7ec674ccf03d"}, + {file = "pyarrow-21.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b7ae0bbdc8c6674259b25bef5d2a1d6af5d39d7200c819cf99e07f7dfef1c51e"}, + {file = "pyarrow-21.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:58c30a1729f82d201627c173d91bd431db88ea74dcaa3885855bc6203e433b82"}, + {file = "pyarrow-21.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:072116f65604b822a7f22945a7a6e581cfa28e3454fdcc6939d4ff6090126623"}, + {file = "pyarrow-21.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:cf56ec8b0a5c8c9d7021d6fd754e688104f9ebebf1bf4449613c9531f5346a18"}, + {file = "pyarrow-21.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e99310a4ebd4479bcd1964dff9e14af33746300cb014aa4a3781738ac63baf4a"}, + {file = "pyarrow-21.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:d2fe8e7f3ce329a71b7ddd7498b3cfac0eeb200c2789bd840234f0dc271a8efe"}, + {file = "pyarrow-21.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:f522e5709379d72fb3da7785aa489ff0bb87448a9dc5a75f45763a795a089ebd"}, + {file = "pyarrow-21.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:69cbbdf0631396e9925e048cfa5bce4e8c3d3b41562bbd70c685a8eb53a91e61"}, + {file = "pyarrow-21.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:731c7022587006b755d0bdb27626a1a3bb004bb56b11fb30d98b6c1b4718579d"}, + {file = "pyarrow-21.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:dc56bc708f2d8ac71bd1dcb927e458c93cec10b98eb4120206a4091db7b67b99"}, + {file = "pyarrow-21.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:186aa00bca62139f75b7de8420f745f2af12941595bbbfa7ed3870ff63e25636"}, + {file = "pyarrow-21.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:a7a102574faa3f421141a64c10216e078df467ab9576684d5cd696952546e2da"}, + {file = "pyarrow-21.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:1e005378c4a2c6db3ada3ad4c217b381f6c886f0a80d6a316fe586b90f77efd7"}, + {file = "pyarrow-21.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:65f8e85f79031449ec8706b74504a316805217b35b6099155dd7e227eef0d4b6"}, + {file = "pyarrow-21.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:3a81486adc665c7eb1a2bde0224cfca6ceaba344a82a971ef059678417880eb8"}, + {file = "pyarrow-21.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:fc0d2f88b81dcf3ccf9a6ae17f89183762c8a94a5bdcfa09e05cfe413acf0503"}, + {file = "pyarrow-21.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6299449adf89df38537837487a4f8d3bd91ec94354fdd2a7d30bc11c48ef6e79"}, + {file = "pyarrow-21.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:222c39e2c70113543982c6b34f3077962b44fca38c0bd9e68bb6781534425c10"}, + {file = "pyarrow-21.0.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a7f6524e3747e35f80744537c78e7302cd41deee8baa668d56d55f77d9c464b3"}, + {file = "pyarrow-21.0.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:203003786c9fd253ebcafa44b03c06983c9c8d06c3145e37f1b76a1f317aeae1"}, + {file = "pyarrow-21.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3b4d97e297741796fead24867a8dabf86c87e4584ccc03167e4a811f50fdf74d"}, + {file = "pyarrow-21.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:898afce396b80fdda05e3086b4256f8677c671f7b1d27a6976fa011d3fd0a86e"}, + {file = "pyarrow-21.0.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:067c66ca29aaedae08218569a114e413b26e742171f526e828e1064fcdec13f4"}, + {file = "pyarrow-21.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0c4e75d13eb76295a49e0ea056eb18dbd87d81450bfeb8afa19a7e5a75ae2ad7"}, + {file = "pyarrow-21.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdc4c17afda4dab2a9c0b79148a43a7f4e1094916b3e18d8975bfd6d6d52241f"}, + {file = "pyarrow-21.0.0.tar.gz", hash = "sha256:5051f2dccf0e283ff56335760cbc8622cf52264d67e359d5569541ac11b6d5bc"}, +] + +[package.extras] +test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] + [[package]] name = "pycparser" version = "2.22" @@ -1883,7 +1939,10 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ [package.extras] cffi = ["cffi (>=1.11)"] +[extras] +sql-toolkit = ["pyarrow"] + [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "9b0cdb3fdffaec4ba72ba857149d6ad43c9d75d9b38828ebfb7efd0fe9a0abfa" +content-hash = "0fc9d4365927739850722dd85bffa9d5a1ae3ade10334e7b3d7e4472a749094c" diff --git a/libs/ibm/pyproject.toml b/libs/ibm/pyproject.toml index 51411e9..bc55d10 100644 --- a/libs/ibm/pyproject.toml +++ b/libs/ibm/pyproject.toml @@ -13,7 +13,11 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.10,<3.14" langchain-core = "^0.3.39" -ibm-watsonx-ai = "^1.3.33" +ibm-watsonx-ai = "^1.3.34" +pyarrow = {version = ">=3.0.0", optional = true} + +[tool.poetry.extras] +sql_toolkit = ["pyarrow"] [tool.poetry.group.test] optional = true @@ -27,6 +31,7 @@ pytest-watcher = "^0.3.4" pytest-asyncio = "^0.21.1" pytest-cov = "^4.1.0" langchain-tests = "0.3.13" +pyarrow = ">=3.0.0" [tool.poetry.group.codespell] optional = true @@ -44,6 +49,8 @@ optional = true [tool.poetry.group.lint.dependencies] ruff = "^0.5" +pyarrow = ">=3.0.0" + [tool.poetry.group.typing.dependencies] mypy = "^1.10" diff --git a/libs/ibm/tests/unit_tests/agent_toolkits/test_tool.py b/libs/ibm/tests/unit_tests/agent_toolkits/test_tool.py new file mode 100644 index 0000000..4704d96 --- /dev/null +++ b/libs/ibm/tests/unit_tests/agent_toolkits/test_tool.py @@ -0,0 +1,89 @@ +from typing import Any +from unittest.mock import Mock + +import pytest +from pyarrow import flight # type: ignore[import-untyped] + +from langchain_ibm.agent_toolkits.tool import ( + InfoSQLDatabaseTool, + ListSQLDatabaseTool, + QuerySQLDatabaseTool, +) +from langchain_ibm.utilities.sql_database import WatsonxSQLDatabase + + +@pytest.fixture +def mock_db() -> Mock: + return Mock(spec=WatsonxSQLDatabase) + + +### QuerySQLDatabaseTool + + +def test_query_tool_run_with_valid_query(mock_db: Mock) -> None: + """Test running a valid query.""" + mock_db.run_no_throw.return_value = "Valid query result" + tool = QuerySQLDatabaseTool(db=mock_db) + + query = "SELECT * FROM table" + result = tool._run(query, None) + assert result == "Valid query result" + mock_db.run_no_throw.assert_called_once_with(query) + + +def test_query_tool_run_with_invalid_query(mock_db: Mock) -> None: + """Test running an invalid query.""" + + def mock_run_no_throw(*args: Any, **kwargs: Any) -> None: + raise flight.FlightError("Invalid query") + + mock_db.run.side_effect = mock_run_no_throw + tool = QuerySQLDatabaseTool(db=mock_db) + + query = "SELECT * FROM non_existent_table" + result = tool._run(query, None) + assert result.startswith("Error"), "Expected an error message" + mock_db.run_no_throw.assert_called_once_with(query) + + +### InfoSQLDatabaseTool + + +def test_info_tool_run_with_valid_query(mock_db: Mock) -> None: + """Test running a valid query.""" + mock_db.get_table_info_no_throw.return_value = "schema_info" + tool = InfoSQLDatabaseTool(db=mock_db) + + result = tool._run("table1,table2", None) + + mock_db.get_table_info_no_throw.assert_called_once_with(["table1", "table2"]) + assert result == "schema_info" + + +def test_info_tool_run_with_invalid_query(mock_db: Mock) -> None: + """Test running an invalid query.""" + + def mock_get_table_info(*args: Any, **kwargs: Any) -> None: + raise flight.FlightError("Table not Found") + + mock_db.get_table_info.side_effect = mock_get_table_info + tool = InfoSQLDatabaseTool(db=mock_db) + + result = tool._run("tableX", None) + + mock_db.get_table_info_no_throw.assert_called_once_with(["tableX"]) + assert result.startswith("Error"), "Expected an error message" + + +### ListSQLDatabaseTool + + +def test_list_tool_run_with_valid_query(mock_db: Mock) -> None: + """Test running a valid query.""" + mock_db.get_usable_table_names.return_value = ["table1", "table2"] + tool = ListSQLDatabaseTool(db=mock_db) + + result = tool._run() + + mock_db.get_usable_table_names.assert_called_once_with() + assert result == "table1, table2" diff --git a/libs/ibm/tests/unit_tests/utilities/__init__.py b/libs/ibm/tests/unit_tests/utilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/ibm/tests/unit_tests/utilities/test_sql_database.py b/libs/ibm/tests/unit_tests/utilities/test_sql_database.py new file mode 100644 index 0000000..91b9448 --- /dev/null +++ b/libs/ibm/tests/unit_tests/utilities/test_sql_database.py @@ -0,0 +1,532 @@ +import os +from typing import Any, Dict, Generator +from unittest import mock +from unittest.mock import Mock, patch + +import pandas as pd # type: ignore[import-untyped] +import pytest +from pyarrow import flight # type: ignore + +from langchain_ibm.utilities.sql_database import ( + WatsonxSQLDatabase, + pretty_print_table_info, + truncate_word, +) + +CONNECTION_ID = "test_connection_id" +PROJECT_ID = "test_project_id" + + +@pytest.fixture +def schema() -> str: + return "test_schema" + + +@pytest.fixture +def table_name() -> str: + return "test_table" + + +@pytest.fixture +def table_info() -> dict: + return { + "fields": [ + {"name": "id", "type": {"native_type": "INT", "nullable": False}}, + {"name": "name", "type": {"native_type": "VARCHAR(255)", "nullable": True}}, + {"name": "age", "type": {"native_type": "INT", "nullable": True}}, + ], + "extended_metadata": [ + {"name": "primary_key", "value": {"key_columns": ["id"]}} + ], + } + + +@pytest.fixture +def clear_env() -> Generator[None, None, None]: + with mock.patch.dict(os.environ, clear=True): + yield + + +class MockFlightSQLClient: + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass + + def __enter__(self, *args: Any, **kwargs: Any) -> "MockFlightSQLClient": + return self + + def __exit__(self, *args: Any) -> None: + pass + + def get_tables(self, *args: Any, **kwargs: Any) -> Dict: + return {"assets": [{"name": "table1"}, {"name": "table2"}]} + + def get_table_info(self, table_name: str, *args: Any, **kwargs: Any) -> Dict: + if table_name == "table1": + return { + "path": "/public/table1", + "fields": [ + {"name": "id", "type": {"native_type": "INT", "nullable": False}}, + { + "name": "name", + "type": {"native_type": "VARCHAR(255)", "nullable": True}, + }, + {"name": "age", "type": {"native_type": "INT", "nullable": True}}, + ], + "extended_metadata": [ + {"name": "primary_key", "value": {"key_columns": ["id"]}} + ], + } + elif table_name == "table2": + return { + "path": "/public/table2", + "fields": [ + {"name": "id", "type": {"native_type": "INT", "nullable": False}}, + { + "name": "name", + "type": {"native_type": "VARCHAR(255)", "nullable": True}, + }, + {"name": "age", "type": {"native_type": "INT", "nullable": True}}, + ], + "extended_metadata": [ + {"name": "primary_key", "value": {"key_columns": ["id"]}} + ], + } + else: + raise flight.FlightError("Table not found") + + def execute(self, *args: Any, **kwargs: Any) -> pd.DataFrame: + if kwargs.get("interaction_properties", {}).get("table_name") == "table1" or ( + "table1" in kwargs.get("query", "") + ): + return pd.DataFrame({"id": [1], "name": ["test"], "age": [35]}) + + elif "table1" not in kwargs.get("query", ""): + raise flight.FlightError("Table not found") + + else: + raise ValueError("syntax error") + + +### truncate_word + + +def test_truncate_word() -> None: + assert truncate_word("This is a test", length=11) == "This is..." + assert truncate_word("Short", length=10) == "Short" + assert truncate_word("A", length=3) == "A" + assert truncate_word("", length=10) == "" + assert ( + truncate_word("This is a longer test", length=20, suffix="***") + == "This is a longer***" + ) + assert truncate_word(12345, length=10) == 12345 # Non-string input + assert truncate_word("This is a test", length=0) == "This is a test" # Length <= 0 + + +def test_truncate_word_edge_cases() -> None: + assert truncate_word("This is a test", length=8) == "This..." + assert truncate_word("This is a test", length=11) == "This is..." + assert truncate_word("This is a test", length=13) == "This is a..." + assert ( + truncate_word("This is a test", length=16) == "This is a test" + ) # Length >= string lengths + + +### pretty_print_table_info + + +def test_pretty_print_table_info( + schema: str, table_name: str, table_info: dict +) -> None: + expected_output = """ +CREATE TABLE test_schema.test_table ( +\tid INT NOT NULL, +\tname VARCHAR(255), +\tage INT, +\tPRIMARY KEY (id) +\t)""" + assert pretty_print_table_info(schema, table_name, table_info) == expected_output + + +def test_pretty_print_table_info_with_nullable_columns() -> None: + schema = "another_schema" + table_name = "another_table" + table_info = { + "fields": [ + { + "name": "email", + "type": {"native_type": "VARCHAR(255)", "nullable": True}, + }, + { + "name": "created_at", + "type": {"native_type": "TIMESTAMP", "nullable": True}, + }, + ], + "extended_metadata": [ + {"name": "primary_key", "value": {"key_columns": ["email"]}} + ], + } + expected_output = """ +CREATE TABLE another_schema.another_table ( +\temail VARCHAR(255), +\tcreated_at TIMESTAMP, +\tPRIMARY KEY (email) +\t)""" + assert pretty_print_table_info(schema, table_name, table_info) == expected_output + + +def test_pretty_print_table_info_without_primary_key() -> None: + schema = "no_pk_schema" + table_name = "no_pk_table" + table_info = { + "fields": [ + {"name": "value1", "type": {"native_type": "INT", "nullable": False}}, + { + "name": "value2", + "type": {"native_type": "VARCHAR(255)", "nullable": True}, + }, + ] + } + expected_output = """ +CREATE TABLE no_pk_schema.no_pk_table ( +\tvalue1 INT NOT NULL, +\tvalue2 VARCHAR(255), +\t)""" + assert pretty_print_table_info(schema, table_name, table_info) == expected_output + + +### WatsonSQLDatabase + + +def test_initialize_watsonx_sql_database_without_url( + clear_env: None, schema: str +) -> None: + with pytest.raises(ValueError) as e: + WatsonxSQLDatabase(connection_id=CONNECTION_ID, schema=schema) + + assert "url" in str(e.value) + assert "WATSONX_URL" in str(e.value) + + +def test_initialize_watsonx_sql_database_cloud_bad_path( + clear_env: None, schema: str +) -> None: + with pytest.raises(ValueError) as e: + WatsonxSQLDatabase( + connection_id=CONNECTION_ID, + schema=schema, + url="https://us-south.ml.cloud.ibm.com", + ) # type: ignore[arg-type] + + assert "apikey" in str(e.value) and "token" in str(e.value) + assert "WATSONX_APIKEY" in str(e.value) and "WATSONX_TOKEN" in str(e.value) + + +def test_initialize_watsonx_sql_database_cpd_bad_path_without_all( + clear_env: None, schema: str +) -> None: + with pytest.raises(ValueError) as e: + WatsonxSQLDatabase( + connection_id=CONNECTION_ID, + schema=schema, + url="https://cpd-zen.apps.cpd48.cp.fyre.ibm.com", # type: ignore[arg-type] + ) + assert ( + "apikey" in str(e.value) + and "password" in str(e.value) + and "token" in str(e.value) + ) + assert ( + "WATSONX_APIKEY" in str(e.value) + and "WATSONX_PASSWORD" in str(e.value) + and "WATSONX_TOKEN" in str(e.value) + ) + + +def test_initialize_watsonx_sql_database_cpd_bad_path_password_without_username( + clear_env: None, schema: str +) -> None: + with pytest.raises(ValueError) as e: + WatsonxSQLDatabase( + connection_id=CONNECTION_ID, + schema=schema, + url="https://cpd-zen.apps.cpd48.cp.fyre.ibm.com", # type: ignore[arg-type] + password="test_password", # type: ignore[arg-type] + ) + assert "username" in str(e.value) + assert "WATSONX_USERNAME" in str(e.value) + + +def test_initialize_watsonx_sql_database_cpd_bad_path_apikey_without_username( + clear_env: None, schema: str +) -> None: + with pytest.raises(ValueError) as e: + WatsonxSQLDatabase( + connection_id=CONNECTION_ID, + schema=schema, + url="https://cpd-zen.apps.cpd48.cp.fyre.ibm.com", # type: ignore[arg-type] + apikey="test_apikey", # type: ignore[arg-type] + ) + + assert "username" in str(e.value) + assert "WATSONX_USERNAME" in str(e.value) + + +def test_initialize_watsonx_sql_database_cpd_bad_path_without_instance_id( + clear_env: None, schema: str +) -> None: + with pytest.raises(ValueError) as e: + WatsonxSQLDatabase( + connection_id=CONNECTION_ID, + schema=schema, + url="https://cpd-zen.apps.cpd48.cp.fyre.ibm.com", # type: ignore[arg-type] + apikey="test_apikey", # type: ignore[arg-type] + username="test_user", # type: ignore[arg-type] + ) + assert "instance_id" in str(e.value) + assert "WATSONX_INSTANCE_ID" in str(e.value) + + +def test_initialize_watsonx_sql_database_without_any_params() -> None: + with pytest.raises(TypeError): + WatsonxSQLDatabase() # type: ignore[call-arg] + + +def test_initialize_watsonx_sql_database_valid( + schema: str, monkeypatch: pytest.MonkeyPatch +) -> None: + mock_api_client = Mock() + mock_api_client.default_project_id = PROJECT_ID + + with ( + mock.patch.dict(os.environ, clear=True), + patch( + "langchain_ibm.utilities.sql_database.APIClient", + autospec=True, + return_value=mock_api_client, + ), + patch( + "langchain_ibm.utilities.sql_database.FlightSQLClient", + autospec=True, + return_value=MockFlightSQLClient(), + ), + ): + envvars = { + "WATSONX_APIKEY": "test_apikey", + "WATSONX_URL": "https://us-south.ml.cloud.ibm.com", + } + for k, v in envvars.items(): + monkeypatch.setenv(k, v) + + wx_sql_database = WatsonxSQLDatabase(connection_id=CONNECTION_ID, schema=schema) + + assert isinstance(wx_sql_database._flight_sql_client, MockFlightSQLClient) + assert wx_sql_database.schema == schema + + +def test_initialize_watsonx_sql_database_include_tables( + schema: str, monkeypatch: pytest.MonkeyPatch +) -> None: + mock_api_client = Mock() + mock_api_client.default_project_id = PROJECT_ID + + with ( + mock.patch.dict(os.environ, clear=True), + patch( + "langchain_ibm.utilities.sql_database.APIClient", + autospec=True, + return_value=mock_api_client, + ), + patch( + "langchain_ibm.utilities.sql_database.FlightSQLClient", + autospec=True, + return_value=MockFlightSQLClient(), + ), + ): + envvars = { + "WATSONX_APIKEY": "test_apikey", + "WATSONX_URL": "https://us-south.ml.cloud.ibm.com", + } + for k, v in envvars.items(): + monkeypatch.setenv(k, v) + + wx_sql_database = WatsonxSQLDatabase( + connection_id=CONNECTION_ID, schema=schema, include_tables=["table1"] + ) + + assert wx_sql_database.get_usable_table_names() == ["table1"] + + +def test_initialize_watsonx_sql_database_ignore_tables( + schema: str, monkeypatch: pytest.MonkeyPatch +) -> None: + mock_api_client = Mock() + mock_api_client.default_project_id = PROJECT_ID + + with ( + mock.patch.dict(os.environ, clear=True), + patch( + "langchain_ibm.utilities.sql_database.APIClient", + autospec=True, + return_value=mock_api_client, + ), + patch( + "langchain_ibm.utilities.sql_database.FlightSQLClient", + autospec=True, + return_value=MockFlightSQLClient(), + ), + ): + envvars = { + "WATSONX_APIKEY": "test_apikey", + "WATSONX_URL": "https://us-south.ml.cloud.ibm.com", + } + for k, v in envvars.items(): + monkeypatch.setenv(k, v) + + wx_sql_database = WatsonxSQLDatabase( + connection_id=CONNECTION_ID, schema=schema, ignore_tables=["table1"] + ) + + assert wx_sql_database.get_usable_table_names() == ["table2"] + + +def test_initialize_watsonx_sql_database_get_table_info( + schema: str, monkeypatch: pytest.MonkeyPatch +) -> None: + mock_api_client = Mock() + mock_api_client.default_project_id = PROJECT_ID + + with ( + mock.patch.dict(os.environ, clear=True), + patch( + "langchain_ibm.utilities.sql_database.APIClient", + autospec=True, + return_value=mock_api_client, + ), + patch( + "langchain_ibm.utilities.sql_database.FlightSQLClient", + autospec=True, + return_value=MockFlightSQLClient(), + ), + ): + envvars = { + "WATSONX_APIKEY": "test_apikey", + "WATSONX_URL": "https://us-south.ml.cloud.ibm.com", + } + for k, v in envvars.items(): + monkeypatch.setenv(k, v) + + wx_sql_database = WatsonxSQLDatabase(connection_id=CONNECTION_ID, schema=schema) + expected_output = """ +CREATE TABLE test_schema.table1 ( +\tid INT NOT NULL, +\tname VARCHAR(255), +\tage INT, +\tPRIMARY KEY (id) +\t) + +First 3 rows of table table1: + + id name age +0 1 test 35""" + print(wx_sql_database.get_table_info(["table1"])) + assert wx_sql_database.get_table_info(["table1"]) == expected_output + + +def test_initialize_watsonx_sql_database_get_table_info_no_throw( + schema: str, monkeypatch: pytest.MonkeyPatch +) -> None: + mock_api_client = Mock() + mock_api_client.default_project_id = PROJECT_ID + + with ( + mock.patch.dict(os.environ, clear=True), + patch( + "langchain_ibm.utilities.sql_database.APIClient", + autospec=True, + return_value=mock_api_client, + ), + patch( + "langchain_ibm.utilities.sql_database.FlightSQLClient", + autospec=True, + return_value=MockFlightSQLClient(), + ), + ): + envvars = { + "WATSONX_APIKEY": "test_apikey", + "WATSONX_URL": "https://us-south.ml.cloud.ibm.com", + } + for k, v in envvars.items(): + monkeypatch.setenv(k, v) + + wx_sql_database = WatsonxSQLDatabase(connection_id=CONNECTION_ID, schema=schema) + with pytest.raises(ValueError): + wx_sql_database.get_table_info(["tableX"]) + assert "tableX" in wx_sql_database.get_table_info_no_throw(["tableX"]) + + +def test_initialize_watsonx_sql_database_run( + schema: str, monkeypatch: pytest.MonkeyPatch +) -> None: + mock_api_client = Mock() + mock_api_client.default_project_id = PROJECT_ID + + with ( + mock.patch.dict(os.environ, clear=True), + patch( + "langchain_ibm.utilities.sql_database.APIClient", + autospec=True, + return_value=mock_api_client, + ), + patch( + "langchain_ibm.utilities.sql_database.FlightSQLClient", + autospec=True, + return_value=MockFlightSQLClient(), + ), + ): + envvars = { + "WATSONX_APIKEY": "test_apikey", + "WATSONX_URL": "https://us-south.ml.cloud.ibm.com", + } + for k, v in envvars.items(): + monkeypatch.setenv(k, v) + + wx_sql_database = WatsonxSQLDatabase(connection_id=CONNECTION_ID, schema=schema) + + assert ( + wx_sql_database.run(f"SELECT * FROM {schema}.table1", include_columns=True) + == "[{'id': 1, 'name': 'test', 'age': 35}]" + ) + + +def test_initialize_watsonx_sql_database_run_no_throw( + schema: str, monkeypatch: pytest.MonkeyPatch +) -> None: + mock_api_client = Mock() + mock_api_client.default_project_id = PROJECT_ID + + with ( + mock.patch.dict(os.environ, clear=True), + patch( + "langchain_ibm.utilities.sql_database.APIClient", + autospec=True, + return_value=mock_api_client, + ), + patch( + "langchain_ibm.utilities.sql_database.FlightSQLClient", + autospec=True, + return_value=MockFlightSQLClient(), + ), + ): + envvars = { + "WATSONX_APIKEY": "test_apikey", + "WATSONX_URL": "https://us-south.ml.cloud.ibm.com", + } + for k, v in envvars.items(): + monkeypatch.setenv(k, v) + + wx_sql_database = WatsonxSQLDatabase(connection_id=CONNECTION_ID, schema=schema) + + assert "Table not found" in wx_sql_database.run_no_throw( + f"SELECT * FROM {schema}.tableX", include_columns=True + )