@@ -35,6 +35,9 @@ class SQLDatabaseChain(Chain, BaseModel):
35
35
input_key : str = "query" #: :meta private:
36
36
output_key : str = "result" #: :meta private:
37
37
return_intermediate_steps : bool = False
38
+ """Whether or not to return the intermediate steps along with the final answer."""
39
+ return_direct : bool = False
40
+ """Whether or not to return the result of querying the SQL table directly."""
38
41
39
42
class Config :
40
43
"""Configuration for this pydantic object."""
@@ -83,11 +86,17 @@ def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
83
86
intermediate_steps .append (result )
84
87
self .callback_manager .on_text ("\n SQLResult: " , verbose = self .verbose )
85
88
self .callback_manager .on_text (result , color = "yellow" , verbose = self .verbose )
86
- self .callback_manager .on_text ("\n Answer:" , verbose = self .verbose )
87
- input_text += f"{ sql_cmd } \n SQLResult: { result } \n Answer:"
88
- llm_inputs ["input" ] = input_text
89
- final_result = llm_chain .predict (** llm_inputs )
90
- self .callback_manager .on_text (final_result , color = "green" , verbose = self .verbose )
89
+ # If return direct, we just set the final result equal to the sql query
90
+ if self .return_direct :
91
+ final_result = result
92
+ else :
93
+ self .callback_manager .on_text ("\n Answer:" , verbose = self .verbose )
94
+ input_text += f"{ sql_cmd } \n SQLResult: { result } \n Answer:"
95
+ llm_inputs ["input" ] = input_text
96
+ final_result = llm_chain .predict (** llm_inputs )
97
+ self .callback_manager .on_text (
98
+ final_result , color = "green" , verbose = self .verbose
99
+ )
91
100
chain_result : Dict [str , Any ] = {self .output_key : final_result }
92
101
if self .return_intermediate_steps :
93
102
chain_result ["intermediate_steps" ] = intermediate_steps
0 commit comments