Skip to content

Commit 853569e

Browse files
committed
Update agents
1 parent 307204d commit 853569e

File tree

6 files changed

+60
-75
lines changed

6 files changed

+60
-75
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,26 +219,26 @@ async def consume_inner_messages_from_agentic_flow(
219219
# Create an instance of the InnerAutoGenText2Sql class
220220
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)
221221

222-
identifier = ", ".join(parallel_message)
223-
224222
# Add database connection info to injected parameters
225223
query_params = injected_parameters.copy() if injected_parameters else {}
226224
if "Text2Sql__Tsql__ConnectionString" in os.environ:
227225
query_params["database_connection_string"] = os.environ[
228226
"Text2Sql__Tsql__ConnectionString"
229227
]
230228
if "Text2Sql__Tsql__Database" in os.environ:
231-
query_params["database_name"] = os.environ["Text2Sql__Tsql__Database"]
229+
query_params["database_name"] = os.environ[
230+
"Text2Sql__Tsql__Database"
231+
]
232232

233233
# Launch tasks for each sub-query
234234
inner_solving_generators.append(
235235
consume_inner_messages_from_agentic_flow(
236236
inner_autogen_text_2_sql.process_user_message(
237237
user_message=parallel_message,
238238
injected_parameters=query_params,
239-
database_results=filtered_parallel_messages.database_results
239+
database_results=filtered_parallel_messages.database_results,
240240
),
241-
identifier,
241+
parallel_message,
242242
filtered_parallel_messages,
243243
)
244244
)
@@ -267,7 +267,11 @@ async def consume_inner_messages_from_agentic_flow(
267267
# Check for disambiguation requests before processing the next round
268268

269269
if (
270-
max(map(len, filtered_parallel_messages.disambiguation_requests.values()))
270+
max(
271+
map(
272+
len, filtered_parallel_messages.disambiguation_requests.values()
273+
)
274+
)
271275
> 0
272276
):
273277
# Final response

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ async def on_messages_stream(
4444
try:
4545
request_details = json.loads(messages[0].content)
4646
injected_parameters = request_details["injected_parameters"]
47-
user_messages = request_details["user_message"]
48-
logging.info(f"Processing messages: {user_messages}")
47+
user_message = request_details["user_message"]
48+
logging.info(f"Processing messages: {user_message}")
4949
logging.info(f"Input Parameters: {injected_parameters}")
5050
except json.JSONDecodeError:
5151
# If not JSON array, process as single message
5252
raise ValueError("Could not load message")
5353

5454
cached_results = await self.agent.process_message(
55-
user_messages, injected_parameters
55+
user_message, injected_parameters
5656
)
5757
yield Response(
5858
chat_message=TextMessage(

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,14 @@ async def on_messages_stream(
4343
# Try to parse as JSON first
4444
try:
4545
request_details = json.loads(messages[0].content)
46-
messages = request_details["question"]
46+
message = request_details["user_message"]
4747
except (json.JSONDecodeError, KeyError):
4848
# If not JSON or missing question key, use content directly
49-
messages = messages[0].content
49+
message = messages[0].content
5050

51-
if isinstance(messages, str):
52-
messages = [messages]
53-
elif not isinstance(messages, list):
54-
messages = [str(messages)]
51+
logging.info(f"Processing message: {message}")
5552

56-
logging.info(f"Processing questions: {messages}")
57-
58-
final_results = await self.agent.process_message(messages)
53+
final_results = await self.agent.process_message(message)
5954

6055
yield Response(
6156
chat_message=TextMessage(

text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,39 +8,35 @@ class SqlQueryCacheAgentCustomAgent:
88
def __init__(self):
99
self.sql_connector = ConnectorFactory.get_database_connector()
1010

11-
async def process_message(
12-
self, messages: list[str], injected_parameters: dict
13-
) -> dict:
11+
async def process_message(self, message: str, injected_parameters: dict) -> dict:
1412
# Initialize results dictionary
1513
cached_results = {
1614
"cached_sql_queries_with_schemas_from_cache": [],
1715
"contains_cached_sql_queries_with_schemas_from_cache_database_results": False,
1816
}
1917

20-
# Process each question sequentially
21-
for message in messages:
22-
# Fetch the queries from the cache based on the question
23-
logging.info(f"Fetching queries from cache for question: {message}")
24-
cached_query = (
25-
await self.sql_connector.fetch_sql_queries_with_schemas_from_cache(
26-
message, injected_parameters=injected_parameters
27-
)
18+
# Fetch the queries from the cache based on the question
19+
logging.info(f"Fetching queries from cache for question: {message}")
20+
cached_query = (
21+
await self.sql_connector.fetch_sql_queries_with_schemas_from_cache(
22+
message, injected_parameters=injected_parameters
2823
)
24+
)
2925

30-
# If any question has pre-run results, set the flag
31-
if cached_query.get(
32-
"contains_cached_sql_queries_with_schemas_from_cache_database_results",
33-
False,
34-
):
35-
cached_results[
36-
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
37-
] = True
26+
# If any question has pre-run results, set the flag
27+
if cached_query.get(
28+
"contains_cached_sql_queries_with_schemas_from_cache_database_results",
29+
False,
30+
):
31+
cached_results[
32+
"contains_cached_sql_queries_with_schemas_from_cache_database_results"
33+
] = True
3834

39-
# Add the cached results for this question
40-
if cached_query.get("cached_sql_queries_with_schemas_from_cache"):
41-
cached_results["cached_sql_queries_with_schemas_from_cache"].extend(
42-
cached_query["cached_sql_queries_with_schemas_from_cache"]
43-
)
35+
# Add the cached results for this question
36+
if cached_query.get("cached_sql_queries_with_schemas_from_cache"):
37+
cached_results["cached_sql_queries_with_schemas_from_cache"].extend(
38+
cached_query["cached_sql_queries_with_schemas_from_cache"]
39+
)
4440

4541
logging.info(f"Final cached results: {cached_results}")
4642
return cached_results

text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,47 +22,35 @@ def __init__(self, **kwargs):
2222

2323
self.system_prompt = Template(system_prompt).render(kwargs)
2424

25-
async def process_message(self, messages: list[str]) -> dict:
26-
logging.info(f"user inputs: {messages}")
27-
28-
entity_tasks = []
29-
30-
for message in messages:
31-
messages = [
32-
{"role": "system", "content": self.system_prompt},
33-
{"role": "user", "content": message},
34-
]
35-
entity_tasks.append(
36-
self.open_ai_connector.run_completion_request(
37-
messages, response_format=SQLSchemaSelectionAgentOutput
38-
)
39-
)
25+
async def process_message(self, message: str) -> dict:
26+
logging.info(f"Processing message: {message}")
4027

41-
entity_results = await asyncio.gather(*entity_tasks)
28+
messages = [
29+
{"role": "system", "content": self.system_prompt},
30+
{"role": "user", "content": message},
31+
]
32+
entity_result = await self.open_ai_connector.run_completion_request(
33+
messages, response_format=SQLSchemaSelectionAgentOutput
34+
)
4235

4336
entity_search_tasks = []
4437
column_search_tasks = []
4538

46-
for entity_result in entity_results:
47-
logging.info(f"Entity result: {entity_result}")
39+
logging.info(f"Entity result: {entity_result}")
4840

49-
for entity_group in entity_result.entities:
50-
logging.info("Searching for schemas for entity group: %s", entity_group)
51-
entity_search_tasks.append(
52-
self.sql_connector.get_entity_schemas(
53-
" ".join(entity_group), as_json=False
54-
)
41+
for entity_group in entity_result.entities:
42+
logging.info("Searching for schemas for entity group: %s", entity_group)
43+
entity_search_tasks.append(
44+
self.sql_connector.get_entity_schemas(
45+
" ".join(entity_group), as_json=False
5546
)
47+
)
5648

57-
for filter_condition in entity_result.filter_conditions:
58-
logging.info(
59-
"Searching for column values for filter: %s", filter_condition
60-
)
61-
column_search_tasks.append(
62-
self.sql_connector.get_column_values(
63-
filter_condition, as_json=False
64-
)
65-
)
49+
for filter_condition in entity_result.filter_conditions:
50+
logging.info("Searching for column values for filter: %s", filter_condition)
51+
column_search_tasks.append(
52+
self.sql_connector.get_column_values(filter_condition, as_json=False)
53+
)
6654

6755
schemas_results = await asyncio.gather(*entity_search_tasks)
6856
column_value_results = await asyncio.gather(*column_search_tasks)

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,5 @@ system_message:
274274
TERMINATE
275275
</output_format>
276276
"
277+
tools:
278+
- sql_get_entity_schemas_tool

0 commit comments

Comments
 (0)