Skip to content

Commit 5226231

Browse files
authored
carry over 30646 (#6)
1 parent 7a12e11 commit 5226231

File tree

2 files changed

+291
-5
lines changed

2 files changed

+291
-5
lines changed

libs/community/langchain_community/utilities/sql_database.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,9 @@ def table_info(self) -> str:
316316
"""Information about all tables in the database."""
317317
return self.get_table_info()
318318

319-
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
319+
def get_table_info(
320+
self, table_names: Optional[List[str]] = None, get_col_comments: bool = False
321+
) -> str:
320322
"""Get information about specified tables.
321323
322324
Follows best practices as specified in: Rajkumar et al, 2022
@@ -356,14 +358,39 @@ def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
356358
tables.append(self._custom_table_info[table.name])
357359
continue
358360

359-
# Ignore JSON datatyped columns
360-
for k, v in table.columns.items(): # AttributeError: items in sqlalchemy v1
361-
if type(v.type) is NullType:
362-
table._columns.remove(v)
361+
# Ignore JSON datatyped columns - SQLAlchemy v1.x compatibility
362+
try:
363+
# For SQLAlchemy v2.x
364+
for k, v in table.columns.items():
365+
if type(v.type) is NullType:
366+
table._columns.remove(v)
367+
except AttributeError:
368+
# For SQLAlchemy v1.x
369+
for k, v in dict(table.columns).items():
370+
if type(v.type) is NullType:
371+
table._columns.remove(v)
363372

364373
# add create table command
365374
create_table = str(CreateTable(table).compile(self._engine))
366375
table_info = f"{create_table.rstrip()}"
376+
377+
# Add column comments as dictionary
378+
if get_col_comments:
379+
try:
380+
column_comments_dict = {}
381+
for column in table.columns:
382+
if column.comment:
383+
column_comments_dict[column.name] = column.comment
384+
385+
if column_comments_dict:
386+
table_info += (
387+
f"\n\n/*\nColumn Comments: {column_comments_dict}\n*/"
388+
)
389+
except Exception:
390+
raise ValueError(
391+
"Column comments are available on PostgreSQL, MySQL, Oracle"
392+
)
393+
367394
has_extra_info = (
368395
self._indexes_in_table_info or self._sample_rows_in_table_info
369396
)
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import unittest
2+
from typing import Dict, Optional
3+
from unittest.mock import MagicMock, patch
4+
5+
from sqlalchemy import Column, Integer, MetaData, String, Table
6+
7+
from langchain_community.utilities.sql_database import SQLDatabase
8+
9+
10+
class TestSQLDatabaseComments(unittest.TestCase):
11+
"""Test class for column comment functionality in SQLDatabase"""
12+
13+
def setUp(self) -> None:
14+
"""Setup before each test"""
15+
# Mock Engine
16+
self.mock_engine = MagicMock()
17+
self.mock_engine.dialect.name = "postgresql" # Default to PostgreSQL
18+
19+
# Mock inspector and start patch *before* SQLDatabase initialization
20+
self.mock_inspector = MagicMock()
21+
# Mock table name list and other inspector methods called during init
22+
self.mock_inspector.get_table_names.return_value = ["test_table"]
23+
self.mock_inspector.get_view_names.return_value = []
24+
self.mock_inspector.get_indexes.return_value = []
25+
# Mock get_columns to return something reasonable for reflection
26+
self.mock_inspector.get_columns.return_value = [
27+
{
28+
"name": "id",
29+
"type": Integer(),
30+
"nullable": False,
31+
"default": None,
32+
"autoincrement": "auto",
33+
"comment": None,
34+
},
35+
{
36+
"name": "name",
37+
"type": String(100),
38+
"nullable": True,
39+
"default": None,
40+
"autoincrement": "auto",
41+
"comment": None,
42+
},
43+
{
44+
"name": "age",
45+
"type": Integer(),
46+
"nullable": True,
47+
"default": None,
48+
"autoincrement": "auto",
49+
"comment": None,
50+
},
51+
]
52+
# Mock get_pk_constraint for reflection
53+
self.mock_inspector.get_pk_constraint.return_value = {
54+
"constrained_columns": ["id"],
55+
"name": None,
56+
}
57+
# Mock get_foreign_keys for reflection
58+
self.mock_inspector.get_foreign_keys.return_value = []
59+
60+
# Patch sqlalchemy.inspect to return our mock inspector
61+
self.patch_inspector = patch(
62+
"langchain_community.utilities.sql_database.inspect",
63+
return_value=self.mock_inspector,
64+
)
65+
# Start the patch *before* creating the SQLDatabase instance
66+
self.mock_inspect = self.patch_inspector.start()
67+
68+
# Mock metadata
69+
self.metadata = MetaData()
70+
71+
# Create test database object *after* patching inspect
72+
try:
73+
self.db = SQLDatabase(
74+
engine=self.mock_engine,
75+
metadata=self.metadata,
76+
lazy_table_reflection=True,
77+
)
78+
except Exception as e:
79+
self.fail(f"Unexpected exception during SQLDatabase init: {e}")
80+
81+
def tearDown(self) -> None:
82+
"""Cleanup after each test"""
83+
self.patch_inspector.stop()
84+
85+
def setup_mock_table_with_comments(
86+
self, dialect: str, comments: Optional[Dict[str, str]] = None
87+
) -> Table:
88+
"""Setup a mock table with comments
89+
90+
Args:
91+
dialect (str): Database dialect to test (postgresql, mysql, oracle)
92+
comments (dict, optional): Column comments. Uses default comments if None
93+
94+
Returns:
95+
Table: The created mock table
96+
"""
97+
# Default comments
98+
if comments is None:
99+
comments = {
100+
"id": "Primary key",
101+
"name": "Name of the person",
102+
"age": "Age of the person",
103+
}
104+
105+
# Set engine dialect
106+
self.mock_engine.dialect.name = dialect
107+
108+
# Clear existing metadata if necessary, or use a fresh MetaData object
109+
self.metadata.clear()
110+
111+
# Create test table
112+
test_table = Table(
113+
"test_table",
114+
self.metadata,
115+
Column("id", Integer, primary_key=True, comment=comments.get("id")),
116+
Column("name", String(100), comment=comments.get("name")),
117+
Column("age", Integer, comment=comments.get("age")),
118+
)
119+
120+
# Mock reflection to return the columns with comments
121+
# This is crucial because lazy reflection will call inspect later
122+
self.mock_inspector.get_columns.return_value = [
123+
{
124+
"name": "id",
125+
"type": Integer(),
126+
"nullable": False,
127+
"default": None,
128+
"autoincrement": "auto",
129+
"comment": comments.get("id"),
130+
},
131+
{
132+
"name": "name",
133+
"type": String(100),
134+
"nullable": True,
135+
"default": None,
136+
"autoincrement": "auto",
137+
"comment": comments.get("name"),
138+
},
139+
{
140+
"name": "age",
141+
"type": Integer(),
142+
"nullable": True,
143+
"default": None,
144+
"autoincrement": "auto",
145+
"comment": comments.get("age"),
146+
},
147+
]
148+
self.mock_inspector.get_table_names.return_value = [
149+
"test_table"
150+
] # Ensure table is discoverable
151+
152+
# No need to mock CreateTable here, let the actual code call it.
153+
# We will patch it during the get_table_info call in the tests.
154+
155+
# No need to manually add table to metadata, reflection handles it
156+
# self.metadata._add_table("test_table", None, test_table)
157+
158+
return test_table
159+
160+
def _run_test_with_mocked_createtable(self, dialect: str) -> None:
161+
"""Helper function to run comment tests with CreateTable mocked."""
162+
self.setup_mock_table_with_comments(dialect)
163+
164+
# Define the expected CREATE TABLE string
165+
expected_create_table_sql = (
166+
"CREATE TABLE test_table (\n\tid INTEGER NOT NULL, "
167+
"\n\tname VARCHAR(100), \n\tage INTEGER, \n\tPRIMARY KEY (id)\n)"
168+
)
169+
170+
# Patch CreateTable specifically for the get_table_info call
171+
with patch(
172+
"langchain_community.utilities.sql_database.CreateTable"
173+
) as MockCreateTable:
174+
# Mock the compile method to return a specific string
175+
mock_compiler = MockCreateTable.return_value.compile
176+
mock_compiler.return_value = expected_create_table_sql
177+
178+
# Call get_table_info with get_col_comments=True
179+
table_info = self.db.get_table_info(get_col_comments=True)
180+
181+
# Verify CREATE TABLE statement (using the mocked value)
182+
self.assertIn(expected_create_table_sql.strip(), table_info)
183+
184+
# Verify comments are included in table info in the correct format
185+
self.assertIn("/*\nColumn Comments:", table_info)
186+
self.assertIn("'id': 'Primary key'", table_info)
187+
self.assertIn("'name': 'Name of the person'", table_info)
188+
self.assertIn("'age': 'Age of the person'", table_info)
189+
self.assertIn("*/", table_info)
190+
191+
def test_postgres_get_col_comments(self) -> None:
192+
"""Test retrieving column comments from PostgreSQL"""
193+
self._run_test_with_mocked_createtable("postgresql")
194+
195+
def test_mysql_get_col_comments(self) -> None:
196+
"""Test retrieving column comments from MySQL"""
197+
self._run_test_with_mocked_createtable("mysql")
198+
199+
def test_oracle_get_col_comments(self) -> None:
200+
"""Test retrieving column comments from Oracle"""
201+
self._run_test_with_mocked_createtable("oracle")
202+
203+
def test_sqlite_no_comments(self) -> None:
204+
"""Test that SQLite does not add a comment block when comments are missing."""
205+
# Setup SQLite table (comments will be ignored by SQLAlchemy for SQLite)
206+
self.setup_mock_table_with_comments("sqlite", comments={})
207+
# Mock reflection to return columns *without* comments
208+
self.mock_inspector.get_columns.return_value = [
209+
{
210+
"name": "id",
211+
"type": Integer(),
212+
"nullable": False,
213+
"default": None,
214+
"autoincrement": "auto",
215+
"comment": None,
216+
},
217+
{
218+
"name": "name",
219+
"type": String(100),
220+
"nullable": True,
221+
"default": None,
222+
"autoincrement": "auto",
223+
"comment": None,
224+
},
225+
{
226+
"name": "age",
227+
"type": Integer(),
228+
"nullable": True,
229+
"default": None,
230+
"autoincrement": "auto",
231+
"comment": None,
232+
},
233+
]
234+
235+
# Define the expected CREATE TABLE string
236+
expected_create_table_sql = (
237+
"CREATE TABLE test_table (\n\tid INTEGER NOT NULL, "
238+
"\n\tname VARCHAR(100), \n\tage INTEGER, \n\tPRIMARY KEY (id)\n)"
239+
)
240+
241+
# Patch CreateTable specifically for the get_table_info call
242+
with patch(
243+
"langchain_community.utilities.sql_database.CreateTable"
244+
) as MockCreateTable:
245+
mock_compiler = MockCreateTable.return_value.compile
246+
mock_compiler.return_value = expected_create_table_sql
247+
248+
# Call get_table_info with get_col_comments=True
249+
# Even if True, SQLite won't have comments to add.
250+
table_info = self.db.get_table_info(get_col_comments=True)
251+
252+
# Verify CREATE TABLE statement
253+
self.assertIn(expected_create_table_sql.strip(), table_info)
254+
# Verify comments block is NOT included
255+
self.assertNotIn("Column Comments:", table_info)
256+
257+
258+
if __name__ == "__main__":
259+
unittest.main()

0 commit comments

Comments
 (0)