Skip to content

Commit eb398f8

Browse files
committed
Update prompts and agents
1 parent f1fd21c commit eb398f8

File tree

12 files changed

+202
-116
lines changed

12 files changed

+202
-116
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@
33
from typing import AsyncGenerator, List, Sequence
44

55
from autogen_agentchat.agents import BaseChatAgent
6-
from autogen_agentchat.base import Response, TaskResult
7-
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
6+
from autogen_agentchat.base import Response
7+
from autogen_agentchat.messages import (
8+
AgentMessage,
9+
ChatMessage,
10+
TextMessage,
11+
ToolCallResultMessage,
12+
)
813
from autogen_core import CancellationToken
914
import json
1015
import logging
1116
from autogen_text_2_sql.inner_autogen_text_2_sql import InnerAutoGenText2Sql
1217
from aiostream import stream
1318
from json import JSONDecodeError
19+
import re
1420

1521

1622
class ParallelQuerySolvingAgent(BaseChatAgent):
@@ -53,9 +59,6 @@ def parse_inner_message(self, message):
5359
except JSONDecodeError:
5460
pass
5561

56-
# Try to extract JSON from markdown code blocks
57-
import re
58-
5962
json_match = re.search(r"```json\s*(.*?)\s*```", message, re.DOTALL)
6063
if json_match:
6164
try:
@@ -103,12 +106,13 @@ async def consume_inner_messages_from_agentic_flow(
103106

104107
logging.info(f"Checking Inner Message: {inner_message}")
105108

106-
if isinstance(inner_message, TaskResult) is False:
107-
try:
109+
try:
110+
if isinstance(inner_message, ToolCallResultMessage):
111+
# Check for SQL query results
108112
parsed_message = self.parse_inner_message(inner_message.content)
113+
109114
logging.info(f"Inner Loaded: {parsed_message}")
110115

111-
# Search for specific message types and add them to the final output object
112116
if isinstance(parsed_message, dict):
113117
if (
114118
"type" in parsed_message
@@ -124,6 +128,13 @@ async def consume_inner_messages_from_agentic_flow(
124128
}
125129
)
126130

131+
elif isinstance(inner_message, TextMessage):
132+
parsed_message = self.parse_inner_message(inner_message.content)
133+
134+
logging.info(f"Inner Loaded: {parsed_message}")
135+
136+
# Search for specific message types and add them to the final output object
137+
if isinstance(parsed_message, dict):
127138
if ("contains_pre_run_results" in parsed_message) and (
128139
parsed_message["contains_pre_run_results"] is True
129140
):
@@ -139,8 +150,8 @@ async def consume_inner_messages_from_agentic_flow(
139150
}
140151
)
141152

142-
except Exception as e:
143-
logging.warning(f"Error processing message: {e}")
153+
except Exception as e:
154+
logging.warning(f"Error processing message: {e}")
144155

145156
yield inner_message
146157

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from autogen_agentchat.base import Response
77
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
88
from autogen_core import CancellationToken
9-
from text_2_sql_core.connectors.factory import ConnectorFactory
9+
from text_2_sql_core.custom_agents.sql_query_cache_agent import (
10+
SqlQueryCacheAgentCustomAgent,
11+
)
1012
import json
1113
import logging
1214

@@ -18,7 +20,7 @@ def __init__(self):
1820
"An agent that fetches the queries from the cache based on the user question.",
1921
)
2022

21-
self.sql_connector = ConnectorFactory.get_database_connector()
23+
self.agent = SqlQueryCacheAgentCustomAgent()
2224

2325
@property
2426
def produced_message_types(self) -> List[type[ChatMessage]]:
@@ -49,31 +51,9 @@ async def on_messages_stream(
4951
# If not JSON array, process as single question
5052
raise ValueError("Could not load message")
5153

52-
# Initialize results dictionary
53-
cached_results = {
54-
"cached_questions_and_schemas": [],
55-
"contains_pre_run_results": False,
56-
}
57-
58-
# Process each question sequentially
59-
for question in user_questions:
60-
# Fetch the queries from the cache based on the question
61-
logging.info(f"Fetching queries from cache for question: {question}")
62-
cached_query = await self.sql_connector.fetch_queries_from_cache(
63-
question, injected_parameters=injected_parameters
64-
)
65-
66-
# If any question has pre-run results, set the flag
67-
if cached_query.get("contains_pre_run_results", False):
68-
cached_results["contains_pre_run_results"] = True
69-
70-
# Add the cached results for this question
71-
if cached_query.get("cached_questions_and_schemas"):
72-
cached_results["cached_questions_and_schemas"].extend(
73-
cached_query["cached_questions_and_schemas"]
74-
)
75-
76-
logging.info(f"Final cached results: {cached_results}")
54+
cached_results = await self.agent.process_message(
55+
user_questions, injected_parameters
56+
)
7757
yield Response(
7858
chat_message=TextMessage(
7959
content=json.dumps(cached_results), source=self.name

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py

Lines changed: 13 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
from autogen_agentchat.base import Response
77
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
88
from autogen_core import CancellationToken
9-
from text_2_sql_core.connectors.factory import ConnectorFactory
109
import json
1110
import logging
12-
from text_2_sql_core.prompts.load import load
13-
from jinja2 import Template
14-
import asyncio
11+
from text_2_sql_core.custom_agents.sql_schema_selection_agent import (
12+
SqlSchemaSelectionAgentCustomAgent,
13+
)
1514

1615

1716
class SqlSchemaSelectionAgent(BaseChatAgent):
@@ -21,15 +20,7 @@ def __init__(self, **kwargs):
2120
"An agent that fetches the schemas from the cache based on the user question.",
2221
)
2322

24-
self.ai_search_connector = ConnectorFactory.get_ai_search_connector()
25-
26-
self.open_ai_connector = ConnectorFactory.get_open_ai_connector()
27-
28-
self.sql_connector = ConnectorFactory.get_database_connector()
29-
30-
system_prompt = load("sql_schema_selection_agent")["system_message"]
31-
32-
self.system_prompt = Template(system_prompt).render(kwargs)
23+
self.agent = SqlSchemaSelectionAgentCustomAgent(**kwargs)
3324

3425
@property
3526
def produced_message_types(self) -> List[type[ChatMessage]]:
@@ -49,64 +40,15 @@ async def on_messages(
4940
async def on_messages_stream(
5041
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
5142
) -> AsyncGenerator[AgentMessage | Response, None]:
52-
last_response = messages[-1].content
53-
54-
# load the json of the last message and get the user question's
55-
56-
user_questions = json.loads(last_response)
57-
58-
logging.info(f"User questions: {user_questions}")
59-
60-
entity_tasks = []
61-
62-
for user_question in user_questions:
63-
messages = [
64-
{"role": "system", "content": self.system_prompt},
65-
{"role": "user", "content": user_question},
66-
]
67-
entity_tasks.append(self.open_ai_connector.run_completion_request(messages))
68-
69-
entity_results = await asyncio.gather(*entity_tasks)
70-
71-
entity_search_tasks = []
72-
column_search_tasks = []
73-
74-
for entity_result in entity_results:
75-
loaded_entity_result = json.loads(entity_result)
76-
77-
logging.info(f"Loaded entity result: {loaded_entity_result}")
78-
79-
for entity_group in loaded_entity_result["entities"]:
80-
entity_search_tasks.append(
81-
self.sql_connector.get_entity_schemas(
82-
" ".join(entity_group), as_json=False
83-
)
84-
)
85-
86-
for filter_condition in loaded_entity_result["filter_conditions"]:
87-
column_search_tasks.append(
88-
self.ai_search_connector.get_column_values(
89-
filter_condition, as_json=False
90-
)
91-
)
92-
93-
schemas_results = await asyncio.gather(*entity_search_tasks)
94-
column_value_results = await asyncio.gather(*column_search_tasks)
95-
96-
# deduplicate schemas
97-
final_schemas = []
98-
99-
for schema_result in schemas_results:
100-
for schema in schema_result:
101-
if schema not in final_schemas:
102-
final_schemas.append(schema)
103-
104-
final_results = {
105-
"COLUMN_OPTIONS_AND_VALUES_FOR_FILTERS": column_value_results,
106-
"SCHEMA_OPTIONS": final_schemas,
107-
}
108-
109-
logging.info(f"Final results: {final_results}")
43+
try:
44+
request_details = json.loads(messages[0].content)
45+
user_questions = request_details["question"]
46+
logging.info(f"Processing questions: {user_questions}")
47+
except json.JSONDecodeError:
48+
# If not JSON array, process as single question
49+
raise ValueError("Could not load message")
50+
51+
final_results = await self.agent.process_message(user_questions)
11052

11153
yield Response(
11254
chat_message=TextMessage(

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/open_ai.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@ def get_authentication_properties(cls) -> dict:
2828
return token_provider, api_key
2929

3030
async def run_completion_request(
31-
self, messages: list[dict], temperature=0, max_tokens=2000, model="4o-mini"
31+
self,
32+
messages: list[dict],
33+
temperature=0,
34+
max_tokens=2000,
35+
model="4o-mini",
36+
response_format=None,
3237
) -> str:
3338
if model == "4o-mini":
3439
model_deployment = os.environ["OpenAI__MiniCompletionDeployment"]
@@ -45,13 +50,29 @@ async def run_completion_request(
4550
azure_ad_token_provider=token_provider,
4651
api_key=api_key,
4752
) as open_ai_client:
48-
response = await open_ai_client.chat.completions.create(
49-
model=model_deployment,
50-
messages=messages,
51-
temperature=temperature,
52-
max_tokens=max_tokens,
53-
)
54-
return response.choices[0].message.content
53+
if response_format is not None:
54+
response = await open_ai_client.beta.chat.completions.parse(
55+
model=model_deployment,
56+
messages=messages,
57+
temperature=temperature,
58+
max_tokens=max_tokens,
59+
response_format=response_format,
60+
)
61+
else:
62+
response = await open_ai_client.chat.completions.create(
63+
model=model_deployment,
64+
messages=messages,
65+
temperature=temperature,
66+
max_tokens=max_tokens,
67+
)
68+
69+
message = response.choices[0].message
70+
if response_format is not None and message.parsed is not None:
71+
return message.parsed
72+
elif response_format is not None:
73+
return message.refusal
74+
else:
75+
return message.content
5576

5677
async def run_embedding_request(self, batch: list[str]):
5778
token_provider, api_key = self.get_authentication_properties()

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/tsql_sql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ async def get_entity_schemas(
127127

128128
del schema["Entity"]
129129
del schema["Schema"]
130+
del schema["Database"]
130131

131132
if as_json:
132133
return json.dumps(schemas, default=str)

text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/__init__.py

Whitespace-only changes.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from text_2_sql_core.connectors.factory import ConnectorFactory
4+
import logging
5+
6+
7+
class SqlQueryCacheAgentCustomAgent:
8+
def __init__(self):
9+
self.sql_connector = ConnectorFactory.get_database_connector()
10+
11+
async def process_message(
12+
self, user_questions: list[str], injected_parameters: dict
13+
) -> dict:
14+
# Initialize results dictionary
15+
cached_results = {
16+
"cached_questions_and_schemas": [],
17+
"contains_pre_run_results": False,
18+
}
19+
20+
# Process each question sequentially
21+
for question in user_questions:
22+
# Fetch the queries from the cache based on the question
23+
logging.info(f"Fetching queries from cache for question: {question}")
24+
cached_query = await self.sql_connector.fetch_queries_from_cache(
25+
question, injected_parameters=injected_parameters
26+
)
27+
28+
# If any question has pre-run results, set the flag
29+
if cached_query.get("contains_pre_run_results", False):
30+
cached_results["contains_pre_run_results"] = True
31+
32+
# Add the cached results for this question
33+
if cached_query.get("cached_questions_and_schemas"):
34+
cached_results["cached_questions_and_schemas"].extend(
35+
cached_query["cached_questions_and_schemas"]
36+
)
37+
38+
logging.info(f"Final cached results: {cached_results}")
39+
return cached_results

0 commit comments

Comments
 (0)