Skip to content

Commit 9ae0309

Browse files
add a default to ResultData, some related cleanup (#512)
Co-authored-by: Sydney Runkle <[email protected]>
1 parent 866a031 commit 9ae0309

File tree

6 files changed

+42
-29
lines changed

6 files changed

+42
-29
lines changed

docs/api/result.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,3 @@
33
::: pydantic_ai.result
44
options:
55
inherited_members: true
6-
members:
7-
- ResultData
8-
- RunResult
9-
- StreamedRunResult
10-
- Usage

pydantic_ai_examples/sql_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class InvalidRequest(BaseModel):
7373

7474

7575
Response: TypeAlias = Union[Success, InvalidRequest]
76-
agent = Agent(
76+
agent: Agent[Deps, Response] = Agent(
7777
'gemini-1.5-flash',
7878
# Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
7979
result_type=Response, # type: ignore

pydantic_ai_slim/pydantic_ai/_result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
from . import _utils, messages as _messages
1414
from .exceptions import ModelRetry
15-
from .result import ResultData
16-
from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition
15+
from .result import ResultData, ResultValidatorFunc
16+
from .tools import AgentDeps, RunContext, ToolDefinition
1717

1818

1919
@dataclass

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,56 @@
44
from collections.abc import AsyncIterator, Awaitable, Callable
55
from dataclasses import dataclass, field
66
from datetime import datetime
7-
from typing import Generic, TypeVar, cast
7+
from typing import Generic, Union, cast
88

99
import logfire_api
10+
from typing_extensions import TypeVar
1011

1112
from . import _result, _utils, exceptions, messages as _messages, models
1213
from .settings import UsageLimits
1314
from .tools import AgentDeps, RunContext
1415

1516
__all__ = (
1617
'ResultData',
18+
'ResultValidatorFunc',
1719
'Usage',
1820
'RunResult',
1921
'StreamedRunResult',
2022
)
2123

2224

23-
ResultData = TypeVar('ResultData')
25+
ResultData = TypeVar('ResultData', default=str)
2426
"""Type variable for the result data of a run."""
2527

28+
ResultValidatorFunc = Union[
29+
Callable[[RunContext[AgentDeps], ResultData], ResultData],
30+
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
31+
Callable[[ResultData], ResultData],
32+
Callable[[ResultData], Awaitable[ResultData]],
33+
]
34+
"""
35+
A function that always takes `ResultData` and returns `ResultData` and:
36+
37+
* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument
38+
* may or may not be async
39+
40+
Usage `ResultValidatorFunc[AgentDeps, ResultData]`.
41+
"""
42+
2643
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
2744

2845

2946
@dataclass
3047
class Usage:
31-
"""LLM usage associated to a request or run.
48+
"""LLM usage associated with a request or run.
3249
3350
Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests.
3451
3552
You'll need to look up the documentation of the model you're using to convert usage to monetary costs.
3653
"""
3754

3855
requests: int = 0
39-
"""Number of requests made."""
56+
"""Number of requests made to the LLM API."""
4057
request_tokens: int | None = None
4158
"""Tokens used in processing requests."""
4259
response_tokens: int | None = None

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
__all__ = (
1717
'AgentDeps',
1818
'RunContext',
19-
'ResultValidatorFunc',
2019
'SystemPromptFunc',
2120
'ToolFuncContext',
2221
'ToolFuncPlain',
@@ -73,21 +72,6 @@ def replace_with(
7372
Usage `SystemPromptFunc[AgentDeps]`.
7473
"""
7574

76-
ResultData = TypeVar('ResultData')
77-
78-
ResultValidatorFunc = Union[
79-
Callable[[RunContext[AgentDeps], ResultData], ResultData],
80-
Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
81-
Callable[[ResultData], ResultData],
82-
Callable[[ResultData], Awaitable[ResultData]],
83-
]
84-
"""
85-
A function that always takes `ResultData` and returns `ResultData`,
86-
but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
87-
88-
Usage `ResultValidator[AgentDeps, ResultData]`.
89-
"""
90-
9175
ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any]
9276
"""A tool function that takes `RunContext` as the first argument.
9377

tests/typed_agent.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Awaitable, Iterator
44
from contextlib import contextmanager
55
from dataclasses import dataclass
6-
from typing import Callable, Union, assert_type
6+
from typing import Callable, TypeAlias, Union, assert_type
77

88
from pydantic_ai import Agent, ModelRetry, RunContext, Tool
99
from pydantic_ai.result import RunResult
@@ -178,6 +178,13 @@ def run_sync3() -> None:
178178
assert_type(result.data, Union[Foo, Bar])
179179

180180

181+
MyUnion: TypeAlias = 'Foo | Bar'
182+
union_agent2: Agent[None, MyUnion] = Agent(
183+
result_type=MyUnion, # type: ignore[arg-type]
184+
)
185+
assert_type(union_agent2, Agent[None, MyUnion])
186+
187+
181188
def foobar_ctx(ctx: RunContext[int], x: str, y: int) -> str:
182189
return f'{x} {y}'
183190

@@ -225,3 +232,13 @@ async def prepare_greet(ctx: RunContext[str], tool_def: ToolDefinition) -> ToolD
225232

226233
result = greet_agent.run_sync('testing...', deps='human')
227234
assert result.data == '{"greet":"hello a"}'
235+
236+
MYPY = False
237+
if not MYPY:
238+
default_agent = Agent()
239+
assert_type(default_agent, Agent[None, str])
240+
assert_type(default_agent, Agent[None])
241+
242+
partial_agent: Agent[MyDeps] = Agent(deps_type=MyDeps)
243+
assert_type(partial_agent, Agent[MyDeps, str])
244+
assert_type(partial_agent, Agent[MyDeps])

0 commit comments

Comments
 (0)