Skip to content

Commit 248c297

Browse files
hwchase17fpingham
andauthored
Sample row in table info for SQLDatabase (#769) (#782)
The agents usually benefit from understanding what the data looks like to be able to filter effectively. Sending just one row in the table info allows the agent to understand the data before querying and get better results. --------- Co-authored-by: Francisco Ingham <> --------- Co-authored-by: Francisco Ingham <[email protected]>
1 parent 213c2e3 commit 248c297

File tree

3 files changed

+107
-10
lines changed

3 files changed

+107
-10
lines changed

docs/modules/chains/examples/sqlite.ipynb

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
},
5959
{
6060
"cell_type": "code",
61-
"execution_count": null,
61+
"execution_count": 3,
6262
"id": "a8fc8f23",
6363
"metadata": {},
6464
"outputs": [],
@@ -242,6 +242,74 @@
242242
"db_chain.run(\"What are some example tracks by composer Johann Sebastian Bach?\")"
243243
]
244244
},
245+
{
246+
"cell_type": "markdown",
247+
"id": "bcc5e936",
248+
"metadata": {},
249+
"source": [
250+
"## Adding first row of each table\n",
251+
"Sometimes, the format of the data is not obvious and it is optimal to include the first row of the table in the prompt to allow the LLM to understand the data before providing a final query. Here we will use this feature to let the LLM know that artists are saved with their full names."
252+
]
253+
},
254+
{
255+
"cell_type": "code",
256+
"execution_count": 11,
257+
"id": "9a22ee47",
258+
"metadata": {},
259+
"outputs": [],
260+
"source": [
261+
"db = SQLDatabase.from_uri(\n",
262+
" \"sqlite:///../../../../notebooks/Chinook.db\", \n",
263+
" include_tables=['Track'], # we include only one table to save tokens in the prompt :)\n",
264+
" sample_row_in_table_info=True)"
265+
]
266+
},
267+
{
268+
"cell_type": "code",
269+
"execution_count": 12,
270+
"id": "bcb7a489",
271+
"metadata": {},
272+
"outputs": [],
273+
"source": [
274+
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)"
275+
]
276+
},
277+
{
278+
"cell_type": "code",
279+
"execution_count": 13,
280+
"id": "81e05d82",
281+
"metadata": {},
282+
"outputs": [
283+
{
284+
"name": "stdout",
285+
"output_type": "stream",
286+
"text": [
287+
"\n",
288+
"\n",
289+
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
290+
"What are some example tracks by Bach? \n",
291+
"SQLQuery:Table 'Track' has columns: TrackId (INTEGER), Name (NVARCHAR(200)), AlbumId (INTEGER), MediaTypeId (INTEGER), GenreId (INTEGER), Composer (NVARCHAR(220)), Milliseconds (INTEGER), Bytes (INTEGER), UnitPrice (NUMERIC(10, 2)). Here is an example row for this table (long strings are truncated): ['1', 'For Those About To Rock (We Salute You)', '1', '1', '1', 'Angus Young, Malcolm Young, Brian Johnson', '343719', '11170334', '0.99'].\n",
292+
"\u001b[32;1m\u001b[1;3m SELECT TrackId, Name, Composer FROM Track WHERE Composer LIKE '%Bach%' ORDER BY Name LIMIT 5;\u001b[0m\n",
293+
"SQLResult: \u001b[33;1m\u001b[1;3m[(1709, 'American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), (3408, 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), (3433, 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Johann Sebastian Bach'), (3407, 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), (3490, 'Partita in E Major, BWV 1006A: I. Prelude', 'Johann Sebastian Bach')]\u001b[0m\n",
294+
"Answer:\u001b[32;1m\u001b[1;3m Some example tracks by Bach are 'American Woman', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', and 'Partita in E Major, BWV 1006A: I. Prelude'.\u001b[0m\n",
295+
"\u001b[1m> Finished chain.\u001b[0m\n"
296+
]
297+
},
298+
{
299+
"data": {
300+
"text/plain": [
301+
"' Some example tracks by Bach are \\'American Woman\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Concerto No.2 in F Major, BWV1047, I. Allegro\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', and \\'Partita in E Major, BWV 1006A: I. Prelude\\'.'"
302+
]
303+
},
304+
"execution_count": 13,
305+
"metadata": {},
306+
"output_type": "execute_result"
307+
}
308+
],
309+
"source": [
310+
"db_chain.run(\"What are some example tracks by Bach?\")"
311+
]
312+
},
245313
{
246314
"cell_type": "markdown",
247315
"id": "c12ae15a",
@@ -319,14 +387,6 @@
319387
"source": [
320388
"chain.run(\"How many employees are also customers?\")"
321389
]
322-
},
323-
{
324-
"cell_type": "code",
325-
"execution_count": null,
326-
"id": "b2998b03",
327-
"metadata": {},
328-
"outputs": [],
329-
"source": []
330390
}
331391
],
332392
"metadata": {
@@ -345,7 +405,7 @@
345405
"name": "python",
346406
"nbconvert_exporter": "python",
347407
"pygments_lexer": "ipython3",
348-
"version": "3.10.9"
408+
"version": "3.8.16"
349409
}
350410
},
351411
"nbformat": 4,

langchain/sql_database.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
schema: Optional[str] = None,
1717
ignore_tables: Optional[List[str]] = None,
1818
include_tables: Optional[List[str]] = None,
19+
sample_row_in_table_info: bool = False,
1920
):
2021
"""Create engine from database URI."""
2122
self._engine = engine
@@ -39,6 +40,7 @@ def __init__(
3940
raise ValueError(
4041
f"ignore_tables {missing_tables} not found in database"
4142
)
43+
self._sample_row_in_table_info = sample_row_in_table_info
4244

4345
@classmethod
4446
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
@@ -69,14 +71,28 @@ def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
6971
if missing_tables:
7072
raise ValueError(f"table_names {missing_tables} not found in database")
7173
all_table_names = table_names
74+
7275
template = "Table '{table_name}' has columns: {columns}."
76+
7377
tables = []
7478
for table_name in all_table_names:
79+
7580
columns = []
7681
for column in self._inspector.get_columns(table_name, schema=self._schema):
7782
columns.append(f"{column['name']} ({str(column['type'])})")
7883
column_str = ", ".join(columns)
7984
table_str = template.format(table_name=table_name, columns=column_str)
85+
86+
if self._sample_row_in_table_info:
87+
row_template = (
88+
" Here is an example row for this table"
89+
" (long strings are truncated): {sample_row}."
90+
)
91+
sample_row = self.run(f"SELECT * FROM '{table_name}' LIMIT 1")
92+
if len(eval(sample_row)) > 0:
93+
sample_row = " ".join([str(i)[:100] for i in eval(sample_row)[0]])
94+
table_str += row_template.format(sample_row=sample_row)
95+
8096
tables.append(table_str)
8197
return "\n".join(tables)
8298

tests/unit_tests/test_sql_database.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,27 @@ def test_table_info() -> None:
3535
assert sorted(output.split("\n")) == sorted(expected_output)
3636

3737

38+
def test_table_info_w_sample_row() -> None:
39+
"""Test that table info is constructed properly."""
40+
engine = create_engine("sqlite:///:memory:")
41+
metadata_obj.create_all(engine)
42+
stmt = insert(user).values(user_id=13, user_name="Harrison")
43+
with engine.begin() as conn:
44+
conn.execute(stmt)
45+
46+
db = SQLDatabase(engine, sample_row_in_table_info=True)
47+
48+
output = db.table_info
49+
expected_output = (
50+
"Table 'company' has columns: company_id (INTEGER), "
51+
"company_location (VARCHAR).\n"
52+
"Table 'user' has columns: user_id (INTEGER), "
53+
"user_name (VARCHAR(16)). Here is an example row "
54+
"for this table (long strings are truncated): 13 Harrison."
55+
)
56+
assert sorted(output.split("\n")) == sorted(expected_output.split("\n"))
57+
58+
3859
def test_sql_database_run() -> None:
3960
"""Test that commands can be run successfully and returned in correct format."""
4061
engine = create_engine("sqlite:///:memory:")

0 commit comments

Comments
 (0)