File tree Expand file tree Collapse file tree 2 files changed +33
-2
lines changed
Expand file tree Collapse file tree 2 files changed +33
-2
lines changed Original file line number Diff line number Diff line change @@ -156,6 +156,16 @@ def tool(func):
156156def wrap_error (func ):
157157 """Decorator to capture and return the possible error when running the given tool."""
158158
159+ async def __wrap_aiter (
160+ aiter_ : AsyncIterator [ChatMessage | str ],
161+ ) -> AsyncIterator [ChatMessage | str ]:
162+ try :
163+ async for chunk in aiter_ :
164+ yield chunk
165+ except Exception as exc :
166+ logger .exception (exc )
167+ yield f"Error: { exc } "
168+
159169 @functools .wraps (func )
160170 async def run (
161171 * args : Any , ** kwargs : Any
@@ -180,7 +190,7 @@ async def run(
180190
181191 result = func (* args , ** kwargs )
182192 if is_async_iterator (result ):
183- return result
193+ return __wrap_aiter ( result )
184194 else :
185195 return await result
186196
Original file line number Diff line number Diff line change 1+ from typing import AsyncIterator
2+
13from pydantic import Field
24import pytest
35
46from coagent .agents .chat_agent import wrap_error
57
68
79@pytest .mark .asyncio
8- async def test_wrap_error ():
10+ async def test_wrap_error_normal ():
911 @wrap_error
1012 async def func (
1113 a : int = Field (..., description = "Argument a" ),
@@ -16,3 +18,22 @@ async def func(
1618 assert await func () == "Error: Missing required argument 'a'"
1719 assert await func (a = 1 ) == 1
1820 assert await func (a = 1 , b = 0 ) == "Error: division by zero"
21+
22+
23+ @pytest .mark .asyncio
24+ async def test_wrap_error_aiter ():
25+ @wrap_error
26+ async def func (
27+ a : int = Field (..., description = "Argument a" ),
28+ b : int = Field (1 , description = "Argument b" ),
29+ ) -> AsyncIterator [float ]:
30+ yield a / b
31+
32+ result = await func ()
33+ assert result == "Error: Missing required argument 'a'"
34+
35+ result = await func (a = 1 )
36+ assert await anext (result ) == 1
37+
38+ result = await func (a = 1 , b = 0 )
39+ assert await anext (result ) == "Error: division by zero"
You can’t perform that action at this time.
0 commit comments