Skip to content

Commit 1890062

Browse files
Parameter Cache Rendering (#106)
1 parent cc3d8cc commit 1890062

File tree

4 files changed

+86
-80
lines changed

4 files changed

+86
-80
lines changed

text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
"metadata": {},
8686
"outputs": [],
8787
"source": [
88-
"agentic_text_2_sql = AutoGenText2Sql(engine_specific_rules=\"Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.\", use_case=\"Analysing sales data across product categories.\").agentic_flow"
88+
"agentic_text_2_sql = AutoGenText2Sql(engine_specific_rules=\"\", use_case=\"Analysing sales data across suppliers\")"
8989
]
9090
},
9191
{
@@ -101,56 +101,8 @@
101101
"metadata": {},
102102
"outputs": [],
103103
"source": [
104-
"result = agentic_text_2_sql.run_stream(task=\"What country did we sell the most to in June 2008?\")\n",
105-
"await Console(result)"
106-
]
107-
},
108-
{
109-
"cell_type": "code",
110-
"execution_count": null,
111-
"metadata": {},
112-
"outputs": [],
113-
"source": [
114-
"await agentic_text_2_sql.reset()"
115-
]
116-
},
117-
{
118-
"cell_type": "code",
119-
"execution_count": null,
120-
"metadata": {},
121-
"outputs": [],
122-
"source": [
123-
"result = agentic_text_2_sql.run_stream(task=\"What are the total number of sales within 2008 for the mountain bike category?\")\n",
124-
"await Console(result)"
125-
]
126-
},
127-
{
128-
"cell_type": "code",
129-
"execution_count": null,
130-
"metadata": {},
131-
"outputs": [],
132-
"source": [
133-
"await agentic_text_2_sql.reset()"
134-
]
135-
},
136-
{
137-
"cell_type": "code",
138-
"execution_count": null,
139-
"metadata": {},
140-
"outputs": [],
141-
"source": [
142-
"result = agentic_text_2_sql.run_stream(task=\"What are the total number of sales within 2008?\")\n",
143-
"await Console(result)"
144-
]
145-
},
146-
{
147-
"cell_type": "code",
148-
"execution_count": null,
149-
"metadata": {},
150-
"outputs": [],
151-
"source": [
152-
"\n",
153-
"await agentic_text_2_sql.reset()"
104+
"result = await agentic_text_2_sql.process_question(task=\"What total number of orders in June 2008?\")\n",
105+
"await Console(result)\n"
154106
]
155107
},
156108
{

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,57 +68,52 @@ def get_all_agents(self):
6868
# Get current datetime for the Query Rewrite Agent
6969
current_datetime = datetime.now()
7070

71-
QUERY_REWRITE_AGENT = LLMAgentCreator.create(
71+
self.query_rewrite_agent = LLMAgentCreator.create(
7272
"query_rewrite_agent", current_datetime=current_datetime
7373
)
7474

75-
SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
75+
self.sql_query_generation_agent = LLMAgentCreator.create(
7676
"sql_query_generation_agent",
7777
target_engine=self.target_engine,
7878
engine_specific_rules=self.engine_specific_rules,
7979
**self.kwargs,
8080
)
8181

82-
SQL_SCHEMA_SELECTION_AGENT = SqlSchemaSelectionAgent(
82+
self.sql_schema_selection_agent = SqlSchemaSelectionAgent(
8383
target_engine=self.target_engine,
8484
engine_specific_rules=self.engine_specific_rules,
8585
**self.kwargs,
8686
)
8787

88-
SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
88+
self.sql_query_correction_agent = LLMAgentCreator.create(
8989
"sql_query_correction_agent",
9090
target_engine=self.target_engine,
9191
engine_specific_rules=self.engine_specific_rules,
9292
**self.kwargs,
9393
)
9494

95-
SQL_DISAMBIGUATION_AGENT = LLMAgentCreator.create(
95+
self.sql_disambiguation_agent = LLMAgentCreator.create(
9696
"sql_disambiguation_agent",
9797
target_engine=self.target_engine,
9898
engine_specific_rules=self.engine_specific_rules,
9999
**self.kwargs,
100100
)
101101

102-
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create(
103-
"question_decomposition_agent"
104-
)
105-
106102
# Auto-responding UserProxyAgent
107-
USER_PROXY = EmptyResponseUserProxyAgent(name="user_proxy")
103+
self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy")
108104

109105
agents = [
110-
USER_PROXY,
111-
QUERY_REWRITE_AGENT,
112-
SQL_QUERY_GENERATION_AGENT,
113-
SQL_SCHEMA_SELECTION_AGENT,
114-
SQL_QUERY_CORRECTION_AGENT,
115-
QUESTION_DECOMPOSITION_AGENT,
116-
SQL_DISAMBIGUATION_AGENT,
106+
self.user_proxy,
107+
self.query_rewrite_agent,
108+
self.sql_query_generation_agent,
109+
self.sql_schema_selection_agent,
110+
self.sql_query_correction_agent,
111+
self.sql_disambiguation_agent,
117112
]
118113

119114
if self.use_query_cache:
120-
SQL_QUERY_CACHE_AGENT = SqlQueryCacheAgent()
121-
agents.append(SQL_QUERY_CACHE_AGENT)
115+
self.query_cache_agent = SqlQueryCacheAgent()
116+
agents.append(self.query_cache_agent)
122117

123118
return agents
124119

@@ -136,7 +131,7 @@ def termination_condition(self):
136131
)
137132
return termination
138133

139-
def unified_selector(messages):
134+
def unified_selector(self, messages):
140135
"""Unified selector for the complete flow."""
141136
logging.info("Messages: %s", messages)
142137
decision = None
@@ -195,18 +190,34 @@ def agentic_flow(self):
195190
)
196191
return flow
197192

198-
async def process_question(self, task: str, chat_history: list[str] = None):
199-
"""Process the complete question through the unified system."""
193+
async def process_question(
194+
self, task: str, chat_history: list[str] = None, parameters: dict = None
195+
):
196+
"""Process the complete question through the unified system.
197+
198+
Args:
199+
----
200+
task (str): The user question to process.
201+
chat_history (list[str], optional): The chat history. Defaults to None.
202+
parameters (dict, optional): The parameters to pass to the agents. Defaults to None.
203+
204+
Returns:
205+
-------
206+
dict: The response from the system.
207+
"""
200208

201209
logging.info("Processing question: %s", task)
202210
logging.info("Chat history: %s", chat_history)
203211

204-
agent_input = {"user_question": task, "chat_history": {}}
212+
agent_input = {
213+
"user_question": task,
214+
"chat_history": {},
215+
"parameters": parameters,
216+
}
205217

206218
if chat_history is not None:
207219
# Update input
208220
for idx, chat in enumerate(chat_history):
209221
agent_input[f"chat_{idx}"] = chat
210222

211-
result = await self.agentic_flow.run_stream(task=json.dumps(agent_input))
212-
return result
223+
return self.agentic_flow.run_stream(task=json.dumps(agent_input))

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313

1414
class SqlQueryCacheAgent(BaseChatAgent):
15-
def __init__(self, name: str = "sql_query_cache_agent"):
15+
def __init__(self):
1616
super().__init__(
17-
name,
17+
"sql_query_cache_agent",
1818
"An agent that fetches the queries from the cache based on the user question.",
1919
)
2020

@@ -39,10 +39,13 @@ async def on_messages_stream(
3939
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4040
) -> AsyncGenerator[AgentMessage | Response, None]:
4141
# Get the decomposed questions from the query_rewrite_agent
42+
parameter_input = messages[0].content
4243
last_response = messages[-1].content
4344
try:
4445
user_questions = json.loads(last_response)
46+
user_parameters = json.loads(parameter_input)["parameters"]
4547
logging.info(f"Processing questions: {user_questions}")
48+
logging.info(f"Input Parameters: {user_parameters}")
4649

4750
# Initialize results dictionary
4851
cached_results = {
@@ -55,7 +58,7 @@ async def on_messages_stream(
5558
# Fetch the queries from the cache based on the question
5659
logging.info(f"Fetching queries from cache for question: {question}")
5760
cached_query = await self.sql_connector.fetch_queries_from_cache(
58-
question
61+
question, parameters=user_parameters
5962
)
6063

6164
# If any question has pre-run results, set the flag

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sqlglot
99
from abc import ABC, abstractmethod
1010
from datetime import datetime
11+
from jinja2 import Template
1112

1213

1314
class SqlConnector(ABC):
@@ -30,6 +31,18 @@ def get_current_datetime(self) -> str:
3031
"""Get the current datetime."""
3132
return datetime.now().strftime("%d/%m/%Y, %H:%M:%S")
3233

34+
def get_current_date(self) -> str:
35+
"""Get the current date."""
36+
return datetime.now().strftime("%d/%m/%Y")
37+
38+
def get_current_time(self) -> str:
39+
"""Get the current time."""
40+
return datetime.now().strftime("%H:%M:%S")
41+
42+
def get_current_unix_timestamp(self) -> int:
43+
"""Get the current unix timestamp."""
44+
return int(datetime.now().timestamp())
45+
3346
@abstractmethod
3447
async def query_execution(
3548
self,
@@ -118,7 +131,9 @@ async def query_validation(
118131
logging.info("SQL Query is valid.")
119132
return True
120133

121-
async def fetch_queries_from_cache(self, question: str) -> str:
134+
async def fetch_queries_from_cache(
135+
self, question: str, parameters: dict = None
136+
) -> str:
122137
"""Fetch the queries from the cache based on the question.
123138
124139
Args:
@@ -129,6 +144,23 @@ async def fetch_queries_from_cache(self, question: str) -> str:
129144
-------
130145
str: The formatted string of the queries fetched from the cache. This is injected into the prompt.
131146
"""
147+
148+
if parameters is None:
149+
parameters = {}
150+
151+
# Populate the parameters
152+
if "date" not in parameters:
153+
parameters["date"] = self.get_current_date()
154+
155+
if "time" not in parameters:
156+
parameters["time"] = self.get_current_time()
157+
158+
if "datetime" not in parameters:
159+
parameters["datetime"] = self.get_current_datetime()
160+
161+
if "unix_timestamp" not in parameters:
162+
parameters["unix_timestamp"] = self.get_current_unix_timestamp()
163+
132164
cached_schemas = await self.ai_search_connector.run_ai_search_query(
133165
question,
134166
["QuestionEmbedding"],
@@ -146,6 +178,14 @@ async def fetch_queries_from_cache(self, question: str) -> str:
146178
"cached_questions_and_schemas": None,
147179
}
148180

181+
# loop through all sql queries and populate the template in place
182+
for schema in cached_schemas:
183+
sql_queries = schema["SqlQueryDecomposition"]
184+
for sql_query in sql_queries:
185+
sql_query["SqlQuery"] = Template(sql_query["SqlQuery"]).render(
186+
**parameters
187+
)
188+
149189
logging.info("Cached schemas: %s", cached_schemas)
150190
if self.pre_run_query_cache and len(cached_schemas) > 0:
151191
# check the score

0 commit comments

Comments
 (0)