Skip to content

Commit 17673a2

Browse files
committed
Improve wrap_error to fill in missing arguments with default values
1 parent 85e52f4 commit 17673a2

File tree

4 files changed

+49
-6
lines changed

4 files changed

+49
-6
lines changed

coagent/agents/chat_agent.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, AsyncIterator, Callable
77

88
from coagent.core import Address, BaseAgent, Context, handler, logger
9+
from pydantic_core import PydanticUndefined
910
from pydantic.fields import FieldInfo
1011

1112
from .aswarm import Agent as SwarmAgent, Swarm
@@ -39,14 +40,10 @@ async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
3940
if ctx and not RunContext(ctx).user_confirmed:
4041
# We assume that all meaningful arguments (includes `ctx` but
4142
# excepts possible `self`) are keyword arguments. Therefore,
42-
# here we use kwargs as the source of template variables.
43-
tmpl_vars = {
44-
k: v.default if isinstance(v, FieldInfo) else v
45-
for k, v in kwargs.items()
46-
}
43+
# here we use kwargs directly as the template variables.
4744
return ChatMessage(
4845
role="assistant",
49-
content=template.format(**tmpl_vars),
46+
content=template.format(**kwargs),
5047
type="confirm",
5148
to_user=True,
5249
)
@@ -94,6 +91,19 @@ def wrap_error(func):
9491
@functools.wraps(func)
9592
async def run(*args: Any, **kwargs: Any) -> ChatMessage | str:
9693
try:
94+
# Fill in the missing arguments with default values if possible.
95+
#
96+
# Note that we assume that all meaningful arguments (includes `ctx`
97+
# but excepts possible `self`) are keyword arguments.
98+
sig = inspect.signature(func)
99+
for name, param in sig.parameters.items():
100+
if name not in kwargs and isinstance(param.default, FieldInfo):
101+
default = param.default.default
102+
if default is PydanticUndefined:
103+
raise ValueError(f"Missing required argument {name!r}")
104+
else:
105+
kwargs[name] = default
106+
97107
# Note that we assume that the tool is not an async generator,
98108
# so we always use `await` here.
99109
return await func(*args, **kwargs)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ linting = [
7878
]
7979
testing = [
8080
"pytest",
81+
"pytest-asyncio",
8182
"pytest-cov",
8283
]
8384

tests/agents/test_chat_agent.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from pydantic import Field
2+
import pytest
3+
4+
from coagent.agents.chat_agent import wrap_error
5+
6+
7+
@pytest.mark.asyncio
8+
async def test_wrap_error():
9+
@wrap_error
10+
async def func(
11+
a: int = Field(..., description="Argument a"),
12+
b: int = Field(1, description="Argument b"),
13+
) -> float:
14+
return a / b
15+
16+
assert await func() == "Error: Missing required argument 'a'"
17+
assert await func(a=1) == 1
18+
assert await func(a=1, b=0) == "Error: division by zero"

uv.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)