diff --git a/libs/community/langchain_community/utilities/sql_database.py b/libs/community/langchain_community/utilities/sql_database.py index 5aa0efba..51cf5c23 100644 --- a/libs/community/langchain_community/utilities/sql_database.py +++ b/libs/community/langchain_community/utilities/sql_database.py @@ -3,7 +3,7 @@ from __future__ import annotations import re -from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union +from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Union import sqlalchemy from langchain_core._api import deprecated @@ -544,11 +544,22 @@ def run( *, parameters: Optional[Dict[str, Any]] = None, execution_options: Optional[Dict[str, Any]] = None, - ) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]: - """Execute a SQL command and return a string representing the results. + response_format: Literal["content", "content_and_artifact"] = "content", + ) -> Union[str, Sequence[Dict[str, Any]], Result[Any], Tuple[str, Any]]: + """Execute a SQL command and return its result. - If the statement returns rows, a string of the results is returned. - If the statement returns no rows, an empty string is returned. + Args: + command: The SQL command to execute. + fetch: The number of rows to fetch. Can be "one", "all", or "cursor". + include_columns: Whether to include column names in the result. + parameters: A dictionary of parameters to pass to the SQL command. + execution_options: A dictionary of execution options for the engine. + response_format: The format of the response. Defaults to "content". + + Returns: + A string representation of the results, or a cursor object. If + `response_format` is 'content_and_artifact', returns a tuple of + (string_representation, processed_result_list). """ result = self._execute( command, fetch, parameters=parameters, execution_options=execution_options @@ -568,10 +579,12 @@ def run( if not include_columns: res = [tuple(row.values()) for row in res] # type: ignore[misc] - if not res: - return "" + string_result = "" if not res else str(res) + + if response_format == "content_and_artifact": + return string_result, res else: - return str(res) + return string_result def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: """Get information about specified tables. diff --git a/libs/community/tests/unit_tests/utilities/test_sql_database.py b/libs/community/tests/unit_tests/utilities/test_sql_database.py new file mode 100644 index 00000000..6bb66809 --- /dev/null +++ b/libs/community/tests/unit_tests/utilities/test_sql_database.py @@ -0,0 +1,20 @@ +from langchain_community.utilities.sql_database import SQLDatabase + + +def test_sql_database_run_content_and_artifact(): + db = SQLDatabase.from_uri("sqlite+pysqlite:///:memory:") + db.run("CREATE TABLE test (id INTEGER, name TEXT);") + db.run("INSERT INTO test (id, name) VALUES (1, 'foo');") + db.run("INSERT INTO test (id, name) VALUES (2, 'bar');") + + # Test content_and_artifact format + string_result, artifact = db.run( + "SELECT * FROM test", response_format="content_and_artifact" + ) + assert "foo" in string_result and "bar" in string_result + assert (1, "foo") in artifact and (2, "bar") in artifact + + # Test default format (string only) + only_string = db.run("SELECT * FROM test") + assert isinstance(only_string, str) + assert "foo" in only_string and "bar" in only_string