Skip to content

[Draft] Add input model to Agent #2393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT,


def agent_to_a2a(
agent: Agent[AgentDepsT, OutputDataT],
agent: Agent[AgentDepsT, OutputDataT, None],
*,
storage: Storage | None = None,
broker: Broker | None = None,
Expand Down
28 changes: 17 additions & 11 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing_extensions import TypeGuard, TypeVar, assert_never

from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
from pydantic_ai._run_context import InputDataT
from pydantic_ai._tool_manager import ToolManager
from pydantic_ai._utils import is_async_callable, run_in_executor
from pydantic_graph import BaseNode, Graph, GraphRunContext
Expand Down Expand Up @@ -92,10 +93,11 @@ def increment_retries(self, max_result_retries: int, error: BaseException | None


@dataclasses.dataclass
class GraphAgentDeps(Generic[DepsT, OutputDataT]):
class GraphAgentDeps(Generic[DepsT, OutputDataT, InputDataT]):
"""Dependencies/config passed to the agent graph."""

user_deps: DepsT
user_input: InputDataT

prompt: str | Sequence[_messages.UserContent] | None
new_message_index: int
Expand All @@ -105,7 +107,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
usage_limits: _usage.UsageLimits
max_result_retries: int
end_strategy: EndStrategy
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]]
get_instructions: Callable[[RunContext[DepsT, InputDataT]], Awaitable[str | None]]

output_schema: _output.OutputSchema[OutputDataT]
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
Expand Down Expand Up @@ -148,11 +150,11 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
user_prompt: str | Sequence[_messages.UserContent] | None

instructions: str | None
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT, Any]]

system_prompts: tuple[str, ...]
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]]
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT, Any]]
system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT, Any]]

async def run(
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
Expand All @@ -175,7 +177,7 @@ async def _prepare_messages(
self,
user_prompt: str | Sequence[_messages.UserContent] | None,
message_history: list[_messages.ModelMessage] | None,
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]],
get_instructions: Callable[[RunContext[DepsT, Any]], Awaitable[str | None]],
run_context: RunContext[DepsT],
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
try:
Expand Down Expand Up @@ -214,7 +216,7 @@ async def _prepare_messages(
return messages, _messages.ModelRequest(parts, instructions=instructions)

async def _reevaluate_dynamic_prompts(
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT, Any]
) -> None:
"""Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
# Only proceed if there's at least one dynamic runner.
Expand Down Expand Up @@ -554,9 +556,9 @@ async def _handle_text_response(
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])


def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT, Any]:
"""Build a `RunContext` object from the current agent graph run context."""
return RunContext[DepsT](
return RunContext[DepsT, Any](
deps=ctx.deps.user_deps,
model=ctx.deps.model,
usage=ctx.state.usage,
Expand All @@ -566,6 +568,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
trace_include_content=ctx.deps.instrumentation_settings is not None
and ctx.deps.instrumentation_settings.include_content,
run_step=ctx.state.run_step,
input=ctx.deps.user_input,
)


Expand Down Expand Up @@ -843,14 +846,17 @@ def build_agent_graph(
name: str | None,
deps_type: type[DepsT],
output_type: OutputSpec[OutputT],
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
input_type: type[InputDataT],
) -> Graph[
GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT], InputDataT], result.FinalResult[OutputT]
]:
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
nodes = (
UserPromptNode[DepsT],
ModelRequestNode[DepsT],
CallToolsNode[DepsT],
)
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]](
graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any, Any], result.FinalResult[OutputT]](
nodes=nodes,
name=name or 'Agent',
state_type=GraphAgentState,
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def cli( # noqa: C901

async def run_chat(
stream: bool,
agent: Agent[AgentDepsT, OutputDataT],
agent: Agent[AgentDepsT, OutputDataT, Any],
console: Console,
code_theme: str,
prog_name: str,
Expand Down
7 changes: 6 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True)
"""Type variable for agent dependencies."""

InputDataT = TypeVar('InputDataT', default=None)
"""Type variable for the input data of a run."""


@dataclasses.dataclass(repr=False)
class RunContext(Generic[AgentDepsT]):
class RunContext(Generic[AgentDepsT, InputDataT]):
"""Information about the current call."""

deps: AgentDepsT
"""Dependencies for the agent."""
input: InputDataT
"""The input data for the run."""
model: Model
"""The model used in this run."""
usage: Usage
Expand Down
8 changes: 4 additions & 4 deletions pydantic_ai_slim/pydantic_ai/_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from typing import Any, Callable, Generic, cast

from . import _utils
from ._run_context import AgentDepsT, RunContext
from ._run_context import AgentDepsT, InputDataT, RunContext
from .tools import SystemPromptFunc


@dataclass
class SystemPromptRunner(Generic[AgentDepsT]):
function: SystemPromptFunc[AgentDepsT]
class SystemPromptRunner(Generic[AgentDepsT, InputDataT]):
function: SystemPromptFunc[AgentDepsT, InputDataT]
dynamic: bool = False
_takes_ctx: bool = field(init=False)
_is_async: bool = field(init=False)
Expand All @@ -21,7 +21,7 @@ def __post_init__(self):
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
self._is_async = _utils.is_async_callable(self.function)

async def run(self, run_context: RunContext[AgentDepsT]) -> str:
async def run(self, run_context: RunContext[AgentDepsT, InputDataT]) -> str:
if self._takes_ctx:
args = (run_context,)
else:
Expand Down
6 changes: 4 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ class ToolManager(Generic[AgentDepsT]):
"""Names of tools that failed in this run step."""

@classmethod
async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
async def build(
cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT, Any]
) -> ToolManager[AgentDepsT]:
"""Build a new tool manager for a specific run step."""
return cls(
ctx=ctx,
toolset=toolset,
tools=await toolset.get_tools(ctx),
)

async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
async def for_run_step(self, ctx: RunContext[AgentDepsT, Any]) -> ToolManager[AgentDepsT]:
"""Build a new tool manager for the next run step, carrying over the retries from the current run step."""
retries = {
failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1 for failed_tool_name in self.failed_tools
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):

def __init__(
self,
agent: Agent[AgentDepsT, OutputDataT],
agent: Agent[AgentDepsT, OutputDataT, Any],
*,
# Agent.iter parameters.
output_type: OutputSpec[Any] | None = None,
Expand Down
Loading
Loading