1515from ms_agent .tools import ToolManager
1616from ms_agent .utils import async_retry
1717from ms_agent .utils .logger import logger
18- from omegaconf import DictConfig
18+ from omegaconf import DictConfig , OmegaConf
1919
2020from ..utils .utils import read_history , save_history
2121from .base import Agent
2424from .plan .utils import planer_mapping
2525from .runtime import Runtime
2626
27- DEFAULT_YAML = os .path .join (
28- os .path .dirname (os .path .abspath (__file__ )), 'agent.yaml' )
29-
3027
3128class LLMAgent (Agent ):
3229 """
@@ -51,7 +48,7 @@ class LLMAgent(Agent):
5148 DEFAULT_SYSTEM = 'You are a helpful assistant.'
5249
5350 def __init__ (self ,
54- config_dir_or_id : Optional [str ] = DEFAULT_YAML ,
51+ config_dir_or_id : Optional [str ] = None ,
5552 config : Optional [DictConfig ] = None ,
5653 env : Optional [Dict [str , str ]] = None ,
5754 ** kwargs ):
@@ -311,8 +308,9 @@ def _log_output(content: str, tag: str):
311308 for _line in line .split ('\\ n' ):
312309 logger .info (f'[{ tag } ] { _line } ' )
313310
314- @async_retry (max_attempts = 2 )
315- async def _step (self , messages : List [Message ], tag : str ) -> List [Message ]:
311+ @async_retry (max_attempts = 2 , delay = 1.0 )
312+ async def _step (self , messages : List [Message ],
313+ tag : str ) -> List [Message ]: # type: ignore
316314 """
317315 Execute a single step in the agent's interaction loop.
318316
@@ -348,12 +346,18 @@ async def _step(self, messages: List[Message], tag: str) -> List[Message]:
348346 self .config .generation_config , 'stream' , False ):
349347 self ._log_output ('[assistant]:' , tag = tag )
350348 _content = ''
349+ is_first = True
351350 for _response_message in self ._handle_stream_message (
352351 messages , tools = tools ):
352+ if is_first :
353+ messages .append (_response_message )
354+ is_first = False
353355 new_content = _response_message .content [len (_content ):]
354356 sys .stdout .write (new_content )
355357 sys .stdout .flush ()
356358 _content = _response_message .content
359+ messages [- 1 ] = _response_message
360+ yield messages
357361 sys .stdout .write ('\n ' )
358362 else :
359363 _response_message = self .llm .generate (messages , tools = tools )
@@ -384,7 +388,7 @@ async def _step(self, messages: List[Message], tag: str) -> List[Message]:
384388 f'[usage] prompt_tokens: { _response_message .prompt_tokens } , '
385389 f'completion_tokens: { _response_message .completion_tokens } ' ,
386390 tag = tag )
387- return messages
391+ yield messages
388392
389393 def _prepare_llm (self ):
390394 """Initialize the LLM model from the configuration."""
@@ -443,13 +447,8 @@ def _save_history(self, messages: List[Message], **kwargs):
443447 config = config ,
444448 messages = messages )
445449
446- async def run (self , messages : Union [List [Message ], str ],
447- ** kwargs ) -> List [Message ]:
448- """
449- Main method to execute the agent.
450-
451- Runs the agent loop, which includes generating responses,
452- calling tools, and managing memory and planning.
450+ async def _run (self , messages : Union [List [Message ], str ], ** kwargs ):
451+ """Run the agent, mainly contains a llm calling and tool calling loop.
453452
454453 Args:
455454 messages (Union[List[Message], str]): Input data for the agent. Can be a raw string prompt,
@@ -486,7 +485,9 @@ async def run(self, messages: Union[List[Message], str],
486485 self ._log_output ('[' + message .role + ']:' , tag = self .tag )
487486 self ._log_output (message .content , tag = self .tag )
488487 while not self .runtime .should_stop :
489- messages = await self ._step (messages , self .tag )
488+ yield_step = self ._step (messages , self .tag )
489+ async for messages in yield_step :
490+ yield messages
490491 self .runtime .round += 1
491492 # +1 means the next round the assistant may give a conclusion
492493 if self .runtime .round >= self .max_chat_round + 1 :
@@ -498,15 +499,35 @@ async def run(self, messages: Union[List[Message], str],
498499 f'Task { messages [1 ].content } failed, max round({ self .max_chat_round } ) exceeded.'
499500 ))
500501 self .runtime .should_stop = True
502+ yield messages
501503 # save history
502504 self ._save_history (messages , ** kwargs )
503505
504506 await self ._loop_callback ('on_task_end' , messages )
505507 await self ._cleanup_tools ()
506- return messages
507508 except Exception as e :
508509 if hasattr (self .config , 'help' ):
509510 logger .error (
510511 f'[{ self .tag } ] Runtime error, please follow the instructions:\n \n { self .config .help } '
511512 )
512513 raise e
514+
515+ async def run (self , messages : Union [List [Message ], str ],
516+ ** kwargs ) -> List [Message ]:
517+ stream = kwargs .get ('stream' , False )
518+ if stream :
519+ OmegaConf .update (
520+ self .config , 'generation_config.stream' , True , merge = True )
521+
522+ if stream :
523+
524+ async def stream_generator ():
525+ async for chunk in self ._run (messages = messages , ** kwargs ):
526+ yield chunk
527+
528+ return stream_generator ()
529+ else :
530+ res = None
531+ async for chunk in self ._run (messages = messages , ** kwargs ):
532+ res = chunk
533+ return res
0 commit comments