Skip to content

Commit bd4b4b0

Browse files
committed
Fix disambiguation
1 parent b200d21 commit bd4b4b0

File tree

7 files changed

+124
-36
lines changed

7 files changed

+124
-36
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def termination_condition(self):
6262
termination = (
6363
TextMentionTermination("TERMINATE")
6464
| SourceMatchTermination("answer_agent")
65-
| TextMentionTermination("requires_user_information_request")
65+
| TextMentionTermination("contains_disambiguation_requests")
6666
| MaxMessageTermination(5)
6767
)
6868
return termination
@@ -131,20 +131,43 @@ def extract_decomposed_user_messages(self, messages: list) -> list[list[str]]:
131131
sub_message_results = self.parse_message_content(messages[1].content)
132132
logging.info("Decomposed Results: %s", sub_message_results)
133133

134-
return sub_message_results.get("decomposed_user_messages", [])
134+
decomposed_user_messages = sub_message_results.get(
135+
"decomposed_user_messages", []
136+
)
137+
138+
logging.debug(
139+
"Returning decomposed_user_messages: %s", decomposed_user_messages
140+
)
141+
142+
return decomposed_user_messages
135143

136144
def extract_disambiguation_request(
137145
self, messages: list
138146
) -> DismabiguationRequestsPayload:
139147
"""Extract the disambiguation request from the answer."""
140-
disambiguation_request = messages[-1].content
148+
all_disambiguation_requests = self.parse_message_content(messages[-1].content)
141149

142150
decomposed_user_messages = self.extract_decomposed_user_messages(messages)
143-
return DismabiguationRequestsPayload(
144-
disambiguation_request=disambiguation_request,
145-
decomposed_user_messages=decomposed_user_messages,
151+
request_payload = DismabiguationRequestsPayload(
152+
decomposed_user_messages=decomposed_user_messages
146153
)
147154

155+
for per_question_disambiguation_request in all_disambiguation_requests[
156+
"disambiguation_requests"
157+
].values():
158+
for disambiguation_request in per_question_disambiguation_request:
159+
logging.info(
160+
"Disambiguation Request Identified: %s", disambiguation_request
161+
)
162+
163+
request = DismabiguationRequestsPayload.Body.DismabiguationRequest(
164+
agent_question=disambiguation_request["agent_question"],
165+
user_choices=disambiguation_request["user_choices"],
166+
)
167+
request_payload.body.disambiguation_requests.append(request)
168+
169+
return request_payload
170+
148171
def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
149172
"""Extract the sources from the answer."""
150173
answer = messages[-1].content
@@ -169,11 +192,13 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
169192
logging.error(f"Expected dict, got {type(sql_query_results)}")
170193
return payload
171194

172-
if "results" not in sql_query_results:
195+
if "database_results" not in sql_query_results:
173196
logging.error("No 'results' key in sql_query_results")
174197
return payload
175198

176-
for message, sql_query_result_list in sql_query_results["results"].items():
199+
for message, sql_query_result_list in sql_query_results[
200+
"database_results"
201+
].items():
177202
if not sql_query_result_list: # Check if list is empty
178203
logging.warning(f"No results for message: {message}")
179204
continue

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@
1818
from json import JSONDecodeError
1919
import re
2020
import os
21+
from pydantic import BaseModel, Field
22+
23+
24+
class FilteredParallelMessagesCollection(BaseModel):
25+
database_results: dict[str, list] = Field(default_factory=dict)
26+
disambiguation_requests: dict[str, list] = Field(default_factory=dict)
27+
28+
def add_identifier(self, identifier):
29+
if identifier not in self.database_results:
30+
self.database_results[identifier] = []
31+
if identifier not in self.disambiguation_requests:
32+
self.disambiguation_requests[identifier] = []
2133

2234

2335
class ParallelQuerySolvingAgent(BaseChatAgent):
@@ -89,7 +101,7 @@ async def on_messages_stream(
89101
logging.info(f"Query Rewrites: {message_rewrites}")
90102

91103
async def consume_inner_messages_from_agentic_flow(
92-
agentic_flow, identifier, database_results
104+
agentic_flow, identifier, filtered_parallel_messages
93105
):
94106
"""
95107
Consume the inner messages and append them to the specified list.
@@ -101,8 +113,7 @@ async def consume_inner_messages_from_agentic_flow(
101113
"""
102114
async for inner_message in agentic_flow:
103115
# Add message to results dictionary, tagged by the function name
104-
if identifier not in database_results:
105-
database_results[identifier] = []
116+
filtered_parallel_messages.add_identifier(identifier)
106117

107118
logging.info(f"Checking Inner Message: {inner_message}")
108119

@@ -122,7 +133,9 @@ async def consume_inner_messages_from_agentic_flow(
122133
== "query_execution_with_limit"
123134
):
124135
logging.info("Contains query results")
125-
database_results[identifier].append(
136+
filtered_parallel_messages.database_results[
137+
identifier
138+
].append(
126139
{
127140
"sql_query": parsed_message[
128141
"sql_query"
@@ -138,29 +151,43 @@ async def consume_inner_messages_from_agentic_flow(
138151

139152
# Search for specific message types and add them to the final output object
140153
if isinstance(parsed_message, dict):
154+
# Check if the message contains pre-run results
141155
if ("contains_pre_run_results" in parsed_message) and (
142156
parsed_message["contains_pre_run_results"] is True
143157
):
144158
logging.info("Contains pre-run results")
145159
for pre_run_sql_query, pre_run_result in parsed_message[
146160
"cached_messages_and_schemas"
147161
].items():
148-
database_results[identifier].append(
162+
filtered_parallel_messages.database_results[
163+
identifier
164+
].append(
149165
{
150166
"sql_query": pre_run_sql_query.replace(
151167
"\n", " "
152168
),
153169
"sql_rows": pre_run_result["sql_rows"],
154170
}
155171
)
172+
# Check if disambiguation is required
173+
elif ("disambiguation_requests" in parsed_message) and (
174+
parsed_message["disambiguation_requests"]
175+
):
176+
logging.info("Contains disambiguation requests")
177+
for disambiguation_request in parsed_message[
178+
"disambiguation_requests"
179+
]:
180+
filtered_parallel_messages.disambiguation_requests[
181+
identifier
182+
].append(disambiguation_request)
156183

157184
except Exception as e:
158185
logging.warning(f"Error processing message: {e}")
159186

160187
yield inner_message
161188

162189
inner_solving_generators = []
163-
database_results = {}
190+
filtered_parallel_messages = FilteredParallelMessagesCollection()
164191

165192
# Convert all_non_database_query to lowercase string and compare
166193
all_non_database_query = str(
@@ -201,7 +228,7 @@ async def consume_inner_messages_from_agentic_flow(
201228
injected_parameters=query_params,
202229
),
203230
identifier,
204-
database_results,
231+
filtered_parallel_messages,
205232
)
206233
)
207234

@@ -218,17 +245,43 @@ async def consume_inner_messages_from_agentic_flow(
218245
yield inner_message
219246

220247
# Log final results for debugging or auditing
221-
logging.info(f"Database Results: {database_results}")
248+
logging.info(
249+
"Database Results: %s", filtered_parallel_messages.database_results
250+
)
251+
logging.info(
252+
"Disambiguation Requests: %s",
253+
filtered_parallel_messages.disambiguation_requests,
254+
)
222255

223-
# Final response
224-
yield Response(
225-
chat_message=TextMessage(
226-
content=json.dumps(
227-
{"contains_results": True, "results": database_results}
256+
if (
257+
max(map(len, filtered_parallel_messages.disambiguation_requests.values()))
258+
> 0
259+
):
260+
# Final response
261+
yield Response(
262+
chat_message=TextMessage(
263+
content=json.dumps(
264+
{
265+
"contains_disambiguation_requests": True,
266+
"disambiguation_requests": filtered_parallel_messages.disambiguation_requests,
267+
}
268+
),
269+
source=self.name,
228270
),
229-
source=self.name,
230-
),
231-
)
271+
)
272+
else:
273+
# Final response
274+
yield Response(
275+
chat_message=TextMessage(
276+
content=json.dumps(
277+
{
278+
"contains_database_results": True,
279+
"database_results": filtered_parallel_messages.database_results,
280+
}
281+
),
282+
source=self.name,
283+
),
284+
)
232285

233286
async def on_reset(self, cancellation_token: CancellationToken) -> None:
234287
pass

text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def extract_sql_queries_from_results(results: Dict[str, Any]) -> List[str]:
1616
"""
1717
queries = []
1818

19-
if results.get("contains_results") and results.get("results"):
19+
if results.get("contains_database_results") and results.get("results"):
2020
for question_results in results["results"].values():
2121
for result in question_results:
2222
if isinstance(result, dict) and "sql_query" in result:

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,11 @@ def get_all_agents(self):
124124
@property
125125
def termination_condition(self):
126126
"""Define the termination condition for the chat."""
127-
termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(10)
127+
termination = (
128+
TextMentionTermination("TERMINATE")
129+
| MaxMessageTermination(10)
130+
| TextMentionTermination("disambiguation_request")
131+
)
128132
return termination
129133

130134
def unified_selector(self, messages):

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,17 @@ async def query_execution_with_limit(
187187
"type": "query_execution_with_limit",
188188
"sql_query": sql_query,
189189
"sql_rows": result,
190-
}
190+
},
191+
default=str,
191192
)
192193
else:
193194
return json.dumps(
194195
{
195196
"type": "errored_query_execution_with_limit",
196197
"sql_query": sql_query,
197198
"errors": validation_result,
198-
}
199+
},
200+
default=str,
199201
)
200202

201203
async def query_validation(

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class DismabiguationRequest(InteractionPayloadBase):
4242
agent_question: str | None = Field(..., alias="agentQuestion")
4343
user_choices: list[str] | None = Field(default=None, alias="userChoices")
4444

45-
disambiguation_requests: list[DismabiguationRequest] = Field(
46-
alias="disambiguationRequests"
45+
disambiguation_requests: list[DismabiguationRequest] | None = Field(
46+
default_factory=list, alias="disambiguationRequests"
4747
)
4848
decomposed_user_messages: list[list[str]] = Field(
4949
default_factory=list, alias="decomposedUserMessages"

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/disambiguation_and_sql_query_generation_agent.yaml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,16 @@ system_message:
170170
171171
If disambiguation needed:
172172
{
173-
\"disambiguation\": [{
174-
\"question\": \"<specific_question>\",
175-
\"matching_columns\": [\"<column1>\", \"<column2>\"],
176-
\"matching_filter_values\": [\"<value1>\", \"<value2>\"],
177-
\"other_user_choices\": [\"<choice1>\", \"<choice2>\"]
178-
}]
173+
\"disambiguation_requests\": [
174+
{
175+
\"agent_question\": \"<specific_question>\",
176+
\"user_choices\": [\"<choice1>\", \"<choice2>\"]
177+
},
178+
{
179+
\"agent_question\": \"<specific_question>\",
180+
\"user_choices\": [\"<choice1>\", \"<choice2>\"]
181+
}
182+
]
179183
}
180184
TERMINATE
181185
</output_format>

0 commit comments

Comments
 (0)