Skip to content

Commit f46d9bd

Browse files
committed
Update payload
1 parent 00b332d commit f46d9bd

File tree

7 files changed

+114
-131
lines changed

7 files changed

+114
-131
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,16 @@ def get_all_agents(self):
4040
# Get current datetime for the Query Rewrite Agent
4141
current_datetime = datetime.now()
4242

43-
self.user_input_rewrite_agent = LLMAgentCreator.create(
44-
"user_input_rewrite_agent", current_datetime=current_datetime
43+
self.message_rewrite_agent = LLMAgentCreator.create(
44+
"message_rewrite_agent", current_datetime=current_datetime
4545
)
4646

4747
self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)
4848

4949
self.answer_agent = LLMAgentCreator.create("answer_agent")
5050

5151
agents = [
52-
self.user_input_rewrite_agent,
52+
self.message_rewrite_agent,
5353
self.parallel_query_solving_agent,
5454
self.answer_agent,
5555
]
@@ -73,11 +73,11 @@ def unified_selector(self, messages):
7373
current_agent = messages[-1].source if messages else "user"
7474
decision = None
7575

76-
# If this is the first message start with user_input_rewrite_agent
76+
# If this is the first message start with message_rewrite_agent
7777
if current_agent == "user":
78-
decision = "user_input_rewrite_agent"
78+
decision = "message_rewrite_agent"
7979
# Handle transition after query rewriting
80-
elif current_agent == "user_input_rewrite_agent":
80+
elif current_agent == "message_rewrite_agent":
8181
decision = "parallel_query_solving_agent"
8282
# Handle transition after parallel query solving
8383
elif current_agent == "parallel_query_solving_agent":
@@ -145,24 +145,24 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
145145
except json.JSONDecodeError:
146146
logging.warning("Unable to read SQL query results: %s", sql_query_results)
147147
sql_query_results = {}
148-
sub_user_input_results = {}
148+
sub_message_results = {}
149149
else:
150-
# Only load sub-user_input results if we have a database result
151-
sub_user_input_results = self.parse_message_content(messages[1].content)
152-
logging.info("Sub-user_input Results: %s", sub_user_input_results)
150+
# Only load sub-message results if we have a database result
151+
sub_message_results = self.parse_message_content(messages[1].content)
152+
logging.info("Sub-message Results: %s", sub_message_results)
153153

154154
try:
155-
sub_user_inputs = [
156-
sub_user_input
157-
for sub_user_input_group in sub_user_input_results.get(
158-
"sub_user_inputs", []
155+
decomposed_messages = [
156+
sub_message
157+
for sub_message_group in sub_message_results.get(
158+
"decomposed_messages", []
159159
)
160-
for sub_user_input in sub_user_input_group
160+
for sub_message in sub_message_group
161161
]
162162

163163
logging.info("SQL Query Results: %s", sql_query_results)
164164
payload = AnswerWithSourcesPayload(
165-
answer=answer, sub_user_inputs=sub_user_inputs
165+
answer=answer, decomposed_messages=decomposed_messages
166166
)
167167

168168
if not isinstance(sql_query_results, dict):
@@ -173,11 +173,9 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
173173
logging.error("No 'results' key in sql_query_results")
174174
return payload
175175

176-
for user_input, sql_query_result_list in sql_query_results[
177-
"results"
178-
].items():
176+
for message, sql_query_result_list in sql_query_results["results"].items():
179177
if not sql_query_result_list: # Check if list is empty
180-
logging.warning(f"No results for user_input: {user_input}")
178+
logging.warning(f"No results for message: {message}")
181179
continue
182180

183181
for sql_query_result in sql_query_result_list:
@@ -213,47 +211,47 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
213211
answer=f"{answer}\nError processing results: {str(e)}"
214212
)
215213

216-
async def process_user_input(
214+
async def process_message(
217215
self,
218-
user_input_payload: UserInputPayload,
216+
message_payload: UserInputPayload,
219217
chat_history: list[InteractionPayload] = None,
220218
) -> AsyncGenerator[InteractionPayload, None]:
221-
"""Process the complete user_input through the unified system.
219+
"""Process the complete message through the unified system.
222220
223221
Args:
224222
----
225-
task (str): The user user_input to process.
223+
task (str): The user message to process.
226224
chat_history (list[str], optional): The chat history. Defaults to None.
227225
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.
228226
229227
Returns:
230228
-------
231229
dict: The response from the system.
232230
"""
233-
logging.info("Processing user_input: %s", user_input_payload.body.user_input)
231+
logging.info("Processing message: %s", message_payload.body.message)
234232
logging.info("Chat history: %s", chat_history)
235233

236234
agent_input = {
237-
"user_input": user_input_payload.body.user_input,
235+
"message": message_payload.body.message,
238236
"chat_history": {},
239-
"injected_parameters": user_input_payload.body.injected_parameters,
237+
"injected_parameters": message_payload.body.injected_parameters,
240238
}
241239

242240
if chat_history is not None:
243241
# Update input
244242
for idx, chat in enumerate(chat_history):
245-
if chat.root.payload_type == PayloadType.USER_INPUT:
243+
if chat.root.payload_type == PayloadType.message:
246244
# For now only consider the user query
247245
chat_history_key = f"chat_{idx}"
248-
agent_input[chat_history_key] = chat.root.body.user_input
246+
agent_input[chat_history_key] = chat.root.body.message
249247

250248
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
251249
logging.debug("Message: %s", message)
252250

253251
payload = None
254252

255253
if isinstance(message, TextMessage):
256-
if message.source == "user_input_rewrite_agent":
254+
if message.source == "message_rewrite_agent":
257255
payload = ProcessingUpdatePayload(
258256
message="Rewriting the query...",
259257
)
@@ -276,10 +274,10 @@ async def process_user_input(
276274
elif message.messages[-1].source == "parallel_query_solving_agent":
277275
# Load into disambiguation request
278276
payload = self.extract_disambiguation_request(message.messages)
279-
elif message.messages[-1].source == "user_input_rewrite_agent":
277+
elif message.messages[-1].source == "message_rewrite_agent":
280278
# Load into empty response
281279
payload = AnswerWithSourcesPayload(
282-
answer="Apologies, I cannot answer that user_input as it is not relevant. Please try another user_input or rephrase your current user_input."
280+
answer="Apologies, I cannot answer that message as it is not relevant. Please try another message or rephrase your current message."
283281
)
284282
else:
285283
logging.error("Unexpected TaskResult: %s", message)

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ async def on_messages_stream(
8484
injected_parameters = {}
8585

8686
# Load the json of the last message to populate the final output object
87-
user_input_rewrites = json.loads(last_response)
87+
message_rewrites = json.loads(last_response)
8888

89-
logging.info(f"Query Rewrites: {user_input_rewrites}")
89+
logging.info(f"Query Rewrites: {message_rewrites}")
9090

9191
async def consume_inner_messages_from_agentic_flow(
9292
agentic_flow, identifier, database_results
@@ -143,7 +143,7 @@ async def consume_inner_messages_from_agentic_flow(
143143
):
144144
logging.info("Contains pre-run results")
145145
for pre_run_sql_query, pre_run_result in parsed_message[
146-
"cached_user_inputs_and_schemas"
146+
"cached_messages_and_schemas"
147147
].items():
148148
database_results[identifier].append(
149149
{
@@ -164,7 +164,7 @@ async def consume_inner_messages_from_agentic_flow(
164164

165165
# Convert all_non_database_query to lowercase string and compare
166166
all_non_database_query = str(
167-
user_input_rewrites.get("all_non_database_query", "false")
167+
message_rewrites.get("all_non_database_query", "false")
168168
).lower()
169169

170170
if all_non_database_query == "true":
@@ -177,12 +177,12 @@ async def consume_inner_messages_from_agentic_flow(
177177
return
178178

179179
# Start processing sub-queries
180-
for user_input_rewrite in user_input_rewrites["sub_user_inputs"]:
181-
logging.info(f"Processing sub-query: {user_input_rewrite}")
180+
for message_rewrite in message_rewrites["decomposed_messages"]:
181+
logging.info(f"Processing sub-query: {message_rewrite}")
182182
# Create an instance of the InnerAutoGenText2Sql class
183183
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)
184184

185-
identifier = ", ".join(user_input_rewrite)
185+
identifier = ", ".join(message_rewrite)
186186

187187
# Add database connection info to injected parameters
188188
query_params = injected_parameters.copy() if injected_parameters else {}
@@ -196,8 +196,8 @@ async def consume_inner_messages_from_agentic_flow(
196196
# Launch tasks for each sub-query
197197
inner_solving_generators.append(
198198
consume_inner_messages_from_agentic_flow(
199-
inner_autogen_text_2_sql.process_user_input(
200-
user_input=user_input_rewrite,
199+
inner_autogen_text_2_sql.process_message(
200+
message=message_rewrite,
201201
injected_parameters=query_params,
202202
),
203203
identifier,

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class SqlQueryCacheAgent(BaseChatAgent):
1717
def __init__(self):
1818
super().__init__(
1919
"sql_query_cache_agent",
20-
"An agent that fetches the queries from the cache based on the user user_input.",
20+
"An agent that fetches the queries from the cache based on the user message.",
2121
)
2222

2323
self.agent = SqlQueryCacheAgentCustomAgent()
@@ -40,19 +40,19 @@ async def on_messages(
4040
async def on_messages_stream(
4141
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4242
) -> AsyncGenerator[AgentMessage | Response, None]:
43-
# Get the decomposed user_inputs from the user_input_rewrite_agent
43+
# Get the decomposed messages from the message_rewrite_agent
4444
try:
4545
request_details = json.loads(messages[0].content)
4646
injected_parameters = request_details["injected_parameters"]
47-
user_user_inputs = request_details["user_input"]
48-
logging.info(f"Processing user_inputs: {user_user_inputs}")
47+
user_messages = request_details["message"]
48+
logging.info(f"Processing messages: {user_messages}")
4949
logging.info(f"Input Parameters: {injected_parameters}")
5050
except json.JSONDecodeError:
51-
# If not JSON array, process as single user_input
51+
# If not JSON array, process as single message
5252
raise ValueError("Could not load message")
5353

5454
cached_results = await self.agent.process_message(
55-
user_user_inputs, injected_parameters
55+
user_messages, injected_parameters
5656
)
5757
yield Response(
5858
chat_message=TextMessage(

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,19 @@ async def on_messages_stream(
4343
# Try to parse as JSON first
4444
try:
4545
request_details = json.loads(messages[0].content)
46-
user_inputs = request_details["question"]
46+
messages = request_details["question"]
4747
except (json.JSONDecodeError, KeyError):
4848
# If not JSON or missing question key, use content directly
49-
user_inputs = messages[0].content
49+
messages = messages[0].content
5050

51-
if isinstance(user_inputs, str):
52-
user_inputs = [user_inputs]
53-
elif not isinstance(user_inputs, list):
54-
user_inputs = [str(user_inputs)]
51+
if isinstance(messages, str):
52+
messages = [messages]
53+
elif not isinstance(messages, list):
54+
messages = [str(messages)]
5555

56-
logging.info(f"Processing questions: {user_inputs}")
56+
logging.info(f"Processing questions: {messages}")
5757

58-
final_results = await self.agent.process_message(user_inputs)
58+
final_results = await self.agent.process_message(messages)
5959

6060
yield Response(
6161
chat_message=TextMessage(

text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(self):
99
self.sql_connector = ConnectorFactory.get_database_connector()
1010

1111
async def process_message(
12-
self, user_inputs: list[str], injected_parameters: dict
12+
self, messages: list[str], injected_parameters: dict
1313
) -> dict:
1414
# Initialize results dictionary
1515
cached_results = {
@@ -18,7 +18,7 @@ async def process_message(
1818
}
1919

2020
# Process each question sequentially
21-
for question in user_inputs:
21+
for question in messages:
2222
# Fetch the queries from the cache based on the question
2323
logging.info(f"Fetching queries from cache for question: {question}")
2424
cached_query = await self.sql_connector.fetch_queries_from_cache(

text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def __init__(self, **kwargs):
2222

2323
self.system_prompt = Template(system_prompt).render(kwargs)
2424

25-
async def process_message(self, user_inputs: list[str]) -> dict:
26-
logging.info(f"user inputs: {user_inputs}")
25+
async def process_message(self, messages: list[str]) -> dict:
26+
logging.info(f"user inputs: {messages}")
2727

2828
entity_tasks = []
2929

30-
for user_input in user_inputs:
30+
for message in messages:
3131
messages = [
3232
{"role": "system", "content": self.system_prompt},
33-
{"role": "user", "content": user_input},
33+
{"role": "user", "content": message},
3434
]
3535
entity_tasks.append(
3636
self.open_ai_connector.run_completion_request(

0 commit comments

Comments
 (0)