Skip to content

Commit 124c491

Browse files
committed
Separate out sources and answer agent
1 parent 1132cd0 commit 124c491

File tree

5 files changed

+51
-12
lines changed

5 files changed

+51
-12
lines changed

text_2_sql/autogen/src/__init__.py

Whitespace-only changes.

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import (
1212
ParallelQuerySolvingAgent,
1313
)
14-
from autogen_text_2_sql.custom_agents.answer_and_sources_agent import (
15-
AnswerAndSourcesAgent,
14+
from text_2_sql.autogen.src.autogen_text_2_sql.custom_agents.sources_agent import (
15+
SourcesAgent,
1616
)
1717
from autogen_agentchat.agents import UserProxyAgent
1818
from autogen_agentchat.messages import TextMessage
@@ -57,7 +57,9 @@ def get_all_agents(self):
5757
engine_specific_rules=self.engine_specific_rules, **self.kwargs
5858
)
5959

60-
self.answer_and_sources_agent = AnswerAndSourcesAgent()
60+
self.answer_agent = LLMAgentCreator.create("answer_agent")
61+
62+
self.sources_agent = SourcesAgent()
6163

6264
# Auto-responding UserProxyAgent
6365
self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy")
@@ -66,7 +68,7 @@ def get_all_agents(self):
6668
self.user_proxy,
6769
self.query_rewrite_agent,
6870
self.parallel_query_solving_agent,
69-
self.answer_and_sources_agent,
71+
self.sources_agent,
7072
]
7173

7274
return agents
@@ -95,7 +97,9 @@ def unified_selector(self, messages):
9597
decision = "parallel_query_solving_agent"
9698
# Handle transition after parallel query solving
9799
elif current_agent == "parallel_query_solving_agent":
98-
decision = "answer_and_sources_agent"
100+
decision = "answer_agent"
101+
elif current_agent == "answer_agent":
102+
decision = "sources_agent"
99103

100104
if decision:
101105
logging.info(f"Agent transition: {current_agent} -> {decision}")
@@ -118,7 +122,7 @@ def agentic_flow(self):
118122

119123
async def process_question(
120124
self,
121-
task: str,
125+
question: str,
122126
chat_history: list[str] = None,
123127
parameters: dict = None,
124128
):
@@ -134,11 +138,11 @@ async def process_question(
134138
-------
135139
dict: The response from the system.
136140
"""
137-
logging.info("Processing question: %s", task)
141+
logging.info("Processing question: %s", question)
138142
logging.info("Chat history: %s", chat_history)
139143

140144
agent_input = {
141-
"user_question": task,
145+
"question": question,
142146
"chat_history": {},
143147
"parameters": parameters,
144148
}

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ async def on_messages_stream(
4141
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4242
) -> AsyncGenerator[AgentMessage | Response, None]:
4343
last_response = messages[-1].content
44+
parameter_input = messages[0].content
45+
last_response = messages[-1].content
46+
try:
47+
user_parameters = json.loads(parameter_input)["parameters"]
48+
except json.JSONDecodeError:
49+
logging.error("Error decoding the user parameters.")
50+
user_parameters = {}
4451

4552
# Load the json of the last message to populate the final output object
4653
query_rewrites = json.loads(last_response)
@@ -49,14 +56,16 @@ async def on_messages_stream(
4956

5057
inner_solving_tasks = []
5158

52-
for query_rewrite in query_rewrites:
59+
for query_rewrite in query_rewrites["sub_queries"]:
5360
# Create an instance of the InnerAutoGenText2Sql class
5461
inner_autogen_text_2_sql = InnerAutoGenText2Sql(
5562
self.engine_specific_rules, **self.kwargs
5663
)
5764

5865
inner_solving_tasks.append(
59-
inner_autogen_text_2_sql.run_stream(task=query_rewrite)
66+
inner_autogen_text_2_sql.process_question(
67+
question=query_rewrite, parameters=user_parameters
68+
)
6069
)
6170

6271
# Wait for all the inner solving tasks to complete

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/answer_and_sources_agent.py renamed to text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sources_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
import pandas as pd
1313

1414

15-
class AnswerAndSourcesAgent(BaseChatAgent):
15+
class SourcesAgent(BaseChatAgent):
1616
def __init__(self):
1717
super().__init__(
18-
"answer_and_sources_agent",
18+
"sources_agent",
1919
"An agent that formats the final answer and sources.",
2020
)
2121

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,29 @@ def agentic_flow(self):
187187
selector_func=self.unified_selector,
188188
)
189189
return flow
190+
191+
async def process_question(
192+
self,
193+
question: str,
194+
parameters: dict = None,
195+
):
196+
"""Process the complete question through the unified system.
197+
198+
Args:
199+
----
200+
task (str): The user question to process.
201+
parameters (dict, optional): Parameters to pass to agents. Defaults to None.
202+
203+
Returns:
204+
-------
205+
dict: The response from the system.
206+
"""
207+
logging.info("Processing question: %s", question)
208+
209+
agent_input = {
210+
"question": question,
211+
"chat_history": {},
212+
"parameters": parameters,
213+
}
214+
215+
return self.agentic_flow.run_stream(task=json.dumps(agent_input))

0 commit comments

Comments
 (0)