Skip to content

Commit 32fd9d9

Browse files
mushenL杨堃luyan
authored
add streamllm_agent (#635)
Co-authored-by: 杨堃 <yk01645326@alibaba-inc.com> Co-authored-by: luyan <luyan@U-V61TJ94D-2208.local>
1 parent d82e1c8 commit 32fd9d9

File tree

4 files changed

+50
-17
lines changed

4 files changed

+50
-17
lines changed

ms_agent/agent/llm_agent.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ms_agent.tools import ToolManager
1616
from ms_agent.utils import async_retry
1717
from ms_agent.utils.logger import logger
18-
from omegaconf import DictConfig
18+
from omegaconf import DictConfig, OmegaConf
1919

2020
from ..utils.utils import read_history, save_history
2121
from .base import Agent
@@ -308,8 +308,9 @@ def _log_output(content: str, tag: str):
308308
for _line in line.split('\\n'):
309309
logger.info(f'[{tag}] {_line}')
310310

311-
@async_retry(max_attempts=2)
312-
async def _step(self, messages: List[Message], tag: str) -> List[Message]:
311+
@async_retry(max_attempts=2, delay=1.0)
312+
async def _step(self, messages: List[Message],
313+
tag: str) -> List[Message]: # type: ignore
313314
"""
314315
Execute a single step in the agent's interaction loop.
315316
@@ -345,12 +346,18 @@ async def _step(self, messages: List[Message], tag: str) -> List[Message]:
345346
self.config.generation_config, 'stream', False):
346347
self._log_output('[assistant]:', tag=tag)
347348
_content = ''
349+
is_first = True
348350
for _response_message in self._handle_stream_message(
349351
messages, tools=tools):
352+
if is_first:
353+
messages.append(_response_message)
354+
is_first = False
350355
new_content = _response_message.content[len(_content):]
351356
sys.stdout.write(new_content)
352357
sys.stdout.flush()
353358
_content = _response_message.content
359+
messages[-1] = _response_message
360+
yield messages
354361
sys.stdout.write('\n')
355362
else:
356363
_response_message = self.llm.generate(messages, tools=tools)
@@ -381,7 +388,7 @@ async def _step(self, messages: List[Message], tag: str) -> List[Message]:
381388
f'[usage] prompt_tokens: {_response_message.prompt_tokens}, '
382389
f'completion_tokens: {_response_message.completion_tokens}',
383390
tag=tag)
384-
return messages
391+
yield messages
385392

386393
def _prepare_llm(self):
387394
"""Initialize the LLM model from the configuration."""
@@ -440,13 +447,8 @@ def _save_history(self, messages: List[Message], **kwargs):
440447
config=config,
441448
messages=messages)
442449

443-
async def run(self, messages: Union[List[Message], str],
444-
**kwargs) -> List[Message]:
445-
"""
446-
Main method to execute the agent.
447-
448-
Runs the agent loop, which includes generating responses,
449-
calling tools, and managing memory and planning.
450+
async def _run(self, messages: Union[List[Message], str], **kwargs):
451+
"""Run the agent, mainly contains a llm calling and tool calling loop.
450452
451453
Args:
452454
messages (Union[List[Message], str]): Input data for the agent. Can be a raw string prompt,
@@ -483,7 +485,9 @@ async def run(self, messages: Union[List[Message], str],
483485
self._log_output('[' + message.role + ']:', tag=self.tag)
484486
self._log_output(message.content, tag=self.tag)
485487
while not self.runtime.should_stop:
486-
messages = await self._step(messages, self.tag)
488+
yield_step = self._step(messages, self.tag)
489+
async for messages in yield_step:
490+
yield messages
487491
self.runtime.round += 1
488492
# +1 means the next round the assistant may give a conclusion
489493
if self.runtime.round >= self.max_chat_round + 1:
@@ -495,15 +499,35 @@ async def run(self, messages: Union[List[Message], str],
495499
f'Task {messages[1].content} failed, max round({self.max_chat_round}) exceeded.'
496500
))
497501
self.runtime.should_stop = True
502+
yield messages
498503
# save history
499504
self._save_history(messages, **kwargs)
500505

501506
await self._loop_callback('on_task_end', messages)
502507
await self._cleanup_tools()
503-
return messages
504508
except Exception as e:
505509
if hasattr(self.config, 'help'):
506510
logger.error(
507511
f'[{self.tag}] Runtime error, please follow the instructions:\n\n {self.config.help}'
508512
)
509513
raise e
514+
515+
async def run(self, messages: Union[List[Message], str],
516+
**kwargs) -> List[Message]:
517+
stream = kwargs.get('stream', False)
518+
if stream:
519+
OmegaConf.update(
520+
self.config, 'generation_config.stream', True, merge=True)
521+
522+
if stream:
523+
524+
async def stream_generator():
525+
async for chunk in self._run(messages=messages, **kwargs):
526+
yield chunk
527+
528+
return stream_generator()
529+
else:
530+
res = None
531+
async for chunk in self._run(messages=messages, **kwargs):
532+
res = chunk
533+
return res

ms_agent/llm/openai_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def format_tools(self,
7272
tools = None
7373
return tools
7474

75-
@retry(max_attempts=12, delay=1.0)
75+
@retry(max_attempts=3, delay=1.0)
7676
def generate(self,
7777
messages: List[Message],
7878
tools: Optional[List[Tool]] = None,

ms_agent/tools/mcp_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,13 @@ async def call_tool(self, server_name: str, tool_name: str,
6161
texts = []
6262
if response.isError:
6363
sep = '\n\n'
64-
return f'execute error: {sep.join(response.content)}'
64+
if all(isinstance(item, str) for item in response.content):
65+
return f'execute error: {sep.join(response.content)}'
66+
else:
67+
item_list = []
68+
for item in response.content:
69+
item_list.append(item.text)
70+
return f'execute error: {sep.join(item_list)}'
6571
for content in response.content:
6672
if content.type == 'text':
6773
texts.append(content.text)

ms_agent/utils/llm_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import asyncio
23
import functools
34
import time
45
from typing import Callable, Tuple, Type, TypeVar, Union
@@ -63,15 +64,17 @@ async def wrapper(*args, **kwargs) -> T:
6364

6465
for attempt in range(1, max_attempts + 1):
6566
try:
66-
return await func(*args, **kwargs)
67+
async for item in func(*args, **kwargs):
68+
yield item
69+
return
6770
except exceptions as e:
6871
last_exception = e
6972
if attempt < max_attempts:
7073
logger.warning(
7174
f'Attempt {attempt}/{max_attempts} fails: {func.__name__}. '
7275
f'Exception message: {e}. Will retry in {current_delay:.2f} seconds.'
7376
)
74-
time.sleep(current_delay)
77+
await asyncio.sleep(current_delay)
7578
current_delay *= backoff_factor
7679
else:
7780
logger.error(

0 commit comments

Comments
 (0)