55from autogen_agentchat .agents import BaseChatAgent
66from autogen_agentchat .base import Response
77from autogen_agentchat .messages import (
8- AgentEvent ,
98 AgentMessage ,
109 ChatMessage ,
1110 TextMessage ,
1211)
1312from autogen_core import CancellationToken
1413import json
1514import logging
16- import asyncio
1715from autogen_text_2_sql .inner_autogen_text_2_sql import InnerAutoGenText2Sql
1816
17+ from aiostream import stream
18+
1919
2020class ParallelQuerySolvingAgent (BaseChatAgent ):
2121 def __init__ (self , engine_specific_rules : str , ** kwargs : dict ):
@@ -45,11 +45,10 @@ async def on_messages(
4545 async def on_messages_stream (
4646 self , messages : Sequence [ChatMessage ], cancellation_token : CancellationToken
4747 ) -> AsyncGenerator [AgentMessage | Response , None ]:
48- inner_messages : List [AgentEvent | ChatMessage ] = []
48+ inner_messages : List [AgentMessage | ChatMessage ] = []
4949
5050 last_response = messages [- 1 ].content
5151 parameter_input = messages [0 ].content
52- last_response = messages [- 1 ].content
5352 try :
5453 user_parameters = json .loads (parameter_input )["parameters" ]
5554 except json .JSONDecodeError :
@@ -61,31 +60,71 @@ async def on_messages_stream(
6160
6261 logging .info (f"Query Rewrite: { query_rewrites } " )
6362
64- inner_solving_tasks = []
63+ inner_solving_generators = []
64+
65+ async def consume_inner_messages_from_agentic_flow (
66+ agentic_flow , identifier , complete_inner_messages
67+ ):
68+ """
69+ Consume the inner messages and append them to the specified list.
70+
71+ Args:
72+ ----
73+ agentic_flow: The async generator to consume messages from.
74+ messages_list: The list to which messages should be added.
75+ """
76+ async for inner_message in agentic_flow :
77+ # Add message to results dictionary, tagged by the function name
78+ if identifier not in complete_inner_messages :
79+ complete_inner_messages [identifier ] = []
80+ complete_inner_messages [identifier ].append (inner_message )
81+
82+ yield {"source" : identifier , "message" : inner_message }
6583
84+ complete_inner_messages = {}
85+
86+ # Start processing sub-queries
6687 for query_rewrite in query_rewrites ["sub_queries" ]:
6788 # Create an instance of the InnerAutoGenText2Sql class
6889 inner_autogen_text_2_sql = InnerAutoGenText2Sql (
6990 self .engine_specific_rules , ** self .kwargs
7091 )
7192
72- inner_solving_tasks .append (
93+ # Launch tasks for each sub-query
94+ inner_solving_generators .append (
7395 inner_autogen_text_2_sql .process_question (
7496 question = query_rewrite , parameters = user_parameters
7597 )
7698 )
7799
78- # Wait for all the inner solving tasks to complete
79- inner_solving_results = await asyncio .gather (* inner_solving_tasks )
100+ combined_message_streams = stream .merge (* inner_solving_generators )
101+
102+ async with combined_message_streams .stream () as streamer :
103+ async for inner_message in streamer :
104+ print (inner_message )
105+ yield inner_message
106+
107+ # # Process the results as they are yielded
108+ # for completed in asyncio.as_completed(inner_solving_generators):
109+ # async for inner_message in completed:
110+ # # Yield the result as soon as it's available
111+ # yield inner_message
112+
113+ # # Wait for all tasks to complete
114+ # await asyncio.gather(*inner_solving_generators, return_exceptions=True)
115+
116+ # # Log final results for debugging or auditing
117+ # logging.info(f"Formatted Results: {complete_inner_messages}")
80118
81- logging . info ( f"Inner Solving Results: { inner_solving_results } " )
119+ # TODO: Trim out unnecessary information from the final response
82120
121+ # Final response
83122 yield Response (
84123 chat_message = TextMessage (
85- content = json .dumps (inner_solving_results ), source = self .name
124+ content = json .dumps (complete_inner_messages ), source = self .name
86125 ),
87- inner_messages = inner_messages ,
126+ inner_messages = complete_inner_messages ,
88127 )
89128
90- async def on_reset (self , cancellation_token : CancellationToken ) -> None :
91- pass
129+ async def on_reset (self , cancellation_token : CancellationToken ) -> None :
130+ pass
0 commit comments