Skip to content

Commit dc65fa9

Browse files
committed
Update parameters
1 parent cc3d8cc commit dc65fa9

File tree

3 files changed

+55
-30
lines changed

3 files changed

+55
-30
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,57 +68,48 @@ def get_all_agents(self):
6868
# Get current datetime for the Query Rewrite Agent
6969
current_datetime = datetime.now()
7070

71-
QUERY_REWRITE_AGENT = LLMAgentCreator.create(
71+
self.query_rewrite_agent = LLMAgentCreator.create(
7272
"query_rewrite_agent", current_datetime=current_datetime
7373
)
7474

75-
SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
75+
self.sql_query_generation_agent = LLMAgentCreator.create(
7676
"sql_query_generation_agent",
7777
target_engine=self.target_engine,
7878
engine_specific_rules=self.engine_specific_rules,
7979
**self.kwargs,
8080
)
8181

82-
SQL_SCHEMA_SELECTION_AGENT = SqlSchemaSelectionAgent(
83-
target_engine=self.target_engine,
84-
engine_specific_rules=self.engine_specific_rules,
85-
**self.kwargs,
86-
)
82+
self.sql_schema_selection_agent = SqlSchemaSelectionAgent()
8783

88-
SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
84+
self.sql_query_correction_agent = LLMAgentCreator.create(
8985
"sql_query_correction_agent",
9086
target_engine=self.target_engine,
9187
engine_specific_rules=self.engine_specific_rules,
9288
**self.kwargs,
9389
)
9490

95-
SQL_DISAMBIGUATION_AGENT = LLMAgentCreator.create(
91+
self.sql_disambiguation_agent = LLMAgentCreator.create(
9692
"sql_disambiguation_agent",
9793
target_engine=self.target_engine,
9894
engine_specific_rules=self.engine_specific_rules,
9995
**self.kwargs,
10096
)
10197

102-
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create(
103-
"question_decomposition_agent"
104-
)
105-
10698
# Auto-responding UserProxyAgent
107-
USER_PROXY = EmptyResponseUserProxyAgent(name="user_proxy")
99+
self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy")
108100

109101
agents = [
110-
USER_PROXY,
111-
QUERY_REWRITE_AGENT,
112-
SQL_QUERY_GENERATION_AGENT,
113-
SQL_SCHEMA_SELECTION_AGENT,
114-
SQL_QUERY_CORRECTION_AGENT,
115-
QUESTION_DECOMPOSITION_AGENT,
116-
SQL_DISAMBIGUATION_AGENT,
102+
self.user_proxy,
103+
self.query_rewrite_agent,
104+
self.sql_query_generation_agent,
105+
self.sql_schema_selection_agent,
106+
self.sql_query_correction_agent,
107+
self.sql_disambiguation_agent,
117108
]
118109

119110
if self.use_query_cache:
120-
SQL_QUERY_CACHE_AGENT = SqlQueryCacheAgent()
121-
agents.append(SQL_QUERY_CACHE_AGENT)
111+
self.query_cache_agent = SqlQueryCacheAgent(**self.kwargs)
112+
agents.append(self.query_cache_agent)
122113

123114
return agents
124115

@@ -195,13 +186,30 @@ def agentic_flow(self):
195186
)
196187
return flow
197188

198-
async def process_question(self, task: str, chat_history: list[str] = None):
199-
"""Process the complete question through the unified system."""
189+
async def process_question(
190+
self, task: str, chat_history: list[str] = None, parameters: dict = None
191+
):
192+
"""Process the complete question through the unified system.
193+
194+
Args:
195+
----
196+
task (str): The user question to process.
197+
chat_history (list[str], optional): The chat history. Defaults to None.
198+
parameters (dict, optional): The parameters to pass to the agents. Defaults to None.
199+
200+
Returns:
201+
-------
202+
dict: The response from the system.
203+
"""
200204

201205
logging.info("Processing question: %s", task)
202206
logging.info("Chat history: %s", chat_history)
203207

204-
agent_input = {"user_question": task, "chat_history": {}}
208+
agent_input = {
209+
"user_question": task,
210+
"chat_history": {},
211+
"parameters": parameters,
212+
}
205213

206214
if chat_history is not None:
207215
# Update input

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313

1414
class SqlQueryCacheAgent(BaseChatAgent):
15-
def __init__(self, name: str = "sql_query_cache_agent"):
15+
def __init__(self):
1616
super().__init__(
17-
name,
17+
"sql_query_cache_agent",
1818
"An agent that fetches the queries from the cache based on the user question.",
1919
)
2020

@@ -39,9 +39,11 @@ async def on_messages_stream(
3939
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4040
) -> AsyncGenerator[AgentMessage | Response, None]:
4141
# Get the decomposed questions from the query_rewrite_agent
42+
parameter_input = messages[0].content
4243
last_response = messages[-1].content
4344
try:
4445
user_questions = json.loads(last_response)
46+
user_parameters = json.loads(parameter_input)["parameters"]
4547
logging.info(f"Processing questions: {user_questions}")
4648

4749
# Initialize results dictionary
@@ -55,7 +57,7 @@ async def on_messages_stream(
5557
# Fetch the queries from the cache based on the question
5658
logging.info(f"Fetching queries from cache for question: {question}")
5759
cached_query = await self.sql_connector.fetch_queries_from_cache(
58-
question
60+
question, parameters=user_parameters
5961
)
6062

6163
# If any question has pre-run results, set the flag

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sqlglot
99
from abc import ABC, abstractmethod
1010
from datetime import datetime
11+
from jinja2 import Template
1112

1213

1314
class SqlConnector(ABC):
@@ -118,7 +119,9 @@ async def query_validation(
118119
logging.info("SQL Query is valid.")
119120
return True
120121

121-
async def fetch_queries_from_cache(self, question: str) -> str:
122+
async def fetch_queries_from_cache(
123+
self, question: str, parameters: dict = None
124+
) -> str:
122125
"""Fetch the queries from the cache based on the question.
123126
124127
Args:
@@ -129,6 +132,10 @@ async def fetch_queries_from_cache(self, question: str) -> str:
129132
-------
130133
str: The formatted string of the queries fetched from the cache. This is injected into the prompt.
131134
"""
135+
136+
if parameters is None:
137+
parameters = {}
138+
132139
cached_schemas = await self.ai_search_connector.run_ai_search_query(
133140
question,
134141
["QuestionEmbedding"],
@@ -146,6 +153,14 @@ async def fetch_queries_from_cache(self, question: str) -> str:
146153
"cached_questions_and_schemas": None,
147154
}
148155

156+
# loop through all sql queries and populate the template in place
157+
for schema in cached_schemas:
158+
sql_queries = schema["SqlQueryDecomposition"]
159+
for sql_query in sql_queries:
160+
sql_query["SqlQuery"] = Template(sql_query["SqlQuery"]).render(
161+
**parameters
162+
)
163+
149164
logging.info("Cached schemas: %s", cached_schemas)
150165
if self.pre_run_query_cache and len(cached_schemas) > 0:
151166
# check the score

0 commit comments

Comments
 (0)