Skip to content

Commit 1c71fad

Browse files
authored
more complex sql chain (#619)
add a more complex sql chain that first subsets the necessary tables
1 parent 49b3d6c commit 1c71fad

File tree

6 files changed

+184
-8
lines changed

6 files changed

+184
-8
lines changed

docs/modules/chains/examples/sqlite.ipynb

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,82 @@
179179
"db_chain.run(\"How many employees are there in the foobar table?\")"
180180
]
181181
},
182+
{
183+
"cell_type": "markdown",
184+
"id": "c12ae15a",
185+
"metadata": {},
186+
"source": [
187+
"## SQLDatabaseSequentialChain\n",
188+
"\n",
189+
"Chain for querying SQL database that is a sequential chain.\n",
190+
"\n",
191+
"The chain is as follows:\n",
192+
"\n",
193+
" 1. Based on the query, determine which tables to use.\n",
194+
" 2. Based on those tables, call the normal SQL database chain.\n",
195+
"\n",
196+
"This is useful in cases where the number of tables in the database is large."
197+
]
198+
},
182199
{
183200
"cell_type": "code",
184-
"execution_count": null,
201+
"execution_count": 3,
185202
"id": "e59a4740",
186203
"metadata": {},
187204
"outputs": [],
205+
"source": [
206+
"from langchain.chains.sql_database.base import SQLDatabaseSequentialChain"
207+
]
208+
},
209+
{
210+
"cell_type": "code",
211+
"execution_count": 4,
212+
"id": "58bb49b6",
213+
"metadata": {},
214+
"outputs": [],
215+
"source": [
216+
"chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)"
217+
]
218+
},
219+
{
220+
"cell_type": "code",
221+
"execution_count": 5,
222+
"id": "95017b1a",
223+
"metadata": {},
224+
"outputs": [
225+
{
226+
"name": "stdout",
227+
"output_type": "stream",
228+
"text": [
229+
"\n",
230+
"\n",
231+
"\u001b[1m> Entering new SQLDatabaseSequentialChain chain...\u001b[0m\n",
232+
"Table names to use:\n",
233+
"\u001b[33;1m\u001b[1;3m['Employee', 'Customer']\u001b[0m\n",
234+
"\u001b[1m> Finished chain.\u001b[0m\n"
235+
]
236+
},
237+
{
238+
"data": {
239+
"text/plain": [
240+
"' 0 employees are also customers.'"
241+
]
242+
},
243+
"execution_count": 5,
244+
"metadata": {},
245+
"output_type": "execute_result"
246+
}
247+
],
248+
"source": [
249+
"chain.run(\"How many employees are also customers?\")"
250+
]
251+
},
252+
{
253+
"cell_type": "code",
254+
"execution_count": null,
255+
"id": "b2998b03",
256+
"metadata": {},
257+
"outputs": [],
188258
"source": []
189259
}
190260
],

langchain/chains/sequential.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def validate_chains(cls, values: Dict) -> Dict:
4747
for chain in chains:
4848
missing_vars = set(chain.input_keys).difference(known_variables)
4949
if missing_vars:
50-
raise ValueError(f"Missing required input keys: {missing_vars}")
50+
raise ValueError(
51+
f"Missing required input keys: {missing_vars}, "
52+
f"only had {known_variables}"
53+
)
5154
overlapping_keys = known_variables.intersection(chain.output_keys)
5255
if overlapping_keys:
5356
raise ValueError(

langchain/chains/sql_database/base.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Chain for interacting with SQL Database."""
2-
from typing import Dict, List
2+
from __future__ import annotations
3+
4+
from typing import Any, Dict, List
35

46
from pydantic import BaseModel, Extra
57

68
from langchain.chains.base import Chain
79
from langchain.chains.llm import LLMChain
8-
from langchain.chains.sql_database.prompt import PROMPT
10+
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT
911
from langchain.llms.base import BaseLLM
1012
from langchain.prompts.base import BasePromptTemplate
1113
from langchain.sql_database import SQLDatabase
@@ -53,15 +55,18 @@ def output_keys(self) -> List[str]:
5355
"""
5456
return [self.output_key]
5557

56-
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
58+
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
5759
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
5860
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
5961
if self.verbose:
6062
self.callback_manager.on_text(input_text)
63+
# If not present, then defaults to None which is all tables.
64+
table_names_to_use = inputs.get("table_names_to_use")
65+
table_info = self.database.get_table_info(table_names=table_names_to_use)
6166
llm_inputs = {
6267
"input": input_text,
6368
"dialect": self.database.dialect,
64-
"table_info": self.database.table_info,
69+
"table_info": table_info,
6570
"stop": ["\nSQLResult:"],
6671
}
6772
sql_cmd = llm_chain.predict(**llm_inputs)
@@ -78,3 +83,68 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
7883
if self.verbose:
7984
self.callback_manager.on_text(final_result, color="green")
8085
return {self.output_key: final_result}
86+
87+
88+
class SQLDatabaseSequentialChain(Chain, BaseModel):
89+
"""Chain for querying SQL database that is a sequential chain.
90+
91+
The chain is as follows:
92+
1. Based on the query, determine which tables to use.
93+
2. Based on those tables, call the normal SQL database chain.
94+
95+
This is useful in cases where the number of tables in the database is large.
96+
"""
97+
98+
@classmethod
99+
def from_llm(
100+
cls,
101+
llm: BaseLLM,
102+
database: SQLDatabase,
103+
query_prompt: BasePromptTemplate = PROMPT,
104+
decider_prompt: BasePromptTemplate = DECIDER_PROMPT,
105+
**kwargs: Any,
106+
) -> SQLDatabaseSequentialChain:
107+
"""Load the necessary chains."""
108+
sql_chain = SQLDatabaseChain(llm=llm, database=database, prompt=query_prompt)
109+
decider_chain = LLMChain(
110+
llm=llm, prompt=decider_prompt, output_key="table_names"
111+
)
112+
return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)
113+
114+
decider_chain: LLMChain
115+
sql_chain: SQLDatabaseChain
116+
input_key: str = "query" #: :meta private:
117+
output_key: str = "result" #: :meta private:
118+
119+
@property
120+
def input_keys(self) -> List[str]:
121+
"""Return the singular input key.
122+
123+
:meta private:
124+
"""
125+
return [self.input_key]
126+
127+
@property
128+
def output_keys(self) -> List[str]:
129+
"""Return the singular output key.
130+
131+
:meta private:
132+
"""
133+
return [self.output_key]
134+
135+
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
136+
_table_names = self.sql_chain.database.get_table_names()
137+
table_names = ", ".join(_table_names)
138+
llm_inputs = {
139+
"query": inputs[self.input_key],
140+
"table_names": table_names,
141+
}
142+
table_names_to_use = self.decider_chain.predict_and_parse(**llm_inputs)
143+
if self.verbose:
144+
self.callback_manager.on_text("Table names to use:", end="\n")
145+
self.callback_manager.on_text(str(table_names_to_use), color="yellow")
146+
new_inputs = {
147+
self.sql_chain.input_key: inputs[self.input_key],
148+
"table_names_to_use": table_names_to_use,
149+
}
150+
return self.sql_chain(new_inputs, return_only_outputs=True)

langchain/chains/sql_database/prompt.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# flake8: noqa
2+
from langchain.prompts.base import CommaSeparatedListOutputParser
23
from langchain.prompts.prompt import PromptTemplate
34

45
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
@@ -17,3 +18,16 @@
1718
PROMPT = PromptTemplate(
1819
input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
1920
)
21+
22+
_DECIDER_TEMPLATE = """Given the below input question and list of potential tables, output a comma separated list of the table names that may be neccessary to answer this question.
23+
24+
Question: {query}
25+
26+
Table Names: {table_names}
27+
28+
Relevant Table Names:"""
29+
DECIDER_PROMPT = PromptTemplate(
30+
input_variables=["query", "table_names"],
31+
template=_DECIDER_TEMPLATE,
32+
output_parser=CommaSeparatedListOutputParser(),
33+
)

langchain/prompts/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ def parse(self, text: str) -> List[str]:
6464
"""Parse the output of an LLM call."""
6565

6666

67+
class CommaSeparatedListOutputParser(ListOutputParser):
68+
"""Parse out comma separated lists."""
69+
70+
def parse(self, text: str) -> List[str]:
71+
"""Parse the output of an LLM call."""
72+
return text.strip().split(", ")
73+
74+
6775
class RegexParser(BaseOutputParser, BaseModel):
6876
"""Class to parse the output into a dictionary."""
6977

langchain/sql_database.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,28 @@ def dialect(self) -> str:
5050
"""Return string representation of dialect to use."""
5151
return self._engine.dialect.name
5252

53-
def _get_table_names(self) -> Iterable[str]:
53+
def get_table_names(self) -> Iterable[str]:
54+
"""Get names of tables available."""
5455
if self._include_tables:
5556
return self._include_tables
5657
return set(self._all_tables) - set(self._ignore_tables)
5758

5859
@property
5960
def table_info(self) -> str:
6061
"""Information about all tables in the database."""
62+
return self.get_table_info()
63+
64+
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
65+
"""Get information about specified tables."""
66+
all_table_names = self.get_table_names()
67+
if table_names is not None:
68+
missing_tables = set(table_names).difference(all_table_names)
69+
if missing_tables:
70+
raise ValueError(f"table_names {missing_tables} not found in database")
71+
all_table_names = table_names
6172
template = "Table '{table_name}' has columns: {columns}."
6273
tables = []
63-
for table_name in self._get_table_names():
74+
for table_name in all_table_names:
6475
columns = []
6576
for column in self._inspector.get_columns(table_name, schema=self._schema):
6677
columns.append(f"{column['name']} ({str(column['type'])})")

0 commit comments

Comments
 (0)