Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/graph.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class Email:
@dataclass
class State:
user: User
write_agent_messages: list[ModelMessage] = field(default_factory=list)
write_agent_messages: list[ModelMessage] = field(default_factory=list[ModelMessage])


email_writer_agent = Agent(
Expand Down Expand Up @@ -669,8 +669,8 @@ ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True)
@dataclass
class QuestionState:
question: str | None = None
ask_agent_messages: list[ModelMessage] = field(default_factory=list)
evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)
ask_agent_messages: list[ModelMessage] = field(default_factory=list[ModelMessage])
evaluate_agent_messages: list[ModelMessage] = field(default_factory=list[ModelMessage])


@dataclass
Expand Down Expand Up @@ -912,8 +912,8 @@ ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True)
@dataclass
class QuestionState:
question: str | None = None
ask_agent_messages: list[ModelMessage] = field(default_factory=list)
evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)
ask_agent_messages: list[ModelMessage] = field(default_factory=list[ModelMessage])
evaluate_agent_messages: list[ModelMessage] = field(default_factory=list[ModelMessage])


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class Step(BaseModel):
class Plan(BaseModel):
"""Represents a plan with multiple steps."""

steps: list[Step] = Field(default_factory=list, description='The steps in the plan')
steps: list[Step] = Field(
default_factory=list[Step], description='The steps in the plan'
)


class JSONPatchOp(BaseModel):
Expand Down
6 changes: 3 additions & 3 deletions examples/pydantic_ai_examples/ag_ui/api/shared_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ class Recipe(BaseModel):
description='The skill level required for the recipe',
)
special_preferences: list[SpecialPreferences] = Field(
default_factory=list,
default_factory=list[SpecialPreferences],
description='Any special preferences for the recipe',
)
cooking_time: CookingTime = Field(
default=CookingTime.FIVE_MIN, description='The cooking time of the recipe'
)
ingredients: list[Ingredient] = Field(
default_factory=list,
default_factory=list[Ingredient],
description='Ingredients for the recipe',
)
instructions: list[str] = Field(
default_factory=list, description='Instructions for the recipe'
default_factory=list[str], description='Instructions for the recipe'
)


Expand Down
5 changes: 3 additions & 2 deletions examples/pydantic_ai_examples/data_analyst.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import cast

import datasets
import duckdb
Expand All @@ -9,7 +10,7 @@

@dataclass
class AnalystAgentDeps:
output: dict[str, pd.DataFrame] = field(default_factory=dict)
output: dict[str, pd.DataFrame] = field(default_factory=dict[str, pd.DataFrame])

def store(self, value: pd.DataFrame) -> str:
"""Store the output in deps and return the reference such as Out[1] to be used by the LLM."""
Expand Down Expand Up @@ -47,7 +48,7 @@ def load_dataset(
"""
# begin load data from hf
builder = datasets.load_dataset_builder(path) # pyright: ignore[reportUnknownMemberType]
splits: dict[str, datasets.SplitInfo] = builder.info.splits or {} # pyright: ignore[reportUnknownMemberType]
splits = cast(dict[str, datasets.SplitInfo], builder.info.splits or {}) # pyright: ignore[reportUnknownMemberType]
if split not in splits:
raise ModelRetry(
f'{split} is not valid for dataset {path}. Valid splits are {",".join(splits.keys())}'
Expand Down
6 changes: 4 additions & 2 deletions examples/pydantic_ai_examples/question_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
@dataclass
class QuestionState:
question: str | None = None
ask_agent_messages: list[ModelMessage] = field(default_factory=list)
evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)
ask_agent_messages: list[ModelMessage] = field(default_factory=list[ModelMessage])
evaluate_agent_messages: list[ModelMessage] = field(
default_factory=list[ModelMessage]
)


@dataclass
Expand Down
21 changes: 16 additions & 5 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,16 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
deferred_tool_results: DeferredToolResults | None = None

instructions: str | None = None
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(default_factory=list)
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(
default_factory=list[_system_prompt.SystemPromptRunner[DepsT]]
)

system_prompts: tuple[str, ...] = dataclasses.field(default_factory=tuple)
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(default_factory=list)
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(
default_factory=list[_system_prompt.SystemPromptRunner[DepsT]]
)
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(
default_factory=dict
default_factory=dict[str, _system_prompt.SystemPromptRunner[DepsT]]
)

async def run( # noqa: C901
Expand Down Expand Up @@ -929,6 +933,14 @@ async def handle_call_or_result(

pending = tasks
while pending:
pending = cast(
list[
Task[
tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]
]
],
pending,
)
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
index = tasks.index(task)
Expand Down Expand Up @@ -1125,8 +1137,7 @@ async def _process_message_history(
if takes_ctx:
messages = await processor(run_context, messages)
else:
async_processor = cast(_HistoryProcessorAsync, processor)
messages = await async_processor(messages)
messages = await processor(messages)
else:
if takes_ctx:
sync_processor_with_ctx = cast(_HistoryProcessorSyncWithCtx[DepsT], processor)
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class FunctionSchema:
takes_ctx: bool
is_async: bool
single_arg_name: str | None = None
positional_fields: list[str] = field(default_factory=list)
positional_fields: list[str] = field(default_factory=list[str])
var_positional_field: str | None = None

async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any:
Expand Down
13 changes: 8 additions & 5 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,11 @@ def build( # noqa: C901
if len(outputs) == 0 and allows_deferred_tools:
raise UserError('At least one output type must be provided other than `DeferredToolRequests`.')

if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
if output := next((output for output in outputs if isinstance(output, NativeOutput)), None): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType]
if len(outputs) > 1:
raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover

output = cast(NativeOutput[OutputDataT], output)
return NativeOutputSchema(
processor=cls._build_processor(
_flatten_output_spec(output.outputs),
Expand All @@ -282,10 +283,11 @@ def build( # noqa: C901
),
allows_deferred_tools=allows_deferred_tools,
)
elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None): # pyright: ignore[reportUnknownVariableType,reportUnknownArgumentType]
if len(outputs) > 1:
raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover

output = cast(PromptedOutput[OutputDataT], output)
return PromptedOutputSchema(
processor=cls._build_processor(
_flatten_output_spec(output.outputs),
Expand All @@ -303,9 +305,9 @@ def build( # noqa: C901
if output is str:
text_outputs.append(cast(type[str], output))
elif isinstance(output, TextOutput):
text_outputs.append(output)
text_outputs.append(cast(TextOutput[OutputDataT], output))
elif isinstance(output, ToolOutput):
tool_outputs.append(output)
tool_outputs.append(cast(ToolOutput[OutputDataT], output))
elif isinstance(output, NativeOutput):
# We can never get here because this is checked for above.
raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
Expand Down Expand Up @@ -936,6 +938,7 @@ def build(
description = None
strict = None
if isinstance(output, ToolOutput):
output = cast(ToolOutput[OutputDataT], output)
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
name = output.name
description = output.description
Expand Down Expand Up @@ -1033,7 +1036,7 @@ def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem
def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]:
outputs: Sequence[OutputSpec[T]]
if isinstance(output_spec, Sequence):
outputs = output_spec
outputs = cast(Sequence[OutputSpec[T]], output_spec)
else:
outputs = (output_spec,)

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ class ModelResponsePartsManager:
Parts are generally added and/or updated by providing deltas, which are tracked by vendor-specific IDs.
"""

_parts: list[ManagedPart] = field(default_factory=list, init=False)
_parts: list[ManagedPart] = field(default_factory=list[ManagedPart], init=False)
"""A list of parts (text or tool calls) that make up the current state of the model's response."""
_vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict, init=False)
_vendor_id_to_part_index: dict[VendorId, int] = field(default_factory=dict[VendorId, int], init=False)
"""Maps a vendor's "part" ID (if provided) to the index in `_parts` where that part resides."""

def get_parts(self) -> list[ModelResponsePart]:
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ class RunContext(Generic[AgentDepsT]):
"""LLM usage associated with the run."""
prompt: str | Sequence[_messages.UserContent] | None = None
"""The original user prompt passed to the run."""
messages: list[_messages.ModelMessage] = field(default_factory=list)
messages: list[_messages.ModelMessage] = field(default_factory=list[_messages.ModelMessage])
"""Messages exchanged in the conversation so far."""
tracer: Tracer = field(default_factory=NoOpTracer)
"""The tracer to use for tracing the run."""
trace_include_content: bool = False
"""Whether to include the content of the messages in the trace."""
instrumentation_version: int = DEFAULT_INSTRUMENTATION_VERSION
"""Instrumentation settings version, if instrumentation is enabled."""
retries: dict[str, int] = field(default_factory=dict)
retries: dict[str, int] = field(default_factory=dict[str, int])
"""Number of retries for each tool so far."""
tool_call_id: str | None = None
"""The ID of the tool call."""
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ToolManager(Generic[AgentDepsT]):
"""The agent run context for a specific run step."""
tools: dict[str, ToolsetTool[AgentDepsT]] | None = None
"""The cached tools for this run step."""
failed_tools: set[str] = field(default_factory=set)
failed_tools: set[str] = field(default_factory=set[str])
"""Names of tools that failed in this run step."""

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from datetime import datetime, timezone
from functools import partial
from types import GenericAlias
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, get_args, get_origin, overload
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, cast, get_args, get_origin, overload

from anyio.to_thread import run_sync
from pydantic import BaseModel, TypeAdapter
Expand Down Expand Up @@ -184,7 +184,7 @@ async def async_iter_groups() -> AsyncIterator[list[T]]:
if task is None:
# aiter.__anext__() returns an Awaitable[T], not a Coroutine which asyncio.create_task expects
# so far, this doesn't seem to be a problem
task = asyncio.create_task(aiterator.__anext__()) # pyright: ignore[reportArgumentType]
task = cast(asyncio.Task[T], asyncio.create_task(aiterator.__anext__())) # pyright: ignore[reportArgumentType]

# we use asyncio.wait to avoid cancelling the coroutine if it's not done
done, _ = await asyncio.wait((task,), timeout=wait_time)
Expand Down Expand Up @@ -366,7 +366,7 @@ def is_async_callable(obj: Any) -> Any:
while isinstance(obj, functools.partial):
obj = obj.func

return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__))


def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None:
Expand All @@ -386,7 +386,7 @@ def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, s

# Handle arrays
if 'items' in s and isinstance(s['items'], dict):
items: dict[str, Any] = s['items']
items = cast(dict[str, Any], s['items'])
_update_mapped_json_schema_refs(items, name_mapping)
if 'prefixItems' in s:
prefix_items: list[dict[str, Any]] = s['prefixItems']
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ class _RequestStreamContext:
message_id: str = ''
part_end: BaseEvent | None = None
thinking: bool = False
builtin_tool_call_ids: dict[str, str] = field(default_factory=dict)
builtin_tool_call_ids: dict[str, str] = field(default_factory=dict[str, str])

def new_message_id(self) -> str:
"""Generate a new message ID for the request stream.
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class StreamedResponseSync:

_async_stream_cm: AbstractAsyncContextManager[StreamedResponse]
_queue: queue.Queue[messages.ModelResponseStreamEvent | Exception | None] = field(
default_factory=queue.Queue, init=False
default_factory=queue.Queue[messages.ModelResponseStreamEvent | Exception | None], init=False
)
_thread: threading.Thread | None = field(default=None, init=False)
_stream_response: StreamedResponse | None = field(default=None, init=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import AsyncIterator, Callable, Sequence
from contextlib import AbstractAsyncContextManager
from dataclasses import replace
from typing import Any
from typing import Any, cast

from pydantic.errors import PydanticUserError
from temporalio.client import ClientConfig, Plugin as ClientPlugin, WorkflowHistory
Expand Down Expand Up @@ -117,7 +117,7 @@ def init_worker_plugin(self, next: WorkerPlugin) -> None:
self.next_worker_plugin = next

def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType]
activities = cast(Sequence[Callable[..., Any]], config.get('activities', [])) # pyright: ignore[reportUnknownMemberType]
# Activities are checked for name conflicts by Temporal.
config['activities'] = [*activities, *self.agent.temporal_activities]
return self.next_worker_plugin.configure_worker(config)
Expand Down
8 changes: 5 additions & 3 deletions pydantic_ai_slim/pydantic_ai/format_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ class _ToXml:
include_field_info: Literal['once'] | bool
# a map of Pydantic and dataclasses Field paths to their metadata:
# a field unique string representation and its class
_fields_info: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] = field(default_factory=dict)
_fields_info: dict[str, tuple[str, FieldInfo | ComputedFieldInfo]] = field(
default_factory=dict[str, tuple[str, FieldInfo | ComputedFieldInfo]]
)
# keep track of fields we have extracted attributes from
_included_fields: set[str] = field(default_factory=set)
_included_fields: set[str] = field(default_factory=set[str])
# keep track of class names for dataclasses and Pydantic models, that occur in lists
_element_names: dict[str, str] = field(default_factory=dict)
_element_names: dict[str, str] = field(default_factory=dict[str, str])
# flag for parsing dataclasses and Pydantic models once
_is_info_extracted: bool = False
_FIELD_ATTRIBUTES = ('title', 'description')
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pydantic
import pydantic_core
from genai_prices import calc_price, types as genai_types
from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage]
from opentelemetry._events import Event
from typing_extensions import deprecated

from . import _otel_messages, _utils
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,12 +293,12 @@
class ModelRequestParameters:
"""Configuration for an agent's request to a model, specifically related to tools and output handling."""

function_tools: list[ToolDefinition] = field(default_factory=list)
builtin_tools: list[AbstractBuiltinTool] = field(default_factory=list)
function_tools: list[ToolDefinition] = field(default_factory=list[ToolDefinition])
builtin_tools: list[AbstractBuiltinTool] = field(default_factory=list[AbstractBuiltinTool])

output_mode: OutputMode = 'text'
output_object: OutputObjectDefinition | None = None
output_tools: list[ToolDefinition] = field(default_factory=list)
output_tools: list[ToolDefinition] = field(default_factory=list[ToolDefinition])
allow_text_output: bool = True

@cached_property
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import AsyncIterator, Callable
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from opentelemetry.trace import get_current_span

Expand Down Expand Up @@ -47,6 +47,7 @@ def __init__(
self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]

if isinstance(fallback_on, tuple):
fallback_on = cast(tuple[type[Exception], ...], fallback_on)
self._fallback_on = _default_fallback_condition_factory(fallback_on)
else:
self._fallback_on = fallback_on
Expand Down
8 changes: 4 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from genai_prices.types import PriceCalculation
from opentelemetry._events import (
Event, # pyright: ignore[reportPrivateImportUsage]
EventLogger, # pyright: ignore[reportPrivateImportUsage]
EventLoggerProvider, # pyright: ignore[reportPrivateImportUsage]
get_event_logger_provider, # pyright: ignore[reportPrivateImportUsage]
Event,
EventLogger,
EventLoggerProvider,
get_event_logger_provider,
)
from opentelemetry.metrics import MeterProvider, get_meter_provider
from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
Expand Down
Loading