Skip to content

Commit 222c13b

Browse files
committed
Update user input
1 parent 7211486 commit 222c13b

File tree

9 files changed

+67
-63
lines changed

9 files changed

+67
-63
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,16 @@ def get_all_agents(self):
4040
# Get current datetime for the Query Rewrite Agent
4141
current_datetime = datetime.now()
4242

43-
self.question_rewrite_agent = LLMAgentCreator.create(
44-
"question_rewrite_agent", current_datetime=current_datetime
43+
self.user_input_rewrite_agent = LLMAgentCreator.create(
44+
"user_input_rewrite_agent", current_datetime=current_datetime
4545
)
4646

4747
self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)
4848

4949
self.answer_agent = LLMAgentCreator.create("answer_agent")
5050

5151
agents = [
52-
self.question_rewrite_agent,
52+
self.user_input_rewrite_agent,
5353
self.parallel_query_solving_agent,
5454
self.answer_agent,
5555
]
@@ -73,11 +73,11 @@ def unified_selector(self, messages):
7373
current_agent = messages[-1].source if messages else "user"
7474
decision = None
7575

76-
# If this is the first message start with question_rewrite_agent
76+
# If this is the first message start with user_input_rewrite_agent
7777
if current_agent == "user":
78-
decision = "question_rewrite_agent"
78+
decision = "user_input_rewrite_agent"
7979
# Handle transition after query rewriting
80-
elif current_agent == "question_rewrite_agent":
80+
elif current_agent == "user_input_rewrite_agent":
8181
decision = "parallel_query_solving_agent"
8282
# Handle transition after parallel query solving
8383
elif current_agent == "parallel_query_solving_agent":
@@ -145,22 +145,24 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
145145
except json.JSONDecodeError:
146146
logging.warning("Unable to read SQL query results: %s", sql_query_results)
147147
sql_query_results = {}
148-
sub_question_results = {}
148+
sub_user_input_results = {}
149149
else:
150-
# Only load sub-question results if we have a database result
151-
sub_question_results = self.parse_message_content(messages[1].content)
152-
logging.info("Sub-Question Results: %s", sub_question_results)
150+
# Only load sub-user_input results if we have a database result
151+
sub_user_input_results = self.parse_message_content(messages[1].content)
152+
logging.info("Sub-user_input Results: %s", sub_user_input_results)
153153

154154
try:
155-
sub_questions = [
156-
sub_question
157-
for sub_question_group in sub_question_results.get("sub_questions", [])
158-
for sub_question in sub_question_group
155+
sub_user_inputs = [
156+
sub_user_input
157+
for sub_user_input_group in sub_user_input_results.get(
158+
"sub_user_inputs", []
159+
)
160+
for sub_user_input in sub_user_input_group
159161
]
160162

161163
logging.info("SQL Query Results: %s", sql_query_results)
162164
payload = AnswerWithSourcesPayload(
163-
answer=answer, sub_questions=sub_questions
165+
answer=answer, sub_user_inputs=sub_user_inputs
164166
)
165167

166168
if not isinstance(sql_query_results, dict):
@@ -171,9 +173,11 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
171173
logging.error("No 'results' key in sql_query_results")
172174
return payload
173175

174-
for question, sql_query_result_list in sql_query_results["results"].items():
176+
for user_input, sql_query_result_list in sql_query_results[
177+
"results"
178+
].items():
175179
if not sql_query_result_list: # Check if list is empty
176-
logging.warning(f"No results for question: {question}")
180+
logging.warning(f"No results for user_input: {user_input}")
177181
continue
178182

179183
for sql_query_result in sql_query_result_list:
@@ -209,30 +213,30 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
209213
answer=f"{answer}\nError processing results: {str(e)}"
210214
)
211215

212-
async def process_question(
216+
async def process_user_input(
213217
self,
214-
question_payload: UserInputPayload,
218+
user_input_payload: UserInputPayload,
215219
chat_history: list[InteractionPayload] = None,
216220
) -> AsyncGenerator[InteractionPayload, None]:
217-
"""Process the complete question through the unified system.
221+
"""Process the complete user_input through the unified system.
218222
219223
Args:
220224
----
221-
task (str): The user question to process.
225+
task (str): The user user_input to process.
222226
chat_history (list[str], optional): The chat history. Defaults to None.
223227
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.
224228
225229
Returns:
226230
-------
227231
dict: The response from the system.
228232
"""
229-
logging.info("Processing question: %s", question_payload.body.question)
233+
logging.info("Processing user_input: %s", user_input_payload.body.user_input)
230234
logging.info("Chat history: %s", chat_history)
231235

232236
agent_input = {
233-
"question": question_payload.body.question,
237+
"user_input": user_input_payload.body.user_input,
234238
"chat_history": {},
235-
"injected_parameters": question_payload.body.injected_parameters,
239+
"injected_parameters": user_input_payload.body.injected_parameters,
236240
}
237241

238242
if chat_history is not None:
@@ -241,15 +245,15 @@ async def process_question(
241245
if chat.root.payload_type == PayloadType.USER_INPUT:
242246
# For now only consider the user query
243247
chat_history_key = f"chat_{idx}"
244-
agent_input[chat_history_key] = chat.root.body.question
248+
agent_input[chat_history_key] = chat.root.body.user_input
245249

246250
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
247251
logging.debug("Message: %s", message)
248252

249253
payload = None
250254

251255
if isinstance(message, TextMessage):
252-
if message.source == "question_rewrite_agent":
256+
if message.source == "user_input_rewrite_agent":
253257
payload = ProcessingUpdatePayload(
254258
message="Rewriting the query...",
255259
)
@@ -272,10 +276,10 @@ async def process_question(
272276
elif message.messages[-1].source == "parallel_query_solving_agent":
273277
# Load into disambiguation request
274278
payload = self.extract_disambiguation_request(message.messages)
275-
elif message.messages[-1].source == "question_rewrite_agent":
279+
elif message.messages[-1].source == "user_input_rewrite_agent":
276280
# Load into empty response
277281
payload = AnswerWithSourcesPayload(
278-
answer="Apologies, I cannot answer that question as it is not relevant. Please try another question or rephrase your current question."
282+
answer="Apologies, I cannot answer that user_input as it is not relevant. Please try another user_input or rephrase your current user_input."
279283
)
280284
else:
281285
logging.error("Unexpected TaskResult: %s", message)

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_tool(cls, sql_helper, tool_name: str):
4343
elif tool_name == "sql_get_entity_schemas_tool":
4444
return FunctionToolAlias(
4545
sql_helper.get_entity_schemas,
46-
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.",
46+
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user input and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.",
4747
)
4848
elif tool_name == "sql_get_column_values_tool":
4949
return FunctionToolAlias(

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ async def on_messages_stream(
8484
injected_parameters = {}
8585

8686
# Load the json of the last message to populate the final output object
87-
question_rewrites = json.loads(last_response)
87+
user_input_rewrites = json.loads(last_response)
8888

89-
logging.info(f"Query Rewrites: {question_rewrites}")
89+
logging.info(f"Query Rewrites: {user_input_rewrites}")
9090

9191
async def consume_inner_messages_from_agentic_flow(
9292
agentic_flow, identifier, database_results
@@ -143,7 +143,7 @@ async def consume_inner_messages_from_agentic_flow(
143143
):
144144
logging.info("Contains pre-run results")
145145
for pre_run_sql_query, pre_run_result in parsed_message[
146-
"cached_questions_and_schemas"
146+
"cached_user_inputs_and_schemas"
147147
].items():
148148
database_results[identifier].append(
149149
{
@@ -164,7 +164,7 @@ async def consume_inner_messages_from_agentic_flow(
164164

165165
# Convert all_non_database_query to lowercase string and compare
166166
all_non_database_query = str(
167-
question_rewrites.get("all_non_database_query", "false")
167+
user_input_rewrites.get("all_non_database_query", "false")
168168
).lower()
169169

170170
if all_non_database_query == "true":
@@ -177,12 +177,12 @@ async def consume_inner_messages_from_agentic_flow(
177177
return
178178

179179
# Start processing sub-queries
180-
for question_rewrite in question_rewrites["sub_questions"]:
181-
logging.info(f"Processing sub-query: {question_rewrite}")
180+
for user_input_rewrite in user_input_rewrites["sub_user_inputs"]:
181+
logging.info(f"Processing sub-query: {user_input_rewrite}")
182182
# Create an instance of the InnerAutoGenText2Sql class
183183
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)
184184

185-
identifier = ", ".join(question_rewrite)
185+
identifier = ", ".join(user_input_rewrite)
186186

187187
# Add database connection info to injected parameters
188188
query_params = injected_parameters.copy() if injected_parameters else {}
@@ -196,8 +196,8 @@ async def consume_inner_messages_from_agentic_flow(
196196
# Launch tasks for each sub-query
197197
inner_solving_generators.append(
198198
consume_inner_messages_from_agentic_flow(
199-
inner_autogen_text_2_sql.process_question(
200-
question=question_rewrite,
199+
inner_autogen_text_2_sql.process_user_input(
200+
user_input=user_input_rewrite,
201201
injected_parameters=query_params,
202202
),
203203
identifier,

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class SqlQueryCacheAgent(BaseChatAgent):
1717
def __init__(self):
1818
super().__init__(
1919
"sql_query_cache_agent",
20-
"An agent that fetches the queries from the cache based on the user question.",
20+
"An agent that fetches the queries from the cache based on the user user_input.",
2121
)
2222

2323
self.agent = SqlQueryCacheAgentCustomAgent()
@@ -40,19 +40,19 @@ async def on_messages(
4040
async def on_messages_stream(
4141
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4242
) -> AsyncGenerator[AgentMessage | Response, None]:
43-
# Get the decomposed questions from the question_rewrite_agent
43+
# Get the decomposed user_inputs from the user_input_rewrite_agent
4444
try:
4545
request_details = json.loads(messages[0].content)
4646
injected_parameters = request_details["injected_parameters"]
47-
user_questions = request_details["question"]
48-
logging.info(f"Processing questions: {user_questions}")
47+
user_user_inputs = request_details["user_input"]
48+
logging.info(f"Processing user_inputs: {user_user_inputs}")
4949
logging.info(f"Input Parameters: {injected_parameters}")
5050
except json.JSONDecodeError:
51-
# If not JSON array, process as single question
51+
# If not JSON array, process as single user_input
5252
raise ValueError("Could not load message")
5353

5454
cached_results = await self.agent.process_message(
55-
user_questions, injected_parameters
55+
user_user_inputs, injected_parameters
5656
)
5757
yield Response(
5858
chat_message=TextMessage(

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_schema_selection_agent.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class SqlSchemaSelectionAgent(BaseChatAgent):
1717
def __init__(self, **kwargs):
1818
super().__init__(
1919
"sql_schema_selection_agent",
20-
"An agent that fetches the schemas from the cache based on the user question.",
20+
"An agent that fetches the schemas from the cache based on the user input.",
2121
)
2222

2323
self.agent = SqlSchemaSelectionAgentCustomAgent(**kwargs)
@@ -43,19 +43,19 @@ async def on_messages_stream(
4343
# Try to parse as JSON first
4444
try:
4545
request_details = json.loads(messages[0].content)
46-
user_questions = request_details["question"]
46+
user_inputs = request_details["question"]
4747
except (json.JSONDecodeError, KeyError):
4848
# If not JSON or missing question key, use content directly
49-
user_questions = messages[0].content
49+
user_inputs = messages[0].content
5050

51-
if isinstance(user_questions, str):
52-
user_questions = [user_questions]
53-
elif not isinstance(user_questions, list):
54-
user_questions = [str(user_questions)]
51+
if isinstance(user_inputs, str):
52+
user_inputs = [user_inputs]
53+
elif not isinstance(user_inputs, list):
54+
user_inputs = [str(user_inputs)]
5555

56-
logging.info(f"Processing questions: {user_questions}")
56+
logging.info(f"Processing questions: {user_inputs}")
5757

58-
final_results = await self.agent.process_message(user_questions)
58+
final_results = await self.agent.process_message(user_inputs)
5959

6060
yield Response(
6161
chat_message=TextMessage(

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def process_question(
178178
179179
Args:
180180
----
181-
task (str): The user question to process.
181+
task (str): The user input to process.
182182
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.
183183
184184
Returns:

text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_query_cache_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(self):
99
self.sql_connector = ConnectorFactory.get_database_connector()
1010

1111
async def process_message(
12-
self, user_questions: list[str], injected_parameters: dict
12+
self, user_inputs: list[str], injected_parameters: dict
1313
) -> dict:
1414
# Initialize results dictionary
1515
cached_results = {
@@ -18,7 +18,7 @@ async def process_message(
1818
}
1919

2020
# Process each question sequentially
21-
for question in user_questions:
21+
for question in user_inputs:
2222
# Fetch the queries from the cache based on the question
2323
logging.info(f"Fetching queries from cache for question: {question}")
2424
cached_query = await self.sql_connector.fetch_queries_from_cache(

text_2_sql/text_2_sql_core/src/text_2_sql_core/custom_agents/sql_schema_selection_agent.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def __init__(self, **kwargs):
2222

2323
self.system_prompt = Template(system_prompt).render(kwargs)
2424

25-
async def process_message(self, user_questions: list[str]) -> dict:
26-
logging.info(f"User questions: {user_questions}")
25+
async def process_message(self, user_inputs: list[str]) -> dict:
26+
logging.info(f"user inputs: {user_inputs}")
2727

2828
entity_tasks = []
2929

30-
for user_question in user_questions:
30+
for user_input in user_inputs:
3131
messages = [
3232
{"role": "system", "content": self.system_prompt},
33-
{"role": "user", "content": user_question},
33+
{"role": "user", "content": user_input},
3434
]
3535
entity_tasks.append(
3636
self.open_ai_connector.run_completion_request(
@@ -47,7 +47,7 @@ async def process_message(self, user_questions: list[str]) -> dict:
4747
logging.info(f"Entity result: {entity_result}")
4848

4949
for entity_group in entity_result.entities:
50-
logging.info(f"Searching for schemas for entity group: {entity_group}")
50+
logging.info("Searching for schemas for entity group: %s", entity_group)
5151
entity_search_tasks.append(
5252
self.sql_connector.get_entity_schemas(
5353
" ".join(entity_group), as_json=False
@@ -56,7 +56,7 @@ async def process_message(self, user_questions: list[str]) -> dict:
5656

5757
for filter_condition in entity_result.filter_conditions:
5858
logging.info(
59-
f"Searching for column values for filter: {filter_condition}"
59+
"Searching for column values for filter: %s", filter_condition
6060
)
6161
column_search_tasks.append(
6262
self.sql_connector.get_column_values(

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/question_rewrite_agent.yaml renamed to text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_input_rewrite_agent.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
model: "4o-mini"
2-
description: "An agent that preprocesses user questions by decomposing complex queries into simpler sub-queries that can be processed independently and then combined."
2+
description: "An agent that preprocesses user inputs by decomposing complex queries into simpler sub-queries that can be processed independently and then combined."
33
system_message: |
44
<role_and_objective>
55
You are a helpful AI Assistant specializing in breaking down complex questions into simpler sub-queries that can be processed independently and then combined for the final answer. You should identify when a question can be solved through simpler sub-queries and provide clear instructions for combining their results.

0 commit comments

Comments
 (0)