1515from ms_agent .tools import ToolManager
1616from ms_agent .utils import async_retry
1717from ms_agent .utils .logger import logger
18- from omegaconf import DictConfig
18+ from omegaconf import DictConfig , OmegaConf
1919
2020from ..utils .utils import read_history , save_history
2121from .base import Agent
@@ -308,8 +308,9 @@ def _log_output(content: str, tag: str):
308308 for _line in line .split ('\\ n' ):
309309 logger .info (f'[{ tag } ] { _line } ' )
310310
311- @async_retry (max_attempts = 2 )
312- async def _step (self , messages : List [Message ], tag : str ) -> List [Message ]:
311+ @async_retry (max_attempts = 2 , delay = 1.0 )
312+ async def _step (self , messages : List [Message ],
313+ tag : str ) -> List [Message ]: # type: ignore
313314 """
314315 Execute a single step in the agent's interaction loop.
315316
@@ -345,12 +346,18 @@ async def _step(self, messages: List[Message], tag: str) -> List[Message]:
345346 self .config .generation_config , 'stream' , False ):
346347 self ._log_output ('[assistant]:' , tag = tag )
347348 _content = ''
349+ is_first = True
348350 for _response_message in self ._handle_stream_message (
349351 messages , tools = tools ):
352+ if is_first :
353+ messages .append (_response_message )
354+ is_first = False
350355 new_content = _response_message .content [len (_content ):]
351356 sys .stdout .write (new_content )
352357 sys .stdout .flush ()
353358 _content = _response_message .content
359+ messages [- 1 ] = _response_message
360+ yield messages
354361 sys .stdout .write ('\n ' )
355362 else :
356363 _response_message = self .llm .generate (messages , tools = tools )
@@ -381,7 +388,7 @@ async def _step(self, messages: List[Message], tag: str) -> List[Message]:
381388 f'[usage] prompt_tokens: { _response_message .prompt_tokens } , '
382389 f'completion_tokens: { _response_message .completion_tokens } ' ,
383390 tag = tag )
384- return messages
391+ yield messages
385392
386393 def _prepare_llm (self ):
387394 """Initialize the LLM model from the configuration."""
@@ -440,13 +447,8 @@ def _save_history(self, messages: List[Message], **kwargs):
440447 config = config ,
441448 messages = messages )
442449
443- async def run (self , messages : Union [List [Message ], str ],
444- ** kwargs ) -> List [Message ]:
445- """
446- Main method to execute the agent.
447-
448- Runs the agent loop, which includes generating responses,
449- calling tools, and managing memory and planning.
450+ async def _run (self , messages : Union [List [Message ], str ], ** kwargs ):
451+ """Run the agent, mainly contains a llm calling and tool calling loop.
450452
451453 Args:
452454 messages (Union[List[Message], str]): Input data for the agent. Can be a raw string prompt,
@@ -483,7 +485,9 @@ async def run(self, messages: Union[List[Message], str],
483485 self ._log_output ('[' + message .role + ']:' , tag = self .tag )
484486 self ._log_output (message .content , tag = self .tag )
485487 while not self .runtime .should_stop :
486- messages = await self ._step (messages , self .tag )
488+ yield_step = self ._step (messages , self .tag )
489+ async for messages in yield_step :
490+ yield messages
487491 self .runtime .round += 1
488492 # +1 means the next round the assistant may give a conclusion
489493 if self .runtime .round >= self .max_chat_round + 1 :
@@ -495,15 +499,35 @@ async def run(self, messages: Union[List[Message], str],
495499 f'Task { messages [1 ].content } failed, max round({ self .max_chat_round } ) exceeded.'
496500 ))
497501 self .runtime .should_stop = True
502+ yield messages
498503 # save history
499504 self ._save_history (messages , ** kwargs )
500505
501506 await self ._loop_callback ('on_task_end' , messages )
502507 await self ._cleanup_tools ()
503- return messages
504508 except Exception as e :
505509 if hasattr (self .config , 'help' ):
506510 logger .error (
507511 f'[{ self .tag } ] Runtime error, please follow the instructions:\n \n { self .config .help } '
508512 )
509513 raise e
514+
515+ async def run (self , messages : Union [List [Message ], str ],
516+ ** kwargs ) -> List [Message ]:
517+ stream = kwargs .get ('stream' , False )
518+ if stream :
519+ OmegaConf .update (
520+ self .config , 'generation_config.stream' , True , merge = True )
521+
522+ if stream :
523+
524+ async def stream_generator ():
525+ async for chunk in self ._run (messages = messages , ** kwargs ):
526+ yield chunk
527+
528+ return stream_generator ()
529+ else :
530+ res = None
531+ async for chunk in self ._run (messages = messages , ** kwargs ):
532+ res = chunk
533+ return res
0 commit comments