Skip to content

Commit 78072d2

Browse files
authored
BREAKING CHANGE: Make many more dataclasses kw-only (#2738)
1 parent e0e3798 commit 78072d2

File tree

28 files changed

+174
-143
lines changed

28 files changed

+174
-143
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
"""
7979

8080

81-
@dataclasses.dataclass
81+
@dataclasses.dataclass(kw_only=True)
8282
class GraphAgentState:
8383
"""State kept across the execution of the agent graph."""
8484

@@ -99,7 +99,7 @@ def increment_retries(self, max_result_retries: int, error: BaseException | None
9999
raise exceptions.UnexpectedModelBehavior(message)
100100

101101

102-
@dataclasses.dataclass
102+
@dataclasses.dataclass(kw_only=True)
103103
class GraphAgentDeps(Generic[DepsT, OutputDataT]):
104104
"""Dependencies/config passed to the agent graph."""
105105

@@ -157,6 +157,8 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
157157

158158
user_prompt: str | Sequence[_messages.UserContent] | None
159159

160+
_: dataclasses.KW_ONLY
161+
160162
instructions: str | None
161163
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
162164

@@ -359,8 +361,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
359361

360362
request: _messages.ModelRequest
361363

362-
_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
363-
_did_stream: bool = field(default=False, repr=False)
364+
_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None)
365+
_did_stream: bool = field(repr=False, init=False, default=False)
364366

365367
async def run(
366368
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -389,13 +391,13 @@ async def stream(
389391
self._did_stream = True
390392
ctx.state.usage.requests += 1
391393
agent_stream = result.AgentStream[DepsT, T](
392-
streamed_response,
393-
ctx.deps.output_schema,
394-
model_request_parameters,
395-
ctx.deps.output_validators,
396-
build_run_context(ctx),
397-
ctx.deps.usage_limits,
398-
ctx.deps.tool_manager,
394+
_raw_stream_response=streamed_response,
395+
_output_schema=ctx.deps.output_schema,
396+
_model_request_parameters=model_request_parameters,
397+
_output_validators=ctx.deps.output_validators,
398+
_run_ctx=build_run_context(ctx),
399+
_usage_limits=ctx.deps.usage_limits,
400+
_tool_manager=ctx.deps.tool_manager,
399401
)
400402
yield agent_stream
401403
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
@@ -475,9 +477,9 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
475477

476478
model_response: _messages.ModelResponse
477479

478-
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
480+
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, init=False, repr=False)
479481
_next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
480-
default=None, repr=False
482+
default=None, init=False, repr=False
481483
)
482484

483485
async def run(
@@ -629,7 +631,7 @@ async def _handle_text_response(
629631
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
630632
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
631633
else:
632-
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
634+
return self._handle_final_result(ctx, result.FinalResult(result_data), [])
633635

634636

635637
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:

pydantic_ai_slim/pydantic_ai/_function_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
__all__ = ('function_schema',)
3131

3232

33-
@dataclass
33+
@dataclass(kw_only=True)
3434
class FunctionSchema:
3535
"""Internal information about a function schema."""
3636

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""Type variable for agent dependencies."""
1919

2020

21-
@dataclasses.dataclass(repr=False)
21+
@dataclasses.dataclass(repr=False, kw_only=True)
2222
class RunContext(Generic[AgentDepsT]):
2323
"""Information about the current call."""
2424

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,13 +1428,13 @@ class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
14281428
def __init__(
14291429
self,
14301430
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
1431-
max_retries: int = 1,
14321431
*,
1432+
max_retries: int = 1,
14331433
id: str | None = None,
14341434
output_schema: _output.BaseOutputSchema[Any],
14351435
):
14361436
self.output_schema = output_schema
1437-
super().__init__(tools, max_retries, id=id)
1437+
super().__init__(tools, max_retries=max_retries, id=id)
14381438

14391439
@property
14401440
def id(self) -> str:

pydantic_ai_slim/pydantic_ai/builtin_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
__all__ = ('AbstractBuiltinTool', 'WebSearchTool', 'WebSearchUserLocation', 'CodeExecutionTool', 'UrlContextTool')
1010

1111

12-
@dataclass
12+
@dataclass(kw_only=True)
1313
class AbstractBuiltinTool(ABC):
1414
"""A builtin tool that can be used by an agent.
1515
@@ -19,7 +19,7 @@ class AbstractBuiltinTool(ABC):
1919
"""
2020

2121

22-
@dataclass
22+
@dataclass(kw_only=True)
2323
class WebSearchTool(AbstractBuiltinTool):
2424
"""A builtin tool that allows your agent to search the web for information.
2525

pydantic_ai_slim/pydantic_ai/common_tools/duckduckgo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import functools
2-
from dataclasses import dataclass
2+
from dataclasses import KW_ONLY, dataclass
33

44
import anyio
55
import anyio.to_thread
@@ -43,7 +43,9 @@ class DuckDuckGoSearchTool:
4343
client: DDGS
4444
"""The DuckDuckGo search client."""
4545

46-
max_results: int | None = None
46+
_: KW_ONLY
47+
48+
max_results: int | None
4749
"""The maximum number of results. If None, returns results only from the first response."""
4850

4951
async def __call__(self, query: str) -> list[DuckDuckGoResult]:

0 commit comments

Comments
 (0)