Skip to content

Commit 6be5f4e

Browse files
hwchase17bborn
andauthored
Harrison/sql db chain (#641)
Co-authored-by: Bruno Bornsztein <[email protected]>
1 parent b550f57 commit 6be5f4e

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

docs/modules/chains/examples/sqlite.ipynb

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@
266266
"metadata": {},
267267
"outputs": [],
268268
"source": [
269-
"from langchain.chains.sql_database.base import SQLDatabaseSequentialChain"
269+
"from langchain.chains import SQLDatabaseSequentialChain"
270270
]
271271
},
272272
{
@@ -293,14 +293,22 @@
293293
"\n",
294294
"\u001b[1m> Entering new SQLDatabaseSequentialChain chain...\u001b[0m\n",
295295
"Table names to use:\n",
296-
"\u001b[33;1m\u001b[1;3m['Employee', 'Customer']\u001b[0m\n",
296+
"\u001b[33;1m\u001b[1;3m['Customer', 'Employee']\u001b[0m\n",
297+
"\n",
298+
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
299+
"How many employees are also customers? \n",
300+
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Customer c INNER JOIN Employee e ON c.SupportRepId = e.EmployeeId;\u001b[0m\n",
301+
"SQLResult: \u001b[33;1m\u001b[1;3m[(59,)]\u001b[0m\n",
302+
"Answer:\u001b[32;1m\u001b[1;3m There are 59 employees who are also customers.\u001b[0m\n",
303+
"\u001b[1m> Finished chain.\u001b[0m\n",
304+
"\n",
297305
"\u001b[1m> Finished chain.\u001b[0m\n"
298306
]
299307
},
300308
{
301309
"data": {
302310
"text/plain": [
303-
"' 0 employees are also customers.'"
311+
"' There are 59 employees who are also customers.'"
304312
]
305313
},
306314
"execution_count": 5,

langchain/chains/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
1313
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
1414
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
15-
from langchain.chains.sql_database.base import SQLDatabaseChain
15+
from langchain.chains.sql_database.base import (
16+
SQLDatabaseChain,
17+
SQLDatabaseSequentialChain,
18+
)
1619
from langchain.chains.transform import TransformChain
1720
from langchain.chains.vector_db_qa.base import VectorDBQA
1821

@@ -35,4 +38,5 @@
3538
"TransformChain",
3639
"MapReduceChain",
3740
"OpenAIModerationChain",
41+
"SQLDatabaseSequentialChain",
3842
]

langchain/chains/sql_database/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def from_llm(
109109
**kwargs: Any,
110110
) -> SQLDatabaseSequentialChain:
111111
"""Load the necessary chains."""
112-
sql_chain = SQLDatabaseChain(llm=llm, database=database, prompt=query_prompt)
112+
sql_chain = SQLDatabaseChain(
113+
llm=llm, database=database, prompt=query_prompt, **kwargs
114+
)
113115
decider_chain = LLMChain(
114116
llm=llm, prompt=decider_prompt, output_key="table_names"
115117
)

0 commit comments

Comments
 (0)