Skip to content

Commit e8ddbec

Browse files
committed
Merge remote-tracking branch 'refs/remotes/origin/feature/postgres-support' into feature/postgres-support
2 parents 3ac3ce1 + 441c977 commit e8ddbec

File tree

2 files changed

+104
-56
lines changed

2 files changed

+104
-56
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import os
1818
from datetime import datetime
19+
import re
1920

2021
from text_2_sql_core.payloads.interaction_payloads import (
2122
QuestionPayload,
@@ -108,49 +109,69 @@ def extract_disambiguation_request(
108109
self, messages: list
109110
) -> DismabiguationRequestPayload:
110111
"""Extract the disambiguation request from the answer."""
111-
112112
disambiguation_request = messages[-1].content
113-
114-
# TODO: Properly extract the disambiguation request
115113
return DismabiguationRequestPayload(
116114
disambiguation_request=disambiguation_request,
117115
)
118116

117+
def parse_message_content(self, content):
118+
"""Parse different message content formats into a dictionary."""
119+
if isinstance(content, (list, dict)):
120+
# If it's already a list or dict, convert to JSON string
121+
return json.dumps(content)
122+
123+
# Try to extract JSON from markdown-style code blocks
124+
json_match = re.search(r"```json\s*(.*?)\s*```", content, re.DOTALL)
125+
if json_match:
126+
try:
127+
return json.loads(json_match.group(1))
128+
except json.JSONDecodeError:
129+
pass
130+
131+
# Try parsing as regular JSON
132+
try:
133+
return json.loads(content)
134+
except json.JSONDecodeError:
135+
pass
136+
137+
# If all parsing attempts fail, return the content as-is
138+
return content
139+
119140
def extract_sources(self, messages: list) -> AnswerWithSourcesPayload:
120141
"""Extract the sources from the answer."""
121-
122142
answer = messages[-1].content
123-
124-
sql_query_results = messages[-2].content
143+
sql_query_results = self.parse_message_content(messages[-2].content)
125144

126145
try:
127-
sql_query_results = json.loads(sql_query_results)
146+
if isinstance(sql_query_results, str):
147+
sql_query_results = json.loads(sql_query_results)
128148

129149
logging.info("SQL Query Results: %s", sql_query_results)
130-
131150
payload = AnswerWithSourcesPayload(answer=answer)
132151

133-
for question, sql_query_result_list in sql_query_results["results"].items():
134-
logging.info(
135-
"SQL Query Result for question '%s': %s",
136-
question,
137-
sql_query_result_list,
138-
)
139-
140-
for sql_query_result in sql_query_result_list:
141-
logging.info("SQL Query Result: %s", sql_query_result)
142-
# Instantiate Source and append to the payload's sources list
143-
source = AnswerWithSourcesPayload.Body.Source(
144-
sql_query=sql_query_result["sql_query"],
145-
sql_rows=sql_query_result["sql_rows"],
152+
if isinstance(sql_query_results, dict) and "results" in sql_query_results:
153+
for question, sql_query_result_list in sql_query_results[
154+
"results"
155+
].items():
156+
logging.info(
157+
"SQL Query Result for question '%s': %s",
158+
question,
159+
sql_query_result_list,
146160
)
147-
payload.body.sources.append(source)
161+
162+
for sql_query_result in sql_query_result_list:
163+
logging.info("SQL Query Result: %s", sql_query_result)
164+
source = AnswerWithSourcesPayload.Body.Source(
165+
sql_query=sql_query_result["sql_query"],
166+
sql_rows=sql_query_result["sql_rows"],
167+
)
168+
payload.body.sources.append(source)
148169

149170
return payload
150171

151-
except json.JSONDecodeError:
152-
logging.error("Could not load message: %s", sql_query_results)
153-
raise ValueError("Could not load message")
172+
except Exception as e:
173+
logging.error("Error processing results: %s", str(e))
174+
return AnswerWithSourcesPayload(answer=answer)
154175

155176
async def process_question(
156177
self,

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,37 @@ async def on_messages(
3838
assert response is not None
3939
return response
4040

41+
def parse_inner_message(self, message):
42+
"""Parse inner message content into a structured format."""
43+
try:
44+
if isinstance(message, (dict, list)):
45+
return message
46+
47+
if not isinstance(message, str):
48+
message = str(message)
49+
50+
# Try to parse as JSON first
51+
try:
52+
return json.loads(message)
53+
except JSONDecodeError:
54+
pass
55+
56+
# Try to extract JSON from markdown code blocks
57+
import re
58+
59+
json_match = re.search(r"```json\s*(.*?)\s*```", message, re.DOTALL)
60+
if json_match:
61+
try:
62+
return json.loads(json_match.group(1))
63+
except JSONDecodeError:
64+
pass
65+
66+
# If we can't parse it, return it as-is
67+
return message
68+
except Exception as e:
69+
logging.warning(f"Error parsing message: {e}")
70+
return message
71+
4172
async def on_messages_stream(
4273
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4374
) -> AsyncGenerator[AgentMessage | Response, None]:
@@ -74,46 +105,42 @@ async def consume_inner_messages_from_agentic_flow(
74105

75106
if isinstance(inner_message, TaskResult) is False:
76107
try:
77-
inner_message = json.loads(inner_message.content)
78-
logging.info(f"Inner Loaded: {inner_message}")
108+
parsed_message = self.parse_inner_message(inner_message.content)
109+
logging.info(f"Inner Loaded: {parsed_message}")
79110

80111
# Search for specific message types and add them to the final output object
81-
if (
82-
"type" in inner_message
83-
and inner_message["type"] == "query_execution_with_limit"
84-
):
85-
database_results[identifier].append(
86-
{
87-
"sql_query": inner_message["sql_query"].replace(
88-
"\n", " "
89-
),
90-
"sql_rows": inner_message["sql_rows"],
91-
}
92-
)
93-
94-
if ("contains_pre_run_results" in inner_message) and (
95-
inner_message["contains_pre_run_results"] is True
96-
):
97-
for pre_run_sql_query, pre_run_result in inner_message[
98-
"cached_questions_and_schemas"
99-
].items():
112+
if isinstance(parsed_message, dict):
113+
if (
114+
"type" in parsed_message
115+
and parsed_message["type"]
116+
== "query_execution_with_limit"
117+
):
100118
database_results[identifier].append(
101119
{
102-
"sql_query": pre_run_sql_query.replace(
103-
"\n", " "
104-
),
105-
"sql_rows": pre_run_result["sql_rows"],
120+
"sql_query": parsed_message[
121+
"sql_query"
122+
].replace("\n", " "),
123+
"sql_rows": parsed_message["sql_rows"],
106124
}
107125
)
108126

109-
except (JSONDecodeError, TypeError) as e:
110-
logging.error("Could not load message: %s", inner_message)
111-
logging.warning(f"Error processing message: {e}")
127+
if ("contains_pre_run_results" in parsed_message) and (
128+
parsed_message["contains_pre_run_results"] is True
129+
):
130+
for pre_run_sql_query, pre_run_result in parsed_message[
131+
"cached_questions_and_schemas"
132+
].items():
133+
database_results[identifier].append(
134+
{
135+
"sql_query": pre_run_sql_query.replace(
136+
"\n", " "
137+
),
138+
"sql_rows": pre_run_result["sql_rows"],
139+
}
140+
)
112141

113142
except Exception as e:
114-
logging.error("Could not load message: %s", inner_message)
115-
logging.error(f"Error processing message: {e}")
116-
raise e
143+
logging.warning(f"Error processing message: {e}")
117144

118145
yield inner_message
119146

0 commit comments

Comments
 (0)