77from typing import Any , AsyncIterator , Callable
88
99from coagent .core import Address , BaseAgent , Context , handler , logger
10+ from coagent .core .agent import is_async_iterator
1011from pydantic_core import PydanticUndefined
1112from pydantic .fields import FieldInfo
1213
@@ -44,7 +45,9 @@ def confirm(template: str):
4445
4546 def wrapper (func ):
4647 @functools .wraps (func )
47- async def run (* args : Any , ** kwargs : Any ) -> ChatMessage | str :
48+ async def run (
49+ * args : Any , ** kwargs : Any
50+ ) -> AsyncIterator [ChatMessage | str ] | ChatMessage | str :
4851 # Ask the user to confirm if not yet.
4952 ctx = kwargs .get ("ctx" , None )
5053 if ctx and not RunContext (ctx ).user_confirmed :
@@ -58,9 +61,11 @@ async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
5861 to_user = True ,
5962 )
6063
61- # Note that we assume that the tool is not an async generator,
62- # so we always use `await` here.
63- return await func (* args , ** kwargs )
64+ result = func (* args , ** kwargs )
65+ if is_async_iterator (result ):
66+ return result
67+ else :
68+ return await result
6469
6570 return run
6671
@@ -88,7 +93,9 @@ def submit(template: str = ""):
8893
8994 def wrapper (func ):
9095 @functools .wraps (func )
91- async def run (* args : Any , ** kwargs : Any ) -> ChatMessage | str :
96+ async def run (
97+ * args : Any , ** kwargs : Any
98+ ) -> AsyncIterator [ChatMessage | str ] | ChatMessage | str :
9299 # Ask the user to fill in the input form if not yet.
93100 ctx = kwargs .get ("ctx" , None )
94101 if ctx and not RunContext (ctx ).user_submitted :
@@ -107,9 +114,11 @@ async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
107114 to_user = True ,
108115 )
109116
110- # Note that we assume that the tool is not an async generator,
111- # so we always use `await` here.
112- return await func (* args , ** kwargs )
117+ result = func (* args , ** kwargs )
118+ if is_async_iterator (result ):
119+ return result
120+ else :
121+ return await result
113122
114123 return run
115124
@@ -148,7 +157,9 @@ def wrap_error(func):
148157 """Decorator to capture and return the possible error when running the given tool."""
149158
150159 @functools .wraps (func )
151- async def run (* args : Any , ** kwargs : Any ) -> ChatMessage | str :
160+ async def run (
161+ * args : Any , ** kwargs : Any
162+ ) -> AsyncIterator [ChatMessage | str ] | ChatMessage | str :
152163 try :
153164 # Fill in the missing arguments with default values if possible.
154165 #
@@ -163,9 +174,12 @@ async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
163174 else :
164175 kwargs [name ] = default
165176
166- # Note that we assume that the tool is not an async generator,
167- # so we always use `await` here.
168- return await func (* args , ** kwargs )
177+ result = func (* args , ** kwargs )
178+ if is_async_iterator (result ):
179+ return result
180+ else :
181+ return await result
182+
169183 except Exception as exc :
170184 logger .exception (exc )
171185 return f"Error: { exc } "
0 commit comments