diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 8a916b5cc..cd4b76e0d 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -70,7 +70,7 @@ async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT, def agent_to_a2a( - agent: Agent[AgentDepsT, OutputDataT], + agent: Agent[AgentDepsT, OutputDataT, None], *, storage: Storage | None = None, broker: Broker | None = None, diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 12e6e07fe..d5b1aeb95 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -14,6 +14,7 @@ from typing_extensions import TypeGuard, TypeVar, assert_never from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore +from pydantic_ai._run_context import InputDataT from pydantic_ai._tool_manager import ToolManager from pydantic_ai._utils import is_async_callable, run_in_executor from pydantic_graph import BaseNode, Graph, GraphRunContext @@ -92,10 +93,11 @@ def increment_retries(self, max_result_retries: int, error: BaseException | None @dataclasses.dataclass -class GraphAgentDeps(Generic[DepsT, OutputDataT]): +class GraphAgentDeps(Generic[DepsT, OutputDataT, InputDataT]): """Dependencies/config passed to the agent graph.""" user_deps: DepsT + user_input: InputDataT prompt: str | Sequence[_messages.UserContent] | None new_message_index: int @@ -105,7 +107,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): usage_limits: _usage.UsageLimits max_result_retries: int end_strategy: EndStrategy - get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]] + get_instructions: Callable[[RunContext[DepsT, InputDataT]], Awaitable[str | None]] output_schema: _output.OutputSchema[OutputDataT] output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] @@ -148,11 +150,11 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]): user_prompt: str | Sequence[_messages.UserContent] | None instructions: str | None - instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]] + instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT, Any]] system_prompts: tuple[str, ...] - system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] - system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] + system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT, Any]] + system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT, Any]] async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] @@ -175,7 +177,7 @@ async def _prepare_messages( self, user_prompt: str | Sequence[_messages.UserContent] | None, message_history: list[_messages.ModelMessage] | None, - get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]], + get_instructions: Callable[[RunContext[DepsT, Any]], Awaitable[str | None]], run_context: RunContext[DepsT], ) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]: try: @@ -214,7 +216,7 @@ async def _prepare_messages( return messages, _messages.ModelRequest(parts, instructions=instructions) async def _reevaluate_dynamic_prompts( - self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT] + self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT, Any] ) -> None: """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function.""" # Only proceed if there's at least one dynamic runner. @@ -554,9 +556,9 @@ async def _handle_text_response( return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), []) -def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: +def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT, Any]: """Build a `RunContext` object from the current agent graph run context.""" - return RunContext[DepsT]( + return RunContext[DepsT, Any]( deps=ctx.deps.user_deps, model=ctx.deps.model, usage=ctx.state.usage, @@ -566,6 +568,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT trace_include_content=ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content, run_step=ctx.state.run_step, + input=ctx.deps.user_input, ) @@ -843,14 +846,17 @@ def build_agent_graph( name: str | None, deps_type: type[DepsT], output_type: OutputSpec[OutputT], -) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]: + input_type: type[InputDataT], +) -> Graph[ + GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT], InputDataT], result.FinalResult[OutputT] +]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( UserPromptNode[DepsT], ModelRequestNode[DepsT], CallToolsNode[DepsT], ) - graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[OutputT]]( + graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any, Any], result.FinalResult[OutputT]]( nodes=nodes, name=name or 'Agent', state_type=GraphAgentState, diff --git a/pydantic_ai_slim/pydantic_ai/_cli.py b/pydantic_ai_slim/pydantic_ai/_cli.py index ae4f6ff6f..d1b22a329 100644 --- a/pydantic_ai_slim/pydantic_ai/_cli.py +++ b/pydantic_ai_slim/pydantic_ai/_cli.py @@ -220,7 +220,7 @@ def cli( # noqa: C901 async def run_chat( stream: bool, - agent: Agent[AgentDepsT, OutputDataT], + agent: Agent[AgentDepsT, OutputDataT, Any], console: Console, code_theme: str, prog_name: str, diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index afad0e60e..c3daf2716 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -17,13 +17,18 @@ AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True) """Type variable for agent dependencies.""" +InputDataT = TypeVar('InputDataT', default=None) +"""Type variable for the input data of a run.""" + @dataclasses.dataclass(repr=False) -class RunContext(Generic[AgentDepsT]): +class RunContext(Generic[AgentDepsT, InputDataT]): """Information about the current call.""" deps: AgentDepsT """Dependencies for the agent.""" + input: InputDataT + """The input data for the run.""" model: Model """The model used in this run.""" usage: Usage diff --git a/pydantic_ai_slim/pydantic_ai/_system_prompt.py b/pydantic_ai_slim/pydantic_ai/_system_prompt.py index 55bad733c..5481fa87b 100644 --- a/pydantic_ai_slim/pydantic_ai/_system_prompt.py +++ b/pydantic_ai_slim/pydantic_ai/_system_prompt.py @@ -6,13 +6,13 @@ from typing import Any, Callable, Generic, cast from . import _utils -from ._run_context import AgentDepsT, RunContext +from ._run_context import AgentDepsT, InputDataT, RunContext from .tools import SystemPromptFunc @dataclass -class SystemPromptRunner(Generic[AgentDepsT]): - function: SystemPromptFunc[AgentDepsT] +class SystemPromptRunner(Generic[AgentDepsT, InputDataT]): + function: SystemPromptFunc[AgentDepsT, InputDataT] dynamic: bool = False _takes_ctx: bool = field(init=False) _is_async: bool = field(init=False) @@ -21,7 +21,7 @@ def __post_init__(self): self._takes_ctx = len(inspect.signature(self.function).parameters) > 0 self._is_async = _utils.is_async_callable(self.function) - async def run(self, run_context: RunContext[AgentDepsT]) -> str: + async def run(self, run_context: RunContext[AgentDepsT, InputDataT]) -> str: if self._takes_ctx: args = (run_context,) else: diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index 612be4176..8cc010920 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -31,7 +31,9 @@ class ToolManager(Generic[AgentDepsT]): """Names of tools that failed in this run step.""" @classmethod - async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]: + async def build( + cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT, Any] + ) -> ToolManager[AgentDepsT]: """Build a new tool manager for a specific run step.""" return cls( ctx=ctx, @@ -39,7 +41,7 @@ async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[Agent tools=await toolset.get_tools(ctx), ) - async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]: + async def for_run_step(self, ctx: RunContext[AgentDepsT, Any]) -> ToolManager[AgentDepsT]: """Build a new tool manager for the next run step, carrying over the retries from the current run step.""" retries = { failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1 for failed_tool_name in self.failed_tools diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 1ea6f16eb..8f33f99a2 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -117,7 +117,7 @@ class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette): def __init__( self, - agent: Agent[AgentDepsT, OutputDataT], + agent: Agent[AgentDepsT, OutputDataT, Any], *, # Agent.iter parameters. output_type: OutputSpec[Any] | None = None, diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 271d0ebc7..71a3483b3 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -33,6 +33,7 @@ from ._agent_graph import HistoryProcessor from ._output import OutputToolset from ._tool_manager import ToolManager +from .format_prompt import format_as_xml from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model from .output import OutputDataT, OutputSpec from .profiles import ModelProfile @@ -96,10 +97,13 @@ RunOutputDataT = TypeVar('RunOutputDataT') """Type variable for the result data of a run where `output_type` was customized on the run call.""" +InputDataT = TypeVar('InputDataT', default=None) +"""Type variable for the input data of a run.""" + @final @dataclasses.dataclass(init=False) -class Agent(Generic[AgentDepsT, OutputDataT]): +class Agent(Generic[AgentDepsT, OutputDataT, InputDataT]): """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT] @@ -151,17 +155,23 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _instrument_default: ClassVar[InstrumentationSettings | bool] = False _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) + _input_type: type[InputDataT] = dataclasses.field(repr=False) + _deprecated_result_tool_name: str | None = dataclasses.field(repr=False) _deprecated_result_tool_description: str | None = dataclasses.field(repr=False) _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False) _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) _instructions: str | None = dataclasses.field(repr=False) - _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) + _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT, InputDataT]] = dataclasses.field( + repr=False + ) _system_prompts: tuple[str, ...] = dataclasses.field(repr=False) - _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) - _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( + _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT, InputDataT]] = dataclasses.field( repr=False ) + _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT, InputDataT]] = ( + dataclasses.field(repr=False) + ) _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False) _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False) @@ -184,6 +194,7 @@ def __init__( | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] | None = None, system_prompt: str | Sequence[str] = (), + input_type: type[InputDataT] = NoneType, deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, @@ -213,6 +224,7 @@ def __init__( | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] | None = None, system_prompt: str | Sequence[str] = (), + input_type: type[InputDataT] = NoneType, deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, @@ -242,6 +254,7 @@ def __init__( | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] | None = None, system_prompt: str | Sequence[str] = (), + input_type: type[InputDataT] = NoneType, deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, @@ -270,6 +283,7 @@ def __init__( | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] | None = None, system_prompt: str | Sequence[str] = (), + input_type: type[InputDataT] = NoneType, deps_type: type[AgentDepsT] = NoneType, name: str | None = None, model_settings: ModelSettings | None = None, @@ -300,6 +314,7 @@ def __init__( parameterize the agent, and therefore get the best out of static type checking. If you're not using deps, but want type checking to pass, you can set `deps=None` to satisfy Pyright or add a type hint `: Agent[None, ]`. + input_type: The type of the input data, provided via the run context to dynamic instructions / system prompts. name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame when the agent is first run. model_settings: Optional model request settings to use for this agent's runs, by default. @@ -351,6 +366,8 @@ def __init__( self.instrument = instrument + self._input_type = input_type + self._deps_type = deps_type self._deprecated_result_tool_name = _deprecated_kwargs.pop('result_tool_name', None) @@ -405,7 +422,7 @@ def __init__( if isinstance(instruction, str): self._instructions += instruction + '\n' else: - self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction)) + self._instructions_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT, Any](instruction)) self._instructions = self._instructions.strip() or None self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) @@ -448,6 +465,7 @@ async def run( output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, + input: InputDataT = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, @@ -464,6 +482,7 @@ async def run( output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, + input: InputDataT = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, @@ -481,6 +500,7 @@ async def run( result_type: type[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, + input: InputDataT = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, @@ -496,6 +516,7 @@ async def run( output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, + input: InputDataT = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, @@ -527,6 +548,7 @@ async def main(): output validators since output validators would expect an argument that matches the agent's output type. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. + input: Optional input model to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. @@ -551,6 +573,7 @@ async def main(): async with self.iter( user_prompt=user_prompt, output_type=output_type, + input=input, message_history=message_history, model=model, deps=deps, @@ -573,6 +596,7 @@ def iter( output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, + input: InputDataT = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, @@ -590,6 +614,7 @@ def iter( output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, + input: InputDataT = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, @@ -608,6 +633,7 @@ def iter( result_type: type[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, + input: InputDataT = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, @@ -624,6 +650,7 @@ async def iter( output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, + input: InputDataT = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, @@ -699,6 +726,8 @@ async def main(): output validators since output validators would expect an argument that matches the agent's output type. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. + input: Optional input model to use for this run. + input_type: Optional input type to use for this run. deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. @@ -741,9 +770,9 @@ async def main(): output_toolset.output_validators = output_validators # Build the graph - graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( - _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) - ) + graph: Graph[ + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any, Any], FinalResult[Any] + ] = _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_, self._input_type) # Build the initial state usage = usage or _usage.Usage() @@ -761,8 +790,9 @@ async def main(): instrumentation_settings = None tracer = NoOpTracer() - run_context = RunContext[AgentDepsT]( + run_context = RunContext[AgentDepsT, InputDataT]( deps=deps, + input=input, model=model_used, usage=usage, prompt=user_prompt, @@ -791,12 +821,15 @@ async def main(): }, ) - async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: + async def get_instructions(run_context: RunContext[AgentDepsT, InputDataT]) -> str | None: parts = [ self._instructions, *[await func.run(run_context) for func in self._instructions_functions], ] + if not self._instructions_functions and input is not None: + parts.append(format_as_xml(obj=input)) + model_profile = model_used.profile if isinstance(output_schema, _output.PromptedOutputSchema): instructions = output_schema.instructions(model_profile.prompted_output_template) @@ -807,7 +840,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: return None return '\n\n'.join(parts).strip() - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( + graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT, InputDataT]( user_deps=deps, prompt=user_prompt, new_message_index=new_message_index, @@ -823,6 +856,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: tracer=tracer, get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, + user_input=input, ) start_node = _agent_graph.UserPromptNode[AgentDepsT]( user_prompt=user_prompt, @@ -887,6 +921,7 @@ def run_sync( message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, + input: InputDataT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, @@ -903,6 +938,7 @@ def run_sync( message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, + input: InputDataT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, @@ -920,6 +956,7 @@ def run_sync( message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, + input: InputDataT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, @@ -935,6 +972,7 @@ def run_sync( message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, + input: InputDataT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, @@ -965,6 +1003,7 @@ def run_sync( message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. + input: Optional input to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. @@ -989,6 +1028,7 @@ def run_sync( self.run( user_prompt, output_type=output_type, + input=input, message_history=message_history, model=model, deps=deps, @@ -1057,6 +1097,7 @@ async def run_stream( # noqa C901 message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, + input: InputDataT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, @@ -1113,6 +1154,7 @@ async def main(): async with self.iter( user_prompt, output_type=output_type, + input=input, message_history=message_history, model=model, deps=deps, @@ -1230,13 +1272,13 @@ def override( @overload def instructions( - self, func: Callable[[RunContext[AgentDepsT]], str], / - ) -> Callable[[RunContext[AgentDepsT]], str]: ... + self, func: Callable[[RunContext[AgentDepsT, InputDataT]], str], / + ) -> Callable[[RunContext[AgentDepsT, InputDataT]], str]: ... @overload def instructions( - self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], / - ) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ... + self, func: Callable[[RunContext[AgentDepsT, InputDataT]], Awaitable[str]], / + ) -> Callable[[RunContext[AgentDepsT, InputDataT]], Awaitable[str]]: ... @overload def instructions(self, func: Callable[[], str], /) -> Callable[[], str]: ... @@ -1247,15 +1289,21 @@ def instructions(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Aw @overload def instructions( self, / - ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]]: ... + ) -> Callable[ + [_system_prompt.SystemPromptFunc[AgentDepsT, InputDataT]], + _system_prompt.SystemPromptFunc[AgentDepsT, InputDataT], + ]: ... def instructions( self, - func: _system_prompt.SystemPromptFunc[AgentDepsT] | None = None, + func: _system_prompt.SystemPromptFunc[AgentDepsT, InputDataT] | None = None, /, ) -> ( - Callable[[_system_prompt.SystemPromptFunc[AgentDepsT]], _system_prompt.SystemPromptFunc[AgentDepsT]] - | _system_prompt.SystemPromptFunc[AgentDepsT] + Callable[ + [_system_prompt.SystemPromptFunc[AgentDepsT, InputDataT]], + _system_prompt.SystemPromptFunc[AgentDepsT, InputDataT], + ] + | _system_prompt.SystemPromptFunc[AgentDepsT, InputDataT] ): """Decorator to register an instructions function. @@ -1285,14 +1333,14 @@ async def async_instructions(ctx: RunContext[str]) -> str: if func is None: def decorator( - func_: _system_prompt.SystemPromptFunc[AgentDepsT], - ) -> _system_prompt.SystemPromptFunc[AgentDepsT]: - self._instructions_functions.append(_system_prompt.SystemPromptRunner(func_)) + func_: _system_prompt.SystemPromptFunc[AgentDepsT, InputDataT], + ) -> _system_prompt.SystemPromptFunc[AgentDepsT, InputDataT]: + self._instructions_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT, InputDataT](func_)) return func_ return decorator else: - self._instructions_functions.append(_system_prompt.SystemPromptRunner(func)) + self._instructions_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT, InputDataT](func)) return func @overload @@ -1362,7 +1410,7 @@ async def async_system_prompt(ctx: RunContext[str]) -> str: def decorator( func_: _system_prompt.SystemPromptFunc[AgentDepsT], ) -> _system_prompt.SystemPromptFunc[AgentDepsT]: - runner = _system_prompt.SystemPromptRunner[AgentDepsT](func_, dynamic=dynamic) + runner = _system_prompt.SystemPromptRunner[AgentDepsT, Any](func_, dynamic=dynamic) self._system_prompt_functions.append(runner) if dynamic: # pragma: lax no cover self._system_prompt_dynamic_functions[func_.__qualname__] = runner @@ -1371,7 +1419,9 @@ def decorator( return decorator else: assert not dynamic, "dynamic can't be True in this case" - self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic)) + self._system_prompt_functions.append( + _system_prompt.SystemPromptRunner[AgentDepsT, Any](func, dynamic=dynamic) + ) return func @overload @@ -1656,7 +1706,7 @@ def _get_model(self, model: models.Model | models.KnownModelName | str | None) - return instrument_model(model_, instrument) - def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: + def _get_deps(self: Agent[T, OutputDataT, InputDataT], deps: T) -> T: """Get deps for a run. If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call. diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 14b5d12f3..11c592bfa 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -9,7 +9,7 @@ from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar from . import _function_schema, _utils -from ._run_context import AgentDepsT, RunContext +from ._run_context import AgentDepsT, InputDataT, RunContext __all__ = ( 'AgentDepsT', @@ -32,8 +32,8 @@ """Retrieval function param spec.""" SystemPromptFunc = Union[ - Callable[[RunContext[AgentDepsT]], str], - Callable[[RunContext[AgentDepsT]], Awaitable[str]], + Callable[[RunContext[AgentDepsT, InputDataT]], str], + Callable[[RunContext[AgentDepsT, InputDataT]], Awaitable[str]], Callable[[], str], Callable[[], Awaitable[str]], ] diff --git a/tests/test_agent.py b/tests/test_agent.py index 0b37d1041..ffcf28ec4 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1685,6 +1685,52 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) +def test_run_with_input(): + class InputModel(BaseModel): + name: str + + model = TestModel() + agent = Agent(model=model, input_type=InputModel) + + @agent.instructions + def instructions(run_context: RunContext[Any, InputModel]) -> str: + return f'Hello {run_context.input.name}' + + result: AgentRunResult[str] = agent.run_sync(input=InputModel(name='VALUE FROM INPUT MODEL')) + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[], instructions='Hello VALUE FROM INPUT MODEL'), + ModelResponse( + parts=[TextPart(content='success (no tool calls)')], + usage=Usage(requests=1, request_tokens=50, response_tokens=4, total_tokens=54), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ] + ) + + +def test_run_with_input_no_instructions(): + class InputModel(BaseModel): + name: str + + model = TestModel() + agent = Agent(model=model, input_type=InputModel) + + result: AgentRunResult[str] = agent.run_sync(input=InputModel(name='VALUE FROM INPUT MODEL')) + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[], instructions='VALUE FROM INPUT MODEL'), + ModelResponse( + parts=[TextPart(content='success (no tool calls)')], + usage=Usage(requests=1, request_tokens=50, response_tokens=4, total_tokens=54), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ] + ) + + def test_run_with_history_new(): m = TestModel() diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 1021b3151..6146141c7 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -67,7 +67,7 @@ def agent(model: Model, mcp_server: MCPServerStdio) -> Agent: @pytest.fixture def run_context(model: Model) -> RunContext[int]: - return RunContext(deps=0, model=model, usage=Usage()) + return RunContext(deps=0, model=model, usage=Usage(), input=None) async def test_stdio_server(run_context: RunContext[int]): diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index f217b34f4..df128ebb4 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -27,9 +27,10 @@ T = TypeVar('T') -def build_run_context(deps: T) -> RunContext[T]: +def build_run_context(deps: T) -> RunContext[T, None]: return RunContext( deps=deps, + input=None, model=TestModel(), usage=Usage(), prompt=None,