Skip to content

Commit 0faad35

Browse files
committed
Make many more dataclasses kw-only
1 parent e0e3798 commit 0faad35

File tree

21 files changed

+86
-80
lines changed

21 files changed

+86
-80
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
"""
7979

8080

81-
@dataclasses.dataclass
81+
@dataclasses.dataclass(kw_only=True)
8282
class GraphAgentState:
8383
"""State kept across the execution of the agent graph."""
8484

@@ -99,7 +99,7 @@ def increment_retries(self, max_result_retries: int, error: BaseException | None
9999
raise exceptions.UnexpectedModelBehavior(message)
100100

101101

102-
@dataclasses.dataclass
102+
@dataclasses.dataclass(kw_only=True)
103103
class GraphAgentDeps(Generic[DepsT, OutputDataT]):
104104
"""Dependencies/config passed to the agent graph."""
105105

@@ -151,7 +151,7 @@ def is_agent_node(
151151
return isinstance(node, AgentNode)
152152

153153

154-
@dataclasses.dataclass
154+
@dataclasses.dataclass(kw_only=True)
155155
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
156156
"""The node that handles the user prompt and instructions."""
157157

@@ -308,18 +308,22 @@ async def _reevaluate_dynamic_prompts(
308308
):
309309
updated_part_content = await runner.run(run_context)
310310
msg.parts[i] = _messages.SystemPromptPart(
311-
updated_part_content, dynamic_ref=part.dynamic_ref
311+
content=updated_part_content, dynamic_ref=part.dynamic_ref
312312
)
313313

314314
async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]:
315315
"""Build the initial messages for the conversation."""
316-
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self.system_prompts]
316+
messages: list[_messages.ModelRequestPart] = [
317+
_messages.SystemPromptPart(content=p) for p in self.system_prompts
318+
]
317319
for sys_prompt_runner in self.system_prompt_functions:
318320
prompt = await sys_prompt_runner.run(run_context)
319321
if sys_prompt_runner.dynamic:
320-
messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
322+
messages.append(
323+
_messages.SystemPromptPart(content=prompt, dynamic_ref=sys_prompt_runner.function.__qualname__)
324+
)
321325
else:
322-
messages.append(_messages.SystemPromptPart(prompt))
326+
messages.append(_messages.SystemPromptPart(content=prompt))
323327
return messages
324328

325329

@@ -359,8 +363,8 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
359363

360364
request: _messages.ModelRequest
361365

362-
_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False)
363-
_did_stream: bool = field(default=False, repr=False)
366+
_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None)
367+
_did_stream: bool = field(repr=False, init=False, default=False)
364368

365369
async def run(
366370
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
@@ -389,13 +393,13 @@ async def stream(
389393
self._did_stream = True
390394
ctx.state.usage.requests += 1
391395
agent_stream = result.AgentStream[DepsT, T](
392-
streamed_response,
393-
ctx.deps.output_schema,
394-
model_request_parameters,
395-
ctx.deps.output_validators,
396-
build_run_context(ctx),
397-
ctx.deps.usage_limits,
398-
ctx.deps.tool_manager,
396+
_raw_stream_response=streamed_response,
397+
_output_schema=ctx.deps.output_schema,
398+
_model_request_parameters=model_request_parameters,
399+
_output_validators=ctx.deps.output_validators,
400+
_run_ctx=build_run_context(ctx),
401+
_usage_limits=ctx.deps.usage_limits,
402+
_tool_manager=ctx.deps.tool_manager,
399403
)
400404
yield agent_stream
401405
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
@@ -475,9 +479,9 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]):
475479

476480
model_response: _messages.ModelResponse
477481

478-
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False)
482+
_events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, init=False, repr=False)
479483
_next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field(
480-
default=None, repr=False
484+
default=None, init=False, repr=False
481485
)
482486

483487
async def run(

pydantic_ai_slim/pydantic_ai/_function_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
__all__ = ('function_schema',)
3131

3232

33-
@dataclass
33+
@dataclass(kw_only=True)
3434
class FunctionSchema:
3535
"""Internal information about a function schema."""
3636

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""Type variable for agent dependencies."""
1919

2020

21-
@dataclasses.dataclass(repr=False)
21+
@dataclasses.dataclass(repr=False, kw_only=True)
2222
class RunContext(Generic[AgentDepsT]):
2323
"""Information about the current call."""
2424

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,13 +1428,13 @@ class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
14281428
def __init__(
14291429
self,
14301430
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
1431-
max_retries: int = 1,
14321431
*,
1432+
max_retries: int = 1,
14331433
id: str | None = None,
14341434
output_schema: _output.BaseOutputSchema[Any],
14351435
):
14361436
self.output_schema = output_schema
1437-
super().__init__(tools, max_retries, id=id)
1437+
super().__init__(tools, max_retries=max_retries, id=id)
14381438

14391439
@property
14401440
def id(self) -> str:

pydantic_ai_slim/pydantic_ai/builtin_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
__all__ = ('AbstractBuiltinTool', 'WebSearchTool', 'WebSearchUserLocation', 'CodeExecutionTool', 'UrlContextTool')
1010

1111

12-
@dataclass
12+
@dataclass(kw_only=True)
1313
class AbstractBuiltinTool(ABC):
1414
"""A builtin tool that can be used by an agent.
1515
@@ -19,7 +19,7 @@ class AbstractBuiltinTool(ABC):
1919
"""
2020

2121

22-
@dataclass
22+
@dataclass(kw_only=True)
2323
class WebSearchTool(AbstractBuiltinTool):
2424
"""A builtin tool that allows your agent to search the web for information.
2525

pydantic_ai_slim/pydantic_ai/common_tools/duckduckgo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,20 @@ class DuckDuckGoResult(TypedDict):
3636
duckduckgo_ta = TypeAdapter(list[DuckDuckGoResult])
3737

3838

39-
@dataclass
39+
@dataclass(init=False)
4040
class DuckDuckGoSearchTool:
4141
"""The DuckDuckGo search tool."""
4242

4343
client: DDGS
4444
"""The DuckDuckGo search client."""
4545

46-
max_results: int | None = None
46+
max_results: int | None
4747
"""The maximum number of results. If None, returns results only from the first response."""
4848

49+
def __init__(self, client: DDGS, *, max_results: int | None = None):
50+
self.client = client
51+
self.max_results = max_results
52+
4953
async def __call__(self, query: str) -> list[DuckDuckGoResult]:
5054
"""Searches DuckDuckGo for the given query and returns the results.
5155

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class FunctionModel(Model):
4444
Apart from `__init__`, all methods are private or match those of the base class.
4545
"""
4646

47-
function: FunctionDef | None = None
48-
stream_function: StreamFunctionDef | None = None
47+
function: FunctionDef | None
48+
stream_function: StreamFunctionDef | None
4949

5050
_model_name: str = field(repr=False)
5151
_system: str = field(default='function', repr=False)

pydantic_ai_slim/pydantic_ai/providers/bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
) from _import_error
2929

3030

31-
@dataclass
31+
@dataclass(kw_only=True)
3232
class BedrockModelProfile(ModelProfile):
3333
"""Profile for models used with BedrockModel.
3434

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
"""An invariant TypeVar."""
4343

4444

45-
@dataclass
45+
@dataclass(kw_only=True)
4646
class AgentStream(Generic[AgentDepsT, OutputDataT]):
4747
_raw_stream_response: models.StreamedResponse
4848
_output_schema: OutputSchema[OutputDataT]

pydantic_ai_slim/pydantic_ai/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _traceparent(self, *, required: bool = True) -> str | None:
101101
def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]:
102102
"""The current context of the agent run."""
103103
return GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]](
104-
self._graph_run.state, self._graph_run.deps
104+
state=self._graph_run.state, deps=self._graph_run.deps
105105
)
106106

107107
@property

0 commit comments

Comments
 (0)