Skip to content

Commit c0c630c

Browse files
committed
Update query validation
1 parent e846e9e commit c0c630c

File tree

5 files changed

+61
-9
lines changed

5 files changed

+61
-9
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,18 @@
1616
)
1717
from autogen_agentchat.agents import UserProxyAgent
1818
from autogen_agentchat.messages import TextMessage
19-
from autogen_agentchat.base import Response
2019
import json
2120
import os
2221
from datetime import datetime
2322

23+
from text_2_sql_core.payloads import (
24+
AnswerWithSources,
25+
UserInformationRequest,
26+
ProcessingUpdate,
27+
)
28+
from autogen_agentchat.base import Response, TaskResult
29+
from asyncio import AsyncGenerator
30+
2431

2532
class EmptyResponseUserProxyAgent(UserProxyAgent):
2633
"""UserProxyAgent that automatically responds with empty messages."""
@@ -120,12 +127,12 @@ def agentic_flow(self):
120127
)
121128
return flow
122129

123-
def process_question(
130+
async def process_question(
124131
self,
125132
question: str,
126133
chat_history: list[str] = None,
127134
parameters: dict = None,
128-
):
135+
) -> AsyncGenerator[AnswerWithSources | UserInformationRequest]:
129136
"""Process the complete question through the unified system.
130137
131138
Args:
@@ -152,4 +159,42 @@ def process_question(
152159
for idx, chat in enumerate(chat_history):
153160
agent_input[f"chat_{idx}"] = chat
154161

155-
return self.agentic_flow.run_stream(task=json.dumps(agent_input))
162+
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
163+
logging.info("Message: %s", message)
164+
logging.info("Message type: %s", type(message))
165+
166+
payload = None
167+
168+
if isinstance(message, TextMessage):
169+
if message.source == "query_rewrite_agent":
170+
# If the message is from the query_rewrite_agent, we need to update the chat history
171+
payload = ProcessingUpdate(
172+
title="Rewriting the query...",
173+
)
174+
elif message.source == "parallel_query_solving_agent":
175+
# If the message is from the parallel_query_solving_agent, we need to update the chat history
176+
payload = ProcessingUpdate(
177+
title="Solving the query...",
178+
)
179+
elif message.source == "answer_agent":
180+
# If the message is from the answer_agent, we need to update the chat history
181+
payload = ProcessingUpdate(
182+
title="Generating the answer...",
183+
)
184+
185+
elif isinstance(message, TaskResult):
186+
# Now we need to return the final answer or the disambiguation request
187+
188+
if message.task == "answer_agent":
189+
# If the message is from the answer_agent, we need to return the final answer
190+
payload = AnswerWithSources(
191+
**json.loads(message.content),
192+
)
193+
else:
194+
payload = UserInformationRequest(
195+
**json.loads(message.content),
196+
)
197+
198+
if payload is not None:
199+
logging.info("Payload: %s", payload)
200+
yield payload

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ async def query_validation(
143143
sqlglot.transpile(
144144
sql_query,
145145
read=self.database_engine.value.lower(),
146-
error_level=sqlglot.ErrorLevel.ERROR,
146+
error_level=sqlglot.ErrorLevel.RAISE,
147147
)
148148
except sqlglot.errors.ParseError as e:
149149
logging.error("SQL Query is invalid: %s", e.errors)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from text_2_sql_core.payloads.answer_with_sources import AnswerWithSources, Source
22
from text_2_sql_core.payloads.user_information_request import UserInformationRequest
3+
from text_2_sql_core.payloads.processing_update import ProcessingUpdate
34

4-
__all__ = ["AnswerWithSources", "Source", "UserInformationRequest"]
5+
__all__ = ["AnswerWithSources", "Source", "UserInformationRequest", "ProcessingUpdate"]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pydantic import BaseModel
2+
3+
4+
class ProcessingUpdate(BaseModel):
5+
title: str
6+
message: str

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@ class RequestType(StrEnum):
99

1010

1111
class ClarificationRequest(BaseModel):
12-
type: Literal[RequestType.CLARIFICATION]
12+
request_type: Literal[RequestType.CLARIFICATION]
1313
question: str
1414
other_user_choices: list[str]
1515

1616

1717
class DismabiguationRequest(BaseModel):
18-
type: Literal[RequestType.DISAMBIGUATION]
18+
request_type: Literal[RequestType.DISAMBIGUATION]
1919
question: str
2020
matching_columns: list[str]
2121
matching_filter_values: list[str]
2222
other_user_choices: list[str]
2323

2424

2525
class UserInformationRequest(RootModel):
26-
root: DismabiguationRequest = Field(..., discriminator="type")
26+
root: DismabiguationRequest = Field(..., discriminator="request_type")

0 commit comments

Comments
 (0)