Skip to content

Commit 94ae126

Browse files
authored
return sql intermediate steps (#792)
1 parent ae5695a commit 94ae126

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

docs/modules/chains/examples/sqlite.ipynb

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
},
6969
{
7070
"cell_type": "code",
71-
"execution_count": 3,
71+
"execution_count": 4,
7272
"id": "15ff81df",
7373
"metadata": {
7474
"pycharm": {
@@ -96,7 +96,7 @@
9696
"' There are 9 employees.'"
9797
]
9898
},
99-
"execution_count": 3,
99+
"execution_count": 4,
100100
"metadata": {},
101101
"output_type": "execute_result"
102102
}
@@ -188,6 +188,62 @@
188188
"db_chain.run(\"How many employees are there in the foobar table?\")"
189189
]
190190
},
191+
{
192+
"cell_type": "markdown",
193+
"id": "88d8b969",
194+
"metadata": {},
195+
"source": [
196+
"## Return Intermediate Steps\n",
197+
"\n",
198+
"You can also return the intermediate steps of the SQLDatabaseChain. This allows you to access the SQL statement that was generated, as well as the result of running that against the SQL Database."
199+
]
200+
},
201+
{
202+
"cell_type": "code",
203+
"execution_count": 8,
204+
"id": "38559487",
205+
"metadata": {},
206+
"outputs": [],
207+
"source": [
208+
"db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True, return_intermediate_steps=True)"
209+
]
210+
},
211+
{
212+
"cell_type": "code",
213+
"execution_count": 10,
214+
"id": "78b6af4d",
215+
"metadata": {},
216+
"outputs": [
217+
{
218+
"name": "stdout",
219+
"output_type": "stream",
220+
"text": [
221+
"\n",
222+
"\n",
223+
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
224+
"How many employees are there in the foobar table? \n",
225+
"SQLQuery:\u001b[32;1m\u001b[1;3m SELECT COUNT(*) FROM Employee;\u001b[0m\n",
226+
"SQLResult: \u001b[33;1m\u001b[1;3m[(9,)]\u001b[0m\n",
227+
"Answer:\u001b[32;1m\u001b[1;3m There are 9 employees in the foobar table.\u001b[0m\n",
228+
"\u001b[1m> Finished chain.\u001b[0m\n"
229+
]
230+
},
231+
{
232+
"data": {
233+
"text/plain": [
234+
"[' SELECT COUNT(*) FROM Employee;', '[(9,)]']"
235+
]
236+
},
237+
"execution_count": 10,
238+
"metadata": {},
239+
"output_type": "execute_result"
240+
}
241+
],
242+
"source": [
243+
"result = db_chain(\"How many employees are there in the foobar table?\")\n",
244+
"result[\"intermediate_steps\"]"
245+
]
246+
},
191247
{
192248
"cell_type": "markdown",
193249
"id": "b408f800",
@@ -405,7 +461,7 @@
405461
"name": "python",
406462
"nbconvert_exporter": "python",
407463
"pygments_lexer": "ipython3",
408-
"version": "3.8.16"
464+
"version": "3.10.9"
409465
}
410466
},
411467
"nbformat": 4,

langchain/chains/sql_database/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class SQLDatabaseChain(Chain, BaseModel):
3434
"""Number of results to return from the query"""
3535
input_key: str = "query" #: :meta private:
3636
output_key: str = "result" #: :meta private:
37+
return_intermediate_steps: bool = False
3738

3839
class Config:
3940
"""Configuration for this pydantic object."""
@@ -55,9 +56,12 @@ def output_keys(self) -> List[str]:
5556
5657
:meta private:
5758
"""
58-
return [self.output_key]
59+
if not self.return_intermediate_steps:
60+
return [self.output_key]
61+
else:
62+
return [self.output_key, "intermediate_steps"]
5963

60-
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
64+
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
6165
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
6266
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
6367
self.callback_manager.on_text(input_text, verbose=self.verbose)
@@ -71,18 +75,23 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
7175
"table_info": table_info,
7276
"stop": ["\nSQLResult:"],
7377
}
74-
78+
intermediate_steps = []
7579
sql_cmd = llm_chain.predict(**llm_inputs)
80+
intermediate_steps.append(sql_cmd)
7681
self.callback_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
7782
result = self.database.run(sql_cmd)
83+
intermediate_steps.append(result)
7884
self.callback_manager.on_text("\nSQLResult: ", verbose=self.verbose)
7985
self.callback_manager.on_text(result, color="yellow", verbose=self.verbose)
8086
self.callback_manager.on_text("\nAnswer:", verbose=self.verbose)
8187
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
8288
llm_inputs["input"] = input_text
8389
final_result = llm_chain.predict(**llm_inputs)
8490
self.callback_manager.on_text(final_result, color="green", verbose=self.verbose)
85-
return {self.output_key: final_result}
91+
chain_result: Dict[str, Any] = {self.output_key: final_result}
92+
if self.return_intermediate_steps:
93+
chain_result["intermediate_steps"] = intermediate_steps
94+
return chain_result
8695

8796
@property
8897
def _chain_type(self) -> str:

0 commit comments

Comments
 (0)