Skip to content

Commit 48e207b

Browse files
committed
Handle tools that return async iterators more properly
1 parent 62fb0df commit 48e207b

File tree

2 files changed

+42
-17
lines changed

2 files changed

+42
-17
lines changed

coagent/agents/aswarm/core.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,19 +180,30 @@ async def handle_tool_calls(
180180
function_result = func(**args)
181181

182182
if is_async_iterator(function_result):
183-
# NOTE(luopeng)
183+
# NOTE(luopeng):
184184
#
185-
# If the function returns an async iterator, we assume that
186-
# the function is actually a sub-agent, then we should return
187-
# the stream directly to the user.
185+
# The function is an async generator function. We assume that
186+
# it's better to return the stream directly to the user.
188187
#
189188
# Note that this only works if there's one tool call in the batch.
190189
async for chunk in function_result:
191190
yield normalize_function_result(chunk)
192191
return
193192

193+
function_result_after_await = await function_result
194+
if is_async_iterator(function_result_after_await):
195+
# NOTE(luopeng):
196+
#
197+
# The function returns an async iterator internally. We assume
198+
# that it's better to return the stream directly to the user.
199+
#
200+
# Note that this only works if there's one tool call in the batch.
201+
async for chunk in function_result_after_await:
202+
yield normalize_function_result(chunk)
203+
return
204+
194205
# Non-streaming results are handled here.
195-
raw_result = normalize_function_result(await function_result)
206+
raw_result = normalize_function_result(function_result_after_await)
196207
if raw_result.to_user:
197208
# Return the reply directly to the user.
198209
yield raw_result

coagent/agents/chat_agent.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, AsyncIterator, Callable
88

99
from coagent.core import Address, BaseAgent, Context, handler, logger
10+
from coagent.core.agent import is_async_iterator
1011
from pydantic_core import PydanticUndefined
1112
from pydantic.fields import FieldInfo
1213

@@ -44,7 +45,9 @@ def confirm(template: str):
4445

4546
def wrapper(func):
4647
@functools.wraps(func)
47-
async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
48+
async def run(
49+
*args: Any, **kwargs: Any
50+
) -> AsyncIterator[ChatMessage | str] | ChatMessage | str:
4851
# Ask the user to confirm if not yet.
4952
ctx = kwargs.get("ctx", None)
5053
if ctx and not RunContext(ctx).user_confirmed:
@@ -58,9 +61,11 @@ async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
5861
to_user=True,
5962
)
6063

61-
# Note that we assume that the tool is not an async generator,
62-
# so we always use `await` here.
63-
return await func(*args, **kwargs)
64+
result = func(*args, **kwargs)
65+
if is_async_iterator(result):
66+
return result
67+
else:
68+
return await result
6469

6570
return run
6671

@@ -88,7 +93,9 @@ def submit(template: str = ""):
8893

8994
def wrapper(func):
9095
@functools.wraps(func)
91-
async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
96+
async def run(
97+
*args: Any, **kwargs: Any
98+
) -> AsyncIterator[ChatMessage | str] | ChatMessage | str:
9299
# Ask the user to fill in the input form if not yet.
93100
ctx = kwargs.get("ctx", None)
94101
if ctx and not RunContext(ctx).user_submitted:
@@ -107,9 +114,11 @@ async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
107114
to_user=True,
108115
)
109116

110-
# Note that we assume that the tool is not an async generator,
111-
# so we always use `await` here.
112-
return await func(*args, **kwargs)
117+
result = func(*args, **kwargs)
118+
if is_async_iterator(result):
119+
return result
120+
else:
121+
return await result
113122

114123
return run
115124

@@ -148,7 +157,9 @@ def wrap_error(func):
148157
"""Decorator to capture and return the possible error when running the given tool."""
149158

150159
@functools.wraps(func)
151-
async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
160+
async def run(
161+
*args: Any, **kwargs: Any
162+
) -> AsyncIterator[ChatMessage | str] | ChatMessage | str:
152163
try:
153164
# Fill in the missing arguments with default values if possible.
154165
#
@@ -163,9 +174,12 @@ async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
163174
else:
164175
kwargs[name] = default
165176

166-
# Note that we assume that the tool is not an async generator,
167-
# so we always use `await` here.
168-
return await func(*args, **kwargs)
177+
result = func(*args, **kwargs)
178+
if is_async_iterator(result):
179+
return result
180+
else:
181+
return await result
182+
169183
except Exception as exc:
170184
logger.exception(exc)
171185
return f"Error: {exc}"

0 commit comments

Comments
 (0)