From d1517b08968c342e1eb2f018d84de8d754bacba0 Mon Sep 17 00:00:00 2001 From: shayan-devv Date: Mon, 10 Nov 2025 21:16:30 +0500 Subject: [PATCH 1/4] feat(agent): add structured output support for tools on limited models (e.g., Gemini) Enable using tools and structured outputs together on models (e.g., Gemini) that don't natively support both simultaneously. Introduce an opt-in parameter enable_structured_output_with_tools to the Agent class, which injects JSON formatting instructions into the system prompt for LitellmModel as a workaround. Changes: - Add enable_structured_output_with_tools parameter to Agent (default: False) - Implement prompt injection utilities in src/agents/util/_prompts.py - Update LitellmModel to inject JSON instructions when enabled - Extend model interfaces to accept enable_structured_output_with_tools - Add comprehensive unit tests (13 total) and one integration test - Add documentation in docs/models/structured_output_with_tools.md - Update docs/agents.md and docs/models/litellm.md with usage examples --- docs/agents.md | 591 +++++----- docs/models/litellm.md | 213 ++-- docs/models/structured_output_with_tools.md | 237 ++++ mkdocs.yml | 1 + src/agents/agent.py | 16 + src/agents/extensions/models/litellm_model.py | 28 +- src/agents/models/interface.py | 258 ++-- src/agents/models/openai_chatcompletions.py | 729 ++++++------ src/agents/models/openai_responses.py | 1034 +++++++++-------- src/agents/run.py | 2 + src/agents/util/_prompts.py | 117 ++ tests/fake_model.py | 688 +++++------ tests/test_agent_prompt.py | 199 ++-- tests/test_gemini_local.py | 169 +++ tests/test_streaming_tool_call_arguments.py | 748 ++++++------ tests/utils/test_prompts.py | 107 ++ tests/voice/test_workflow.py | 440 +++---- 17 files changed, 3167 insertions(+), 2410 deletions(-) create mode 100644 docs/models/structured_output_with_tools.md create mode 100644 src/agents/util/_prompts.py create mode 100644 tests/test_gemini_local.py create mode 100644 tests/utils/test_prompts.py diff --git a/docs/agents.md b/docs/agents.md index d401f53da..14b5df295 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -1,285 +1,306 @@ -# Agents - -Agents are the core building block in your apps. An agent is a large language model (LLM), configured with instructions and tools. - -## Basic configuration - -The most common properties of an agent you'll configure are: - -- `name`: A required string that identifies your agent. -- `instructions`: also known as a developer message or system prompt. -- `model`: which LLM to use, and optional `model_settings` to configure model tuning parameters like temperature, top_p, etc. -- `tools`: Tools that the agent can use to achieve its tasks. - -```python -from agents import Agent, ModelSettings, function_tool - -@function_tool -def get_weather(city: str) -> str: - """returns weather info for the specified city.""" - return f"The weather in {city} is sunny" - -agent = Agent( - name="Haiku agent", - instructions="Always respond in haiku form", - model="gpt-5-nano", - tools=[get_weather], -) -``` - -## Context - -Agents are generic on their `context` type. Context is a dependency-injection tool: it's an object you create and pass to `Runner.run()`, that is passed to every agent, tool, handoff etc, and it serves as a grab bag of dependencies and state for the agent run. You can provide any Python object as the context. - -```python -@dataclass -class UserContext: - name: str - uid: str - is_pro_user: bool - - async def fetch_purchases() -> list[Purchase]: - return ... - -agent = Agent[UserContext]( - ..., -) -``` - -## Output types - -By default, agents produce plain text (i.e. `str`) outputs. If you want the agent to produce a particular type of output, you can use the `output_type` parameter. A common choice is to use [Pydantic](https://docs.pydantic.dev/) objects, but we support any type that can be wrapped in a Pydantic [TypeAdapter](https://docs.pydantic.dev/latest/api/type_adapter/) - dataclasses, lists, TypedDict, etc. - -```python -from pydantic import BaseModel -from agents import Agent - - -class CalendarEvent(BaseModel): - name: str - date: str - participants: list[str] - -agent = Agent( - name="Calendar extractor", - instructions="Extract calendar events from text", - output_type=CalendarEvent, -) -``` - -!!! note - - When you pass an `output_type`, that tells the model to use [structured outputs](https://platform.openai.com/docs/guides/structured-outputs) instead of regular plain text responses. - -## Multi-agent system design patterns - -There are many ways to design multi‑agent systems, but we commonly see two broadly applicable patterns: - -1. Manager (agents as tools): A central manager/orchestrator invokes specialized sub‑agents as tools and retains control of the conversation. -2. Handoffs: Peer agents hand off control to a specialized agent that takes over the conversation. This is decentralized. - -See [our practical guide to building agents](https://cdn.openai.com/business-guides-and-resources/a-practical-guide-to-building-agents.pdf) for more details. - -### Manager (agents as tools) - -The `customer_facing_agent` handles all user interaction and invokes specialized sub‑agents exposed as tools. Read more in the [tools](tools.md#agents-as-tools) documentation. - -```python -from agents import Agent - -booking_agent = Agent(...) -refund_agent = Agent(...) - -customer_facing_agent = Agent( - name="Customer-facing agent", - instructions=( - "Handle all direct user communication. " - "Call the relevant tools when specialized expertise is needed." - ), - tools=[ - booking_agent.as_tool( - tool_name="booking_expert", - tool_description="Handles booking questions and requests.", - ), - refund_agent.as_tool( - tool_name="refund_expert", - tool_description="Handles refund questions and requests.", - ) - ], -) -``` - -### Handoffs - -Handoffs are sub‑agents the agent can delegate to. When a handoff occurs, the delegated agent receives the conversation history and takes over the conversation. This pattern enables modular, specialized agents that excel at a single task. Read more in the [handoffs](handoffs.md) documentation. - -```python -from agents import Agent - -booking_agent = Agent(...) -refund_agent = Agent(...) - -triage_agent = Agent( - name="Triage agent", - instructions=( - "Help the user with their questions. " - "If they ask about booking, hand off to the booking agent. " - "If they ask about refunds, hand off to the refund agent." - ), - handoffs=[booking_agent, refund_agent], -) -``` - -## Dynamic instructions - -In most cases, you can provide instructions when you create the agent. However, you can also provide dynamic instructions via a function. The function will receive the agent and context, and must return the prompt. Both regular and `async` functions are accepted. - -```python -def dynamic_instructions( - context: RunContextWrapper[UserContext], agent: Agent[UserContext] -) -> str: - return f"The user's name is {context.context.name}. Help them with their questions." - - -agent = Agent[UserContext]( - name="Triage agent", - instructions=dynamic_instructions, -) -``` - -## Lifecycle events (hooks) - -Sometimes, you want to observe the lifecycle of an agent. For example, you may want to log events, or pre-fetch data when certain events occur. You can hook into the agent lifecycle with the `hooks` property. Subclass the [`AgentHooks`][agents.lifecycle.AgentHooks] class, and override the methods you're interested in. - -## Guardrails - -Guardrails allow you to run checks/validations on user input in parallel to the agent running, and on the agent's output once it is produced. For example, you could screen the user's input and agent's output for relevance. Read more in the [guardrails](guardrails.md) documentation. - -## Cloning/copying agents - -By using the `clone()` method on an agent, you can duplicate an Agent, and optionally change any properties you like. - -```python -pirate_agent = Agent( - name="Pirate", - instructions="Write like a pirate", - model="gpt-4.1", -) - -robot_agent = pirate_agent.clone( - name="Robot", - instructions="Write like a robot", -) -``` - -## Forcing tool use - -Supplying a list of tools doesn't always mean the LLM will use a tool. You can force tool use by setting [`ModelSettings.tool_choice`][agents.model_settings.ModelSettings.tool_choice]. Valid values are: - -1. `auto`, which allows the LLM to decide whether or not to use a tool. -2. `required`, which requires the LLM to use a tool (but it can intelligently decide which tool). -3. `none`, which requires the LLM to _not_ use a tool. -4. Setting a specific string e.g. `my_tool`, which requires the LLM to use that specific tool. - -```python -from agents import Agent, Runner, function_tool, ModelSettings - -@function_tool -def get_weather(city: str) -> str: - """Returns weather info for the specified city.""" - return f"The weather in {city} is sunny" - -agent = Agent( - name="Weather Agent", - instructions="Retrieve weather details.", - tools=[get_weather], - model_settings=ModelSettings(tool_choice="get_weather") -) -``` - -## Tool Use Behavior - -The `tool_use_behavior` parameter in the `Agent` configuration controls how tool outputs are handled: - -- `"run_llm_again"`: The default. Tools are run, and the LLM processes the results to produce a final response. -- `"stop_on_first_tool"`: The output of the first tool call is used as the final response, without further LLM processing. - -```python -from agents import Agent, Runner, function_tool, ModelSettings - -@function_tool -def get_weather(city: str) -> str: - """Returns weather info for the specified city.""" - return f"The weather in {city} is sunny" - -agent = Agent( - name="Weather Agent", - instructions="Retrieve weather details.", - tools=[get_weather], - tool_use_behavior="stop_on_first_tool" -) -``` - -- `StopAtTools(stop_at_tool_names=[...])`: Stops if any specified tool is called, using its output as the final response. - -```python -from agents import Agent, Runner, function_tool -from agents.agent import StopAtTools - -@function_tool -def get_weather(city: str) -> str: - """Returns weather info for the specified city.""" - return f"The weather in {city} is sunny" - -@function_tool -def sum_numbers(a: int, b: int) -> int: - """Adds two numbers.""" - return a + b - -agent = Agent( - name="Stop At Stock Agent", - instructions="Get weather or sum numbers.", - tools=[get_weather, sum_numbers], - tool_use_behavior=StopAtTools(stop_at_tool_names=["get_weather"]) -) -``` - -- `ToolsToFinalOutputFunction`: A custom function that processes tool results and decides whether to stop or continue with the LLM. - -```python -from agents import Agent, Runner, function_tool, FunctionToolResult, RunContextWrapper -from agents.agent import ToolsToFinalOutputResult -from typing import List, Any - -@function_tool -def get_weather(city: str) -> str: - """Returns weather info for the specified city.""" - return f"The weather in {city} is sunny" - -def custom_tool_handler( - context: RunContextWrapper[Any], - tool_results: List[FunctionToolResult] -) -> ToolsToFinalOutputResult: - """Processes tool results to decide final output.""" - for result in tool_results: - if result.output and "sunny" in result.output: - return ToolsToFinalOutputResult( - is_final_output=True, - final_output=f"Final weather: {result.output}" - ) - return ToolsToFinalOutputResult( - is_final_output=False, - final_output=None - ) - -agent = Agent( - name="Weather Agent", - instructions="Retrieve weather details.", - tools=[get_weather], - tool_use_behavior=custom_tool_handler -) -``` - -!!! note - - To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call. This behavior is configurable via [`agent.reset_tool_choice`][agents.agent.Agent.reset_tool_choice]. The infinite loop is because tool results are sent to the LLM, which then generates another tool call because of `tool_choice`, ad infinitum. +# Agents + +Agents are the core building block in your apps. An agent is a large language model (LLM), configured with instructions and tools. + +## Basic configuration + +The most common properties of an agent you'll configure are: + +- `name`: A required string that identifies your agent. +- `instructions`: also known as a developer message or system prompt. +- `model`: which LLM to use, and optional `model_settings` to configure model tuning parameters like temperature, top_p, etc. +- `tools`: Tools that the agent can use to achieve its tasks. + +```python +from agents import Agent, ModelSettings, function_tool + +@function_tool +def get_weather(city: str) -> str: + """returns weather info for the specified city.""" + return f"The weather in {city} is sunny" + +agent = Agent( + name="Haiku agent", + instructions="Always respond in haiku form", + model="gpt-5-nano", + tools=[get_weather], +) +``` + +## Context + +Agents are generic on their `context` type. Context is a dependency-injection tool: it's an object you create and pass to `Runner.run()`, that is passed to every agent, tool, handoff etc, and it serves as a grab bag of dependencies and state for the agent run. You can provide any Python object as the context. + +```python +@dataclass +class UserContext: + name: str + uid: str + is_pro_user: bool + + async def fetch_purchases() -> list[Purchase]: + return ... + +agent = Agent[UserContext]( + ..., +) +``` + +## Output types + +By default, agents produce plain text (i.e. `str`) outputs. If you want the agent to produce a particular type of output, you can use the `output_type` parameter. A common choice is to use [Pydantic](https://docs.pydantic.dev/) objects, but we support any type that can be wrapped in a Pydantic [TypeAdapter](https://docs.pydantic.dev/latest/api/type_adapter/) - dataclasses, lists, TypedDict, etc. + +```python +from pydantic import BaseModel +from agents import Agent + + +class CalendarEvent(BaseModel): + name: str + date: str + participants: list[str] + +agent = Agent( + name="Calendar extractor", + instructions="Extract calendar events from text", + output_type=CalendarEvent, +) +``` + +!!! note + + When you pass an `output_type`, that tells the model to use [structured outputs](https://platform.openai.com/docs/guides/structured-outputs) instead of regular plain text responses. + +### Using structured outputs with tools + +Some models (like Google Gemini) don't natively support using tools and structured outputs together. For these cases, you can enable prompt injection: + +```python +from agents import Agent +from agents.extensions.models.litellm_model import LitellmModel + +agent = Agent( + name="Weather assistant", + model=LitellmModel("gemini/gemini-1.5-flash"), + tools=[get_weather], + output_type=WeatherReport, + enable_structured_output_with_tools=True, # Required for Gemini +) +``` + +The `enable_structured_output_with_tools` parameter injects JSON formatting instructions into the system prompt as a workaround. This is only needed for models accessed via [`LitellmModel`][agents.extensions.models.litellm_model.LitellmModel] that lack native support. OpenAI models ignore this parameter. + +See the [prompt injection documentation](models/structured_output_with_tools.md) for more details. + +## Multi-agent system design patterns + +There are many ways to design multi‑agent systems, but we commonly see two broadly applicable patterns: + +1. Manager (agents as tools): A central manager/orchestrator invokes specialized sub‑agents as tools and retains control of the conversation. +2. Handoffs: Peer agents hand off control to a specialized agent that takes over the conversation. This is decentralized. + +See [our practical guide to building agents](https://cdn.openai.com/business-guides-and-resources/a-practical-guide-to-building-agents.pdf) for more details. + +### Manager (agents as tools) + +The `customer_facing_agent` handles all user interaction and invokes specialized sub‑agents exposed as tools. Read more in the [tools](tools.md#agents-as-tools) documentation. + +```python +from agents import Agent + +booking_agent = Agent(...) +refund_agent = Agent(...) + +customer_facing_agent = Agent( + name="Customer-facing agent", + instructions=( + "Handle all direct user communication. " + "Call the relevant tools when specialized expertise is needed." + ), + tools=[ + booking_agent.as_tool( + tool_name="booking_expert", + tool_description="Handles booking questions and requests.", + ), + refund_agent.as_tool( + tool_name="refund_expert", + tool_description="Handles refund questions and requests.", + ) + ], +) +``` + +### Handoffs + +Handoffs are sub‑agents the agent can delegate to. When a handoff occurs, the delegated agent receives the conversation history and takes over the conversation. This pattern enables modular, specialized agents that excel at a single task. Read more in the [handoffs](handoffs.md) documentation. + +```python +from agents import Agent + +booking_agent = Agent(...) +refund_agent = Agent(...) + +triage_agent = Agent( + name="Triage agent", + instructions=( + "Help the user with their questions. " + "If they ask about booking, hand off to the booking agent. " + "If they ask about refunds, hand off to the refund agent." + ), + handoffs=[booking_agent, refund_agent], +) +``` + +## Dynamic instructions + +In most cases, you can provide instructions when you create the agent. However, you can also provide dynamic instructions via a function. The function will receive the agent and context, and must return the prompt. Both regular and `async` functions are accepted. + +```python +def dynamic_instructions( + context: RunContextWrapper[UserContext], agent: Agent[UserContext] +) -> str: + return f"The user's name is {context.context.name}. Help them with their questions." + + +agent = Agent[UserContext]( + name="Triage agent", + instructions=dynamic_instructions, +) +``` + +## Lifecycle events (hooks) + +Sometimes, you want to observe the lifecycle of an agent. For example, you may want to log events, or pre-fetch data when certain events occur. You can hook into the agent lifecycle with the `hooks` property. Subclass the [`AgentHooks`][agents.lifecycle.AgentHooks] class, and override the methods you're interested in. + +## Guardrails + +Guardrails allow you to run checks/validations on user input in parallel to the agent running, and on the agent's output once it is produced. For example, you could screen the user's input and agent's output for relevance. Read more in the [guardrails](guardrails.md) documentation. + +## Cloning/copying agents + +By using the `clone()` method on an agent, you can duplicate an Agent, and optionally change any properties you like. + +```python +pirate_agent = Agent( + name="Pirate", + instructions="Write like a pirate", + model="gpt-4.1", +) + +robot_agent = pirate_agent.clone( + name="Robot", + instructions="Write like a robot", +) +``` + +## Forcing tool use + +Supplying a list of tools doesn't always mean the LLM will use a tool. You can force tool use by setting [`ModelSettings.tool_choice`][agents.model_settings.ModelSettings.tool_choice]. Valid values are: + +1. `auto`, which allows the LLM to decide whether or not to use a tool. +2. `required`, which requires the LLM to use a tool (but it can intelligently decide which tool). +3. `none`, which requires the LLM to _not_ use a tool. +4. Setting a specific string e.g. `my_tool`, which requires the LLM to use that specific tool. + +```python +from agents import Agent, Runner, function_tool, ModelSettings + +@function_tool +def get_weather(city: str) -> str: + """Returns weather info for the specified city.""" + return f"The weather in {city} is sunny" + +agent = Agent( + name="Weather Agent", + instructions="Retrieve weather details.", + tools=[get_weather], + model_settings=ModelSettings(tool_choice="get_weather") +) +``` + +## Tool Use Behavior + +The `tool_use_behavior` parameter in the `Agent` configuration controls how tool outputs are handled: + +- `"run_llm_again"`: The default. Tools are run, and the LLM processes the results to produce a final response. +- `"stop_on_first_tool"`: The output of the first tool call is used as the final response, without further LLM processing. + +```python +from agents import Agent, Runner, function_tool, ModelSettings + +@function_tool +def get_weather(city: str) -> str: + """Returns weather info for the specified city.""" + return f"The weather in {city} is sunny" + +agent = Agent( + name="Weather Agent", + instructions="Retrieve weather details.", + tools=[get_weather], + tool_use_behavior="stop_on_first_tool" +) +``` + +- `StopAtTools(stop_at_tool_names=[...])`: Stops if any specified tool is called, using its output as the final response. + +```python +from agents import Agent, Runner, function_tool +from agents.agent import StopAtTools + +@function_tool +def get_weather(city: str) -> str: + """Returns weather info for the specified city.""" + return f"The weather in {city} is sunny" + +@function_tool +def sum_numbers(a: int, b: int) -> int: + """Adds two numbers.""" + return a + b + +agent = Agent( + name="Stop At Stock Agent", + instructions="Get weather or sum numbers.", + tools=[get_weather, sum_numbers], + tool_use_behavior=StopAtTools(stop_at_tool_names=["get_weather"]) +) +``` + +- `ToolsToFinalOutputFunction`: A custom function that processes tool results and decides whether to stop or continue with the LLM. + +```python +from agents import Agent, Runner, function_tool, FunctionToolResult, RunContextWrapper +from agents.agent import ToolsToFinalOutputResult +from typing import List, Any + +@function_tool +def get_weather(city: str) -> str: + """Returns weather info for the specified city.""" + return f"The weather in {city} is sunny" + +def custom_tool_handler( + context: RunContextWrapper[Any], + tool_results: List[FunctionToolResult] +) -> ToolsToFinalOutputResult: + """Processes tool results to decide final output.""" + for result in tool_results: + if result.output and "sunny" in result.output: + return ToolsToFinalOutputResult( + is_final_output=True, + final_output=f"Final weather: {result.output}" + ) + return ToolsToFinalOutputResult( + is_final_output=False, + final_output=None + ) + +agent = Agent( + name="Weather Agent", + instructions="Retrieve weather details.", + tools=[get_weather], + tool_use_behavior=custom_tool_handler +) +``` + +!!! note + + To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call. This behavior is configurable via [`agent.reset_tool_choice`][agents.agent.Agent.reset_tool_choice]. The infinite loop is because tool results are sent to the LLM, which then generates another tool call because of `tool_choice`, ad infinitum. diff --git a/docs/models/litellm.md b/docs/models/litellm.md index 08263feef..163877925 100644 --- a/docs/models/litellm.md +++ b/docs/models/litellm.md @@ -1,90 +1,123 @@ -# Using any model via LiteLLM - -!!! note - - The LiteLLM integration is in beta. You may run into issues with some model providers, especially smaller ones. Please report any issues via [Github issues](https://github.com/openai/openai-agents-python/issues) and we'll fix quickly. - -[LiteLLM](https://docs.litellm.ai/docs/) is a library that allows you to use 100+ models via a single interface. We've added a LiteLLM integration to allow you to use any AI model in the Agents SDK. - -## Setup - -You'll need to ensure `litellm` is available. You can do this by installing the optional `litellm` dependency group: - -```bash -pip install "openai-agents[litellm]" -``` - -Once done, you can use [`LitellmModel`][agents.extensions.models.litellm_model.LitellmModel] in any agent. - -## Example - -This is a fully working example. When you run it, you'll be prompted for a model name and API key. For example, you could enter: - -- `openai/gpt-4.1` for the model, and your OpenAI API key -- `anthropic/claude-3-5-sonnet-20240620` for the model, and your Anthropic API key -- etc - -For a full list of models supported in LiteLLM, see the [litellm providers docs](https://docs.litellm.ai/docs/providers). - -```python -from __future__ import annotations - -import asyncio - -from agents import Agent, Runner, function_tool, set_tracing_disabled -from agents.extensions.models.litellm_model import LitellmModel - -@function_tool -def get_weather(city: str): - print(f"[debug] getting weather for {city}") - return f"The weather in {city} is sunny." - - -async def main(model: str, api_key: str): - agent = Agent( - name="Assistant", - instructions="You only respond in haikus.", - model=LitellmModel(model=model, api_key=api_key), - tools=[get_weather], - ) - - result = await Runner.run(agent, "What's the weather in Tokyo?") - print(result.final_output) - - -if __name__ == "__main__": - # First try to get model/api key from args - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, required=False) - parser.add_argument("--api-key", type=str, required=False) - args = parser.parse_args() - - model = args.model - if not model: - model = input("Enter a model name for Litellm: ") - - api_key = args.api_key - if not api_key: - api_key = input("Enter an API key for Litellm: ") - - asyncio.run(main(model, api_key)) -``` - -## Tracking usage data - -If you want LiteLLM responses to populate the Agents SDK usage metrics, pass `ModelSettings(include_usage=True)` when creating your agent. - -```python -from agents import Agent, ModelSettings -from agents.extensions.models.litellm_model import LitellmModel - -agent = Agent( - name="Assistant", - model=LitellmModel(model="your/model", api_key="..."), - model_settings=ModelSettings(include_usage=True), -) -``` - -With `include_usage=True`, LiteLLM requests report token and request counts through `result.context_wrapper.usage` just like the built-in OpenAI models. +# Using any model via LiteLLM + +!!! note + + The LiteLLM integration is in beta. You may run into issues with some model providers, especially smaller ones. Please report any issues via [Github issues](https://github.com/openai/openai-agents-python/issues) and we'll fix quickly. + +[LiteLLM](https://docs.litellm.ai/docs/) is a library that allows you to use 100+ models via a single interface. We've added a LiteLLM integration to allow you to use any AI model in the Agents SDK. + +## Setup + +You'll need to ensure `litellm` is available. You can do this by installing the optional `litellm` dependency group: + +```bash +pip install "openai-agents[litellm]" +``` + +Once done, you can use [`LitellmModel`][agents.extensions.models.litellm_model.LitellmModel] in any agent. + +## Example + +This is a fully working example. When you run it, you'll be prompted for a model name and API key. For example, you could enter: + +- `openai/gpt-4.1` for the model, and your OpenAI API key +- `anthropic/claude-3-5-sonnet-20240620` for the model, and your Anthropic API key +- etc + +For a full list of models supported in LiteLLM, see the [litellm providers docs](https://docs.litellm.ai/docs/providers). + +```python +from __future__ import annotations + +import asyncio + +from agents import Agent, Runner, function_tool, set_tracing_disabled +from agents.extensions.models.litellm_model import LitellmModel + +@function_tool +def get_weather(city: str): + print(f"[debug] getting weather for {city}") + return f"The weather in {city} is sunny." + + +async def main(model: str, api_key: str): + agent = Agent( + name="Assistant", + instructions="You only respond in haikus.", + model=LitellmModel(model=model, api_key=api_key), + tools=[get_weather], + ) + + result = await Runner.run(agent, "What's the weather in Tokyo?") + print(result.final_output) + + +if __name__ == "__main__": + # First try to get model/api key from args + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=False) + parser.add_argument("--api-key", type=str, required=False) + args = parser.parse_args() + + model = args.model + if not model: + model = input("Enter a model name for Litellm: ") + + api_key = args.api_key + if not api_key: + api_key = input("Enter an API key for Litellm: ") + + asyncio.run(main(model, api_key)) +``` + +## Tracking usage data + +If you want LiteLLM responses to populate the Agents SDK usage metrics, pass `ModelSettings(include_usage=True)` when creating your agent. + +```python +from agents import Agent, ModelSettings +from agents.extensions.models.litellm_model import LitellmModel + +agent = Agent( + name="Assistant", + model=LitellmModel(model="your/model", api_key="..."), + model_settings=ModelSettings(include_usage=True), +) +``` + +With `include_usage=True`, LiteLLM requests report token and request counts through `result.context_wrapper.usage` just like the built-in OpenAI models. + +## Using tools with structured outputs + +Some models accessed via LiteLLM (particularly Google Gemini) don't natively support using tools and structured outputs simultaneously. For these models, enable prompt injection: + +```python +from pydantic import BaseModel +from agents import Agent, function_tool +from agents.extensions.models.litellm_model import LitellmModel + + +class Report(BaseModel): + summary: str + confidence: float + + +@function_tool +def analyze_data(query: str) -> dict: + return {"result": f"Analysis of {query}"} + + +agent = Agent( + name="Analyst", + model=LitellmModel("gemini/gemini-1.5-flash"), + tools=[analyze_data], + output_type=Report, + enable_structured_output_with_tools=True, # Required for Gemini +) +``` + +The `enable_structured_output_with_tools` parameter enables a workaround that injects JSON formatting instructions into the system prompt instead of using the native API. This allows models like Gemini to return structured outputs even when using tools. + +See the [prompt injection documentation](structured_output_with_tools.md) for complete details. diff --git a/docs/models/structured_output_with_tools.md b/docs/models/structured_output_with_tools.md new file mode 100644 index 000000000..7c00c64fb --- /dev/null +++ b/docs/models/structured_output_with_tools.md @@ -0,0 +1,237 @@ +# Prompt Injection for Structured Outputs + +Some LLM providers don't natively support using tools and structured outputs simultaneously. The Agents SDK includes an opt-in prompt injection feature to work around this limitation. + +!!! note + + This feature is specifically designed for models accessed via [`LitellmModel`][agents.extensions.models.litellm_model.LitellmModel], particularly **Google Gemini**. OpenAI models have native support and don't need this workaround. + +## The Problem + +Models like Google Gemini don't support using `tools` and `response_schema` (structured output) in the same API call. When you try: + +```python +from agents import Agent, function_tool +from agents.extensions.models.litellm_model import LitellmModel +from pydantic import BaseModel + +class WeatherReport(BaseModel): + city: str + temperature: float + +@function_tool +def get_weather(city: str) -> dict: + return {"city": city, "temperature": 22.5} + +# This causes an error with Gemini +agent = Agent( + model=LitellmModel("gemini/gemini-1.5-flash"), + tools=[get_weather], + output_type=WeatherReport, # Error: can't use both! +) +``` + +You'll get an error like: + +``` +GeminiException BadRequestError - Function calling with a response mime type +'application/json' is unsupported +``` + +## The Solution + +Enable prompt injection by setting `enable_structured_output_with_tools=True` on your agent: + +```python +agent = Agent( + model=LitellmModel("gemini/gemini-1.5-flash"), + tools=[get_weather], + output_type=WeatherReport, + enable_structured_output_with_tools=True, # ← Enables the workaround +) +``` + +When enabled, the SDK: + +1. Generates JSON formatting instructions from your Pydantic model. +2. Injects these instructions into the system prompt. +3. Disables the native `response_format` parameter to avoid API errors. +4. Parses the model's JSON response into your Pydantic model. + +## Complete Example + +```python +from __future__ import annotations + +import asyncio +from pydantic import BaseModel, Field + +from agents import Agent, Runner, function_tool +from agents.extensions.models.litellm_model import LitellmModel + + +class WeatherReport(BaseModel): + city: str = Field(description="The city name") + temperature: float = Field(description="Temperature in Celsius") + conditions: str = Field(description="Weather conditions") + + +@function_tool +def get_weather(city: str) -> dict: + """Get current weather for a city.""" + return { + "city": city, + "temperature": 22.5, + "conditions": "sunny", + } + + +async def main(): + agent = Agent( + name="WeatherBot", + instructions="Use the get_weather tool, then provide a structured report.", + model=LitellmModel("gemini/gemini-1.5-flash"), + tools=[get_weather], + output_type=WeatherReport, + enable_structured_output_with_tools=True, # Required for Gemini + ) + + result = await Runner.run(agent, "What's the weather in Tokyo?") + + # Result is properly typed as WeatherReport + report: WeatherReport = result.final_output + print(f"City: {report.city}") + print(f"Temperature: {report.temperature}") + print(f"Conditions: {report.conditions}") + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## When to Use + +| Model Provider | Access Via | Need `enable_structured_output_with_tools`? | +|----------------|-----------|------------------------------| +| Google Gemini | [`LitellmModel("gemini/...")`][agents.extensions.models.litellm_model.LitellmModel] | **Yes** - No native support | +| OpenAI | `"gpt-4o"` (default) | **No** - Has native support | +| Anthropic Claude | [`LitellmModel("claude-...")`][agents.extensions.models.litellm_model.LitellmModel] | **No** - Has native support | +| Other LiteLLM models | [`LitellmModel`][agents.extensions.models.litellm_model.LitellmModel] | **Try without first** | + +!!! tip + + If you're using [`LitellmModel`][agents.extensions.models.litellm_model.LitellmModel] and getting errors when combining tools with structured outputs, set `enable_structured_output_with_tools=True`. + +## How It Works + +### Without Prompt Injection (Default) + +The SDK uses the model's native structured output API: + +```python +# API request +{ + "tools": [...], + "response_format": {"type": "json_schema", ...} +} +``` + +This works for OpenAI and Anthropic models but fails for Gemini. + +### With Prompt Injection + +The SDK modifies the request: + +```python +# API request +{ + "system_instruction": "......", + "tools": [...], + "response_format": None # Disabled to avoid errors +} +``` + +The injected instructions tell the model: + +- Which JSON fields to output. +- The type and description of each field. +- How to format the response (valid JSON only). + +### Example Injected Instructions + +For the `WeatherReport` model above, the SDK injects: + +``` +Provide your output as a JSON object containing the following fields: + +["city", "temperature", "conditions"] + + +Here are the properties for each field: + +{ + "city": { + "description": "The city name", + "type": "string" + }, + "temperature": { + "description": "Temperature in Celsius", + "type": "number" + }, + "conditions": { + "description": "Weather conditions", + "type": "string" + } +} + + +IMPORTANT: +- Start your response with `{` and end it with `}` +- Your output will be parsed with json.loads() +- Make sure it only contains valid JSON +- Do NOT include markdown code blocks or any other formatting +``` + +## Debugging + +Enable debug logging to see when prompt injection is active: + +```python +import logging +logging.basicConfig(level=logging.DEBUG) +``` + +Look for: + +``` +DEBUG: Injected JSON output prompt for structured output with tools +``` + +## Best Practices + +1. **Use Pydantic Field descriptions**: The SDK uses these to generate better instructions. + + ```python + class Report(BaseModel): + # Good - includes description + score: float = Field(description="Confidence score from 0 to 1") + + # Less helpful - no description + count: int + ``` + +2. **Test without prompt injection first**: Only enable it if you get errors. + +3. **Use with LiteLLM models only**: OpenAI models ignore this parameter. + +## Limitations + +- The model must be able to follow JSON formatting instructions reliably. +- Parsing errors can occur if the model doesn't output valid JSON. +- This is a workaround, not a replacement for native API support. + +## Related Documentation + +- [Agents](../agents.md) - General agent configuration. +- [LiteLLM models](litellm.md) - Using any model via LiteLLM. +- [Tools](../tools.md) - Defining and using tools. diff --git a/mkdocs.yml b/mkdocs.yml index a1ed06d31..3e76fd318 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -76,6 +76,7 @@ plugins: - Models: - models/index.md - models/litellm.md + - models/structured_output_with_tools.md - config.md - visualization.md - release.md diff --git a/src/agents/agent.py b/src/agents/agent.py index a061926b1..c05a2c02a 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -231,6 +231,16 @@ class Agent(AgentBase, Generic[TContext]): """Whether to reset the tool choice to the default value after a tool has been called. Defaults to True. This ensures that the agent doesn't enter an infinite loop of tool usage.""" + enable_structured_output_with_tools: bool = False + """Enable structured outputs when using tools on models that don't natively support both + simultaneously (e.g., Gemini). When enabled, injects JSON formatting instructions into the + system prompt as a workaround instead of using the native API. Defaults to False (use native + API support when available). + + Set to True when using models that don't support both features natively (e.g., Gemini via + LiteLLM). + """ + def __post_init__(self): from typing import get_origin @@ -364,6 +374,12 @@ def __post_init__(self): f"got {type(self.reset_tool_choice).__name__}" ) + if not isinstance(self.enable_structured_output_with_tools, bool): + raise TypeError( + f"Agent enable_structured_output_with_tools must be a boolean, " + f"got {type(self.enable_structured_output_with_tools).__name__}" + ) + def clone(self, **kwargs: Any) -> Agent[TContext]: """Make a copy of the agent, with the given arguments changed. Notes: diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index 6389b38b2..c0968b87f 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -51,6 +51,7 @@ from ...tracing.spans import Span from ...usage import Usage from ...util._json import _to_dump_compatible +from ...util._prompts import get_json_output_prompt, should_inject_json_prompt class InternalChatCompletionMessage(ChatCompletionMessage): @@ -90,6 +91,7 @@ async def get_response( previous_response_id: str | None = None, # unused conversation_id: str | None = None, # unused prompt: Any | None = None, + enable_structured_output_with_tools: bool = False, ) -> ModelResponse: with generation_span( model=str(self.model), @@ -108,6 +110,7 @@ async def get_response( tracing, stream=False, prompt=prompt, + enable_structured_output_with_tools=enable_structured_output_with_tools, ) message: litellm.types.utils.Message | None = None @@ -194,6 +197,7 @@ async def stream_response( previous_response_id: str | None = None, # unused conversation_id: str | None = None, # unused prompt: Any | None = None, + enable_structured_output_with_tools: bool = False, ) -> AsyncIterator[TResponseStreamEvent]: with generation_span( model=str(self.model), @@ -212,6 +216,7 @@ async def stream_response( tracing, stream=True, prompt=prompt, + enable_structured_output_with_tools=enable_structured_output_with_tools, ) final_response: Response | None = None @@ -243,6 +248,7 @@ async def _fetch_response( tracing: ModelTracing, stream: Literal[True], prompt: Any | None = None, + enable_structured_output_with_tools: bool = False, ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ... @overload @@ -258,6 +264,7 @@ async def _fetch_response( tracing: ModelTracing, stream: Literal[False], prompt: Any | None = None, + enable_structured_output_with_tools: bool = False, ) -> litellm.types.utils.ModelResponse: ... async def _fetch_response( @@ -272,6 +279,7 @@ async def _fetch_response( tracing: ModelTracing, stream: bool = False, prompt: Any | None = None, + enable_structured_output_with_tools: bool = False, ) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]: # Preserve reasoning messages for tool calls when reasoning is on # This is needed for models like Claude 4 Sonnet/Opus which support interleaved thinking @@ -287,6 +295,19 @@ async def _fetch_response( if "anthropic" in self.model.lower() or "claude" in self.model.lower(): converted_messages = self._fix_tool_message_ordering(converted_messages) + # Check if we need to inject JSON output prompt for models that don't support + # tools + structured output simultaneously (like Gemini) + inject_json_prompt = should_inject_json_prompt( + output_schema, tools, enable_structured_output_with_tools + ) + if inject_json_prompt and output_schema: + json_prompt = get_json_output_prompt(output_schema) + if system_instructions: + system_instructions = f"{system_instructions}\n\n{json_prompt}" + else: + system_instructions = json_prompt + logger.debug("Injected JSON output prompt for structured output with tools") + if system_instructions: converted_messages.insert( 0, @@ -308,7 +329,12 @@ async def _fetch_response( else None ) tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) - response_format = Converter.convert_response_format(output_schema) + # Don't use response_format if we injected JSON prompt (avoids API errors) + response_format = ( + Converter.convert_response_format(None) + if inject_json_prompt + else Converter.convert_response_format(output_schema) + ) converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else [] diff --git a/src/agents/models/interface.py b/src/agents/models/interface.py index f25934780..f69946f8b 100644 --- a/src/agents/models/interface.py +++ b/src/agents/models/interface.py @@ -1,125 +1,133 @@ -from __future__ import annotations - -import abc -import enum -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING - -from openai.types.responses.response_prompt_param import ResponsePromptParam - -from ..agent_output import AgentOutputSchemaBase -from ..handoffs import Handoff -from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent -from ..tool import Tool - -if TYPE_CHECKING: - from ..model_settings import ModelSettings - - -class ModelTracing(enum.Enum): - DISABLED = 0 - """Tracing is disabled entirely.""" - - ENABLED = 1 - """Tracing is enabled, and all data is included.""" - - ENABLED_WITHOUT_DATA = 2 - """Tracing is enabled, but inputs/outputs are not included.""" - - def is_disabled(self) -> bool: - return self == ModelTracing.DISABLED - - def include_data(self) -> bool: - return self == ModelTracing.ENABLED - - -class Model(abc.ABC): - """The base interface for calling an LLM.""" - - @abc.abstractmethod - async def get_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None, - conversation_id: str | None, - prompt: ResponsePromptParam | None, - ) -> ModelResponse: - """Get a response from the model. - - Args: - system_instructions: The system instructions to use. - input: The input items to the model, in OpenAI Responses format. - model_settings: The model settings to use. - tools: The tools available to the model. - output_schema: The output schema to use. - handoffs: The handoffs available to the model. - tracing: Tracing configuration. - previous_response_id: the ID of the previous response. Generally not used by the model, - except for the OpenAI Responses API. - conversation_id: The ID of the stored conversation, if any. - prompt: The prompt config to use for the model. - - Returns: - The full model response. - """ - pass - - @abc.abstractmethod - def stream_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None, - conversation_id: str | None, - prompt: ResponsePromptParam | None, - ) -> AsyncIterator[TResponseStreamEvent]: - """Stream a response from the model. - - Args: - system_instructions: The system instructions to use. - input: The input items to the model, in OpenAI Responses format. - model_settings: The model settings to use. - tools: The tools available to the model. - output_schema: The output schema to use. - handoffs: The handoffs available to the model. - tracing: Tracing configuration. - previous_response_id: the ID of the previous response. Generally not used by the model, - except for the OpenAI Responses API. - conversation_id: The ID of the stored conversation, if any. - prompt: The prompt config to use for the model. - - Returns: - An iterator of response stream events, in OpenAI Responses format. - """ - pass - - -class ModelProvider(abc.ABC): - """The base interface for a model provider. - - Model provider is responsible for looking up Models by name. - """ - - @abc.abstractmethod - def get_model(self, model_name: str | None) -> Model: - """Get a model by name. - - Args: - model_name: The name of the model to get. - - Returns: - The model. - """ +from __future__ import annotations + +import abc +import enum +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING + +from openai.types.responses.response_prompt_param import ResponsePromptParam + +from ..agent_output import AgentOutputSchemaBase +from ..handoffs import Handoff +from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent +from ..tool import Tool + +if TYPE_CHECKING: + from ..model_settings import ModelSettings + + +class ModelTracing(enum.Enum): + DISABLED = 0 + """Tracing is disabled entirely.""" + + ENABLED = 1 + """Tracing is enabled, and all data is included.""" + + ENABLED_WITHOUT_DATA = 2 + """Tracing is enabled, but inputs/outputs are not included.""" + + def is_disabled(self) -> bool: + return self == ModelTracing.DISABLED + + def include_data(self) -> bool: + return self == ModelTracing.ENABLED + + +class Model(abc.ABC): + """The base interface for calling an LLM.""" + + @abc.abstractmethod + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + enable_structured_output_with_tools: bool = False, + ) -> ModelResponse: + """Get a response from the model. + + Args: + system_instructions: The system instructions to use. + input: The input items to the model, in OpenAI Responses format. + model_settings: The model settings to use. + tools: The tools available to the model. + output_schema: The output schema to use. + handoffs: The handoffs available to the model. + tracing: Tracing configuration. + previous_response_id: the ID of the previous response. Generally not used by the model, + except for the OpenAI Responses API. + conversation_id: The ID of the stored conversation, if any. + prompt: The prompt config to use for the model. + enable_structured_output_with_tools: Whether to inject JSON formatting instructions + into the system prompt when using structured outputs with tools. Required for + models that don't support both features natively (like Gemini). + + Returns: + The full model response. + """ + pass + + @abc.abstractmethod + def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + enable_structured_output_with_tools: bool = False, + ) -> AsyncIterator[TResponseStreamEvent]: + """Stream a response from the model. + + Args: + system_instructions: The system instructions to use. + input: The input items to the model, in OpenAI Responses format. + model_settings: The model settings to use. + tools: The tools available to the model. + output_schema: The output schema to use. + handoffs: The handoffs available to the model. + tracing: Tracing configuration. + previous_response_id: the ID of the previous response. Generally not used by the model, + except for the OpenAI Responses API. + conversation_id: The ID of the stored conversation, if any. + prompt: The prompt config to use for the model. + enable_structured_output_with_tools: Whether to inject JSON formatting instructions + into the system prompt when using structured outputs with tools. Required for + models that don't support both features natively (like Gemini). + + Returns: + An iterator of response stream events, in OpenAI Responses format. + """ + pass + + +class ModelProvider(abc.ABC): + """The base interface for a model provider. + + Model provider is responsible for looking up Models by name. + """ + + @abc.abstractmethod + def get_model(self, model_name: str | None) -> Model: + """Get a model by name. + + Args: + model_name: The name of the model to get. + + Returns: + The model. + """ diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index d6cf662d2..56b79ced3 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -1,359 +1,370 @@ -from __future__ import annotations - -import json -import time -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any, Literal, cast, overload - -from openai import AsyncOpenAI, AsyncStream, Omit, omit -from openai.types import ChatModel -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage -from openai.types.chat.chat_completion import Choice -from openai.types.responses import Response -from openai.types.responses.response_prompt_param import ResponsePromptParam -from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails - -from .. import _debug -from ..agent_output import AgentOutputSchemaBase -from ..handoffs import Handoff -from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent -from ..logger import logger -from ..tool import Tool -from ..tracing import generation_span -from ..tracing.span_data import GenerationSpanData -from ..tracing.spans import Span -from ..usage import Usage -from ..util._json import _to_dump_compatible -from .chatcmpl_converter import Converter -from .chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE, ChatCmplHelpers -from .chatcmpl_stream_handler import ChatCmplStreamHandler -from .fake_id import FAKE_RESPONSES_ID -from .interface import Model, ModelTracing -from .openai_responses import Converter as OpenAIResponsesConverter - -if TYPE_CHECKING: - from ..model_settings import ModelSettings - - -class OpenAIChatCompletionsModel(Model): - def __init__( - self, - model: str | ChatModel, - openai_client: AsyncOpenAI, - ) -> None: - self.model = model - self._client = openai_client - - def _non_null_or_omit(self, value: Any) -> Any: - return value if value is not None else omit - - async def get_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - previous_response_id: str | None = None, # unused - conversation_id: str | None = None, # unused - prompt: ResponsePromptParam | None = None, - ) -> ModelResponse: - with generation_span( - model=str(self.model), - model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, - disabled=tracing.is_disabled(), - ) as span_generation: - response = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - span_generation, - tracing, - stream=False, - prompt=prompt, - ) - - message: ChatCompletionMessage | None = None - first_choice: Choice | None = None - if response.choices and len(response.choices) > 0: - first_choice = response.choices[0] - message = first_choice.message - - if _debug.DONT_LOG_MODEL_DATA: - logger.debug("Received model response") - else: - if message is not None: - logger.debug( - "LLM resp:\n%s\n", - json.dumps(message.model_dump(), indent=2, ensure_ascii=False), - ) - else: - finish_reason = first_choice.finish_reason if first_choice else "-" - logger.debug(f"LLM resp had no message. finish_reason: {finish_reason}") - - usage = ( - Usage( - requests=1, - input_tokens=response.usage.prompt_tokens, - output_tokens=response.usage.completion_tokens, - total_tokens=response.usage.total_tokens, - input_tokens_details=InputTokensDetails( - cached_tokens=getattr( - response.usage.prompt_tokens_details, "cached_tokens", 0 - ) - or 0, - ), - output_tokens_details=OutputTokensDetails( - reasoning_tokens=getattr( - response.usage.completion_tokens_details, "reasoning_tokens", 0 - ) - or 0, - ), - ) - if response.usage - else Usage() - ) - if tracing.include_data(): - span_generation.span_data.output = ( - [message.model_dump()] if message is not None else [] - ) - span_generation.span_data.usage = { - "input_tokens": usage.input_tokens, - "output_tokens": usage.output_tokens, - } - - items = Converter.message_to_output_items(message) if message is not None else [] - - return ModelResponse( - output=items, - usage=usage, - response_id=None, - ) - - async def stream_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - previous_response_id: str | None = None, # unused - conversation_id: str | None = None, # unused - prompt: ResponsePromptParam | None = None, - ) -> AsyncIterator[TResponseStreamEvent]: - """ - Yields a partial message as it is generated, as well as the usage information. - """ - with generation_span( - model=str(self.model), - model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, - disabled=tracing.is_disabled(), - ) as span_generation: - response, stream = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - span_generation, - tracing, - stream=True, - prompt=prompt, - ) - - final_response: Response | None = None - async for chunk in ChatCmplStreamHandler.handle_stream(response, stream): - yield chunk - - if chunk.type == "response.completed": - final_response = chunk.response - - if tracing.include_data() and final_response: - span_generation.span_data.output = [final_response.model_dump()] - - if final_response and final_response.usage: - span_generation.span_data.usage = { - "input_tokens": final_response.usage.input_tokens, - "output_tokens": final_response.usage.output_tokens, - } - - @overload - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - span: Span[GenerationSpanData], - tracing: ModelTracing, - stream: Literal[True], - prompt: ResponsePromptParam | None = None, - ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ... - - @overload - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - span: Span[GenerationSpanData], - tracing: ModelTracing, - stream: Literal[False], - prompt: ResponsePromptParam | None = None, - ) -> ChatCompletion: ... - - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - span: Span[GenerationSpanData], - tracing: ModelTracing, - stream: bool = False, - prompt: ResponsePromptParam | None = None, - ) -> ChatCompletion | tuple[Response, AsyncStream[ChatCompletionChunk]]: - converted_messages = Converter.items_to_messages(input) - - if system_instructions: - converted_messages.insert( - 0, - { - "content": system_instructions, - "role": "system", - }, - ) - converted_messages = _to_dump_compatible(converted_messages) - - if tracing.include_data(): - span.span_data.input = converted_messages - - if model_settings.parallel_tool_calls and tools: - parallel_tool_calls: bool | Omit = True - elif model_settings.parallel_tool_calls is False: - parallel_tool_calls = False - else: - parallel_tool_calls = omit - tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) - response_format = Converter.convert_response_format(output_schema) - - converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else [] - - for handoff in handoffs: - converted_tools.append(Converter.convert_handoff_tool(handoff)) - - converted_tools = _to_dump_compatible(converted_tools) - tools_param = converted_tools if converted_tools else omit - - if _debug.DONT_LOG_MODEL_DATA: - logger.debug("Calling LLM") - else: - messages_json = json.dumps( - converted_messages, - indent=2, - ensure_ascii=False, - ) - tools_json = json.dumps( - converted_tools, - indent=2, - ensure_ascii=False, - ) - logger.debug( - f"{messages_json}\n" - f"Tools:\n{tools_json}\n" - f"Stream: {stream}\n" - f"Tool choice: {tool_choice}\n" - f"Response format: {response_format}\n" - ) - - reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None - store = ChatCmplHelpers.get_store_param(self._get_client(), model_settings) - - stream_options = ChatCmplHelpers.get_stream_options_param( - self._get_client(), model_settings, stream=stream - ) - - stream_param: Literal[True] | Omit = True if stream else omit - - ret = await self._get_client().chat.completions.create( - model=self.model, - messages=converted_messages, - tools=tools_param, - temperature=self._non_null_or_omit(model_settings.temperature), - top_p=self._non_null_or_omit(model_settings.top_p), - frequency_penalty=self._non_null_or_omit(model_settings.frequency_penalty), - presence_penalty=self._non_null_or_omit(model_settings.presence_penalty), - max_tokens=self._non_null_or_omit(model_settings.max_tokens), - tool_choice=tool_choice, - response_format=response_format, - parallel_tool_calls=parallel_tool_calls, - stream=cast(Any, stream_param), - stream_options=self._non_null_or_omit(stream_options), - store=self._non_null_or_omit(store), - reasoning_effort=self._non_null_or_omit(reasoning_effort), - verbosity=self._non_null_or_omit(model_settings.verbosity), - top_logprobs=self._non_null_or_omit(model_settings.top_logprobs), - extra_headers=self._merge_headers(model_settings), - extra_query=model_settings.extra_query, - extra_body=model_settings.extra_body, - metadata=self._non_null_or_omit(model_settings.metadata), - **(model_settings.extra_args or {}), - ) - - if isinstance(ret, ChatCompletion): - return ret - - responses_tool_choice = OpenAIResponsesConverter.convert_tool_choice( - model_settings.tool_choice - ) - if responses_tool_choice is None or responses_tool_choice is omit: - # For Responses API data compatibility with Chat Completions patterns, - # we need to set "none" if tool_choice is absent. - # Without this fix, you'll get the following error: - # pydantic_core._pydantic_core.ValidationError: 4 validation errors for Response - # tool_choice.literal['none','auto','required'] - # Input should be 'none', 'auto' or 'required' - # see also: https://github.com/openai/openai-agents-python/issues/980 - responses_tool_choice = "auto" - - response = Response( - id=FAKE_RESPONSES_ID, - created_at=time.time(), - model=self.model, - object="response", - output=[], - tool_choice=responses_tool_choice, # type: ignore[arg-type] - top_p=model_settings.top_p, - temperature=model_settings.temperature, - tools=[], - parallel_tool_calls=parallel_tool_calls or False, - reasoning=model_settings.reasoning, - ) - return response, ret - - def _get_client(self) -> AsyncOpenAI: - if self._client is None: - self._client = AsyncOpenAI() - return self._client - - def _merge_headers(self, model_settings: ModelSettings): - return { - **HEADERS, - **(model_settings.extra_headers or {}), - **(HEADERS_OVERRIDE.get() or {}), - } +from __future__ import annotations + +import json +import time +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any, Literal, cast, overload + +from openai import AsyncOpenAI, AsyncStream, Omit, omit +from openai.types import ChatModel +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.responses import Response +from openai.types.responses.response_prompt_param import ResponsePromptParam +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from .. import _debug +from ..agent_output import AgentOutputSchemaBase +from ..handoffs import Handoff +from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent +from ..logger import logger +from ..tool import Tool +from ..tracing import generation_span +from ..tracing.span_data import GenerationSpanData +from ..tracing.spans import Span +from ..usage import Usage +from ..util._json import _to_dump_compatible +from .chatcmpl_converter import Converter +from .chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE, ChatCmplHelpers +from .chatcmpl_stream_handler import ChatCmplStreamHandler +from .fake_id import FAKE_RESPONSES_ID +from .interface import Model, ModelTracing +from .openai_responses import Converter as OpenAIResponsesConverter + +if TYPE_CHECKING: + from ..model_settings import ModelSettings + + +class OpenAIChatCompletionsModel(Model): + def __init__( + self, + model: str | ChatModel, + openai_client: AsyncOpenAI, + ) -> None: + self.model = model + self._client = openai_client + + def _non_null_or_omit(self, value: Any) -> Any: + return value if value is not None else omit + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused + prompt: ResponsePromptParam | None = None, + enable_structured_output_with_tools: bool = False, + ) -> ModelResponse: + with generation_span( + model=str(self.model), + model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, + disabled=tracing.is_disabled(), + ) as span_generation: + response = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + span_generation, + tracing, + stream=False, + prompt=prompt, + enable_structured_output_with_tools=enable_structured_output_with_tools, + ) + + message: ChatCompletionMessage | None = None + first_choice: Choice | None = None + if response.choices and len(response.choices) > 0: + first_choice = response.choices[0] + message = first_choice.message + + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("Received model response") + else: + if message is not None: + logger.debug( + "LLM resp:\n%s\n", + json.dumps(message.model_dump(), indent=2, ensure_ascii=False), + ) + else: + finish_reason = first_choice.finish_reason if first_choice else "-" + logger.debug(f"LLM resp had no message. finish_reason: {finish_reason}") + + usage = ( + Usage( + requests=1, + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response.usage.prompt_tokens_details, "cached_tokens", 0 + ) + or 0, + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response.usage.completion_tokens_details, "reasoning_tokens", 0 + ) + or 0, + ), + ) + if response.usage + else Usage() + ) + if tracing.include_data(): + span_generation.span_data.output = ( + [message.model_dump()] if message is not None else [] + ) + span_generation.span_data.usage = { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + } + + items = Converter.message_to_output_items(message) if message is not None else [] + + return ModelResponse( + output=items, + usage=usage, + response_id=None, + ) + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused + prompt: ResponsePromptParam | None = None, + enable_structured_output_with_tools: bool = False, + ) -> AsyncIterator[TResponseStreamEvent]: + """ + Yields a partial message as it is generated, as well as the usage information. + """ + with generation_span( + model=str(self.model), + model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, + disabled=tracing.is_disabled(), + ) as span_generation: + response, stream = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + span_generation, + tracing, + stream=True, + prompt=prompt, + enable_structured_output_with_tools=enable_structured_output_with_tools, + ) + + final_response: Response | None = None + async for chunk in ChatCmplStreamHandler.handle_stream(response, stream): + yield chunk + + if chunk.type == "response.completed": + final_response = chunk.response + + if tracing.include_data() and final_response: + span_generation.span_data.output = [final_response.model_dump()] + + if final_response and final_response.usage: + span_generation.span_data.usage = { + "input_tokens": final_response.usage.input_tokens, + "output_tokens": final_response.usage.output_tokens, + } + + @overload + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: Literal[True], + prompt: ResponsePromptParam | None = None, + enable_structured_output_with_tools: bool = False, + ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ... + + @overload + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: Literal[False], + prompt: ResponsePromptParam | None = None, + enable_structured_output_with_tools: bool = False, + ) -> ChatCompletion: ... + + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: bool = False, + prompt: ResponsePromptParam | None = None, + enable_structured_output_with_tools: bool = False, + ) -> ChatCompletion | tuple[Response, AsyncStream[ChatCompletionChunk]]: + # Note: enable_structured_output_with_tools parameter is accepted for interface consistency + # but not used for OpenAI models since they have native support for + # tools + structured outputs simultaneously + + converted_messages = Converter.items_to_messages(input) + + if system_instructions: + converted_messages.insert( + 0, + { + "content": system_instructions, + "role": "system", + }, + ) + converted_messages = _to_dump_compatible(converted_messages) + + if tracing.include_data(): + span.span_data.input = converted_messages + + if model_settings.parallel_tool_calls and tools: + parallel_tool_calls: bool | Omit = True + elif model_settings.parallel_tool_calls is False: + parallel_tool_calls = False + else: + parallel_tool_calls = omit + tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) + response_format = Converter.convert_response_format(output_schema) + + converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else [] + + for handoff in handoffs: + converted_tools.append(Converter.convert_handoff_tool(handoff)) + + converted_tools = _to_dump_compatible(converted_tools) + tools_param = converted_tools if converted_tools else omit + + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("Calling LLM") + else: + messages_json = json.dumps( + converted_messages, + indent=2, + ensure_ascii=False, + ) + tools_json = json.dumps( + converted_tools, + indent=2, + ensure_ascii=False, + ) + logger.debug( + f"{messages_json}\n" + f"Tools:\n{tools_json}\n" + f"Stream: {stream}\n" + f"Tool choice: {tool_choice}\n" + f"Response format: {response_format}\n" + ) + + reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None + store = ChatCmplHelpers.get_store_param(self._get_client(), model_settings) + + stream_options = ChatCmplHelpers.get_stream_options_param( + self._get_client(), model_settings, stream=stream + ) + + stream_param: Literal[True] | Omit = True if stream else omit + + ret = await self._get_client().chat.completions.create( + model=self.model, + messages=converted_messages, + tools=tools_param, + temperature=self._non_null_or_omit(model_settings.temperature), + top_p=self._non_null_or_omit(model_settings.top_p), + frequency_penalty=self._non_null_or_omit(model_settings.frequency_penalty), + presence_penalty=self._non_null_or_omit(model_settings.presence_penalty), + max_tokens=self._non_null_or_omit(model_settings.max_tokens), + tool_choice=tool_choice, + response_format=response_format, + parallel_tool_calls=parallel_tool_calls, + stream=cast(Any, stream_param), + stream_options=self._non_null_or_omit(stream_options), + store=self._non_null_or_omit(store), + reasoning_effort=self._non_null_or_omit(reasoning_effort), + verbosity=self._non_null_or_omit(model_settings.verbosity), + top_logprobs=self._non_null_or_omit(model_settings.top_logprobs), + extra_headers=self._merge_headers(model_settings), + extra_query=model_settings.extra_query, + extra_body=model_settings.extra_body, + metadata=self._non_null_or_omit(model_settings.metadata), + **(model_settings.extra_args or {}), + ) + + if isinstance(ret, ChatCompletion): + return ret + + responses_tool_choice = OpenAIResponsesConverter.convert_tool_choice( + model_settings.tool_choice + ) + if responses_tool_choice is None or responses_tool_choice is omit: + # For Responses API data compatibility with Chat Completions patterns, + # we need to set "none" if tool_choice is absent. + # Without this fix, you'll get the following error: + # pydantic_core._pydantic_core.ValidationError: 4 validation errors for Response + # tool_choice.literal['none','auto','required'] + # Input should be 'none', 'auto' or 'required' + # see also: https://github.com/openai/openai-agents-python/issues/980 + responses_tool_choice = "auto" + + response = Response( + id=FAKE_RESPONSES_ID, + created_at=time.time(), + model=self.model, + object="response", + output=[], + tool_choice=responses_tool_choice, # type: ignore[arg-type] + top_p=model_settings.top_p, + temperature=model_settings.temperature, + tools=[], + parallel_tool_calls=parallel_tool_calls or False, + reasoning=model_settings.reasoning, + ) + return response, ret + + def _get_client(self) -> AsyncOpenAI: + if self._client is None: + self._client = AsyncOpenAI() + return self._client + + def _merge_headers(self, model_settings: ModelSettings): + return { + **HEADERS, + **(model_settings.extra_headers or {}), + **(HEADERS_OVERRIDE.get() or {}), + } diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 36a981404..ca48d17d8 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -1,516 +1,518 @@ -from __future__ import annotations - -import json -from collections.abc import AsyncIterator -from contextvars import ContextVar -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload - -from openai import APIStatusError, AsyncOpenAI, AsyncStream, Omit, omit -from openai.types import ChatModel -from openai.types.responses import ( - Response, - ResponseCompletedEvent, - ResponseIncludable, - ResponseStreamEvent, - ResponseTextConfigParam, - ToolParam, - response_create_params, -) -from openai.types.responses.response_prompt_param import ResponsePromptParam - -from .. import _debug -from ..agent_output import AgentOutputSchemaBase -from ..exceptions import UserError -from ..handoffs import Handoff -from ..items import ItemHelpers, ModelResponse, TResponseInputItem -from ..logger import logger -from ..model_settings import MCPToolChoice -from ..tool import ( - CodeInterpreterTool, - ComputerTool, - FileSearchTool, - FunctionTool, - HostedMCPTool, - ImageGenerationTool, - LocalShellTool, - Tool, - WebSearchTool, -) -from ..tracing import SpanError, response_span -from ..usage import Usage -from ..util._json import _to_dump_compatible -from ..version import __version__ -from .interface import Model, ModelTracing - -if TYPE_CHECKING: - from ..model_settings import ModelSettings - - -_USER_AGENT = f"Agents/Python {__version__}" -_HEADERS = {"User-Agent": _USER_AGENT} - -# Override headers used by the Responses API. -_HEADERS_OVERRIDE: ContextVar[dict[str, str] | None] = ContextVar( - "openai_responses_headers_override", default=None -) - - -class OpenAIResponsesModel(Model): - """ - Implementation of `Model` that uses the OpenAI Responses API. - """ - - def __init__( - self, - model: str | ChatModel, - openai_client: AsyncOpenAI, - ) -> None: - self.model = model - self._client = openai_client - - def _non_null_or_omit(self, value: Any) -> Any: - return value if value is not None else omit - - async def get_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - previous_response_id: str | None = None, - conversation_id: str | None = None, - prompt: ResponsePromptParam | None = None, - ) -> ModelResponse: - with response_span(disabled=tracing.is_disabled()) as span_response: - try: - response = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - stream=False, - prompt=prompt, - ) - - if _debug.DONT_LOG_MODEL_DATA: - logger.debug("LLM responded") - else: - logger.debug( - "LLM resp:\n" - f"""{ - json.dumps( - [x.model_dump() for x in response.output], - indent=2, - ensure_ascii=False, - ) - }\n""" - ) - - usage = ( - Usage( - requests=1, - input_tokens=response.usage.input_tokens, - output_tokens=response.usage.output_tokens, - total_tokens=response.usage.total_tokens, - input_tokens_details=response.usage.input_tokens_details, - output_tokens_details=response.usage.output_tokens_details, - ) - if response.usage - else Usage() - ) - - if tracing.include_data(): - span_response.span_data.response = response - span_response.span_data.input = input - except Exception as e: - span_response.set_error( - SpanError( - message="Error getting response", - data={ - "error": str(e) if tracing.include_data() else e.__class__.__name__, - }, - ) - ) - request_id = e.request_id if isinstance(e, APIStatusError) else None - logger.error(f"Error getting response: {e}. (request_id: {request_id})") - raise - - return ModelResponse( - output=response.output, - usage=usage, - response_id=response.id, - ) - - async def stream_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - previous_response_id: str | None = None, - conversation_id: str | None = None, - prompt: ResponsePromptParam | None = None, - ) -> AsyncIterator[ResponseStreamEvent]: - """ - Yields a partial message as it is generated, as well as the usage information. - """ - with response_span(disabled=tracing.is_disabled()) as span_response: - try: - stream = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - stream=True, - prompt=prompt, - ) - - final_response: Response | None = None - - async for chunk in stream: - if isinstance(chunk, ResponseCompletedEvent): - final_response = chunk.response - yield chunk - - if final_response and tracing.include_data(): - span_response.span_data.response = final_response - span_response.span_data.input = input - - except Exception as e: - span_response.set_error( - SpanError( - message="Error streaming response", - data={ - "error": str(e) if tracing.include_data() else e.__class__.__name__, - }, - ) - ) - logger.error(f"Error streaming response: {e}") - raise - - @overload - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - previous_response_id: str | None, - conversation_id: str | None, - stream: Literal[True], - prompt: ResponsePromptParam | None = None, - ) -> AsyncStream[ResponseStreamEvent]: ... - - @overload - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - previous_response_id: str | None, - conversation_id: str | None, - stream: Literal[False], - prompt: ResponsePromptParam | None = None, - ) -> Response: ... - - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - previous_response_id: str | None = None, - conversation_id: str | None = None, - stream: Literal[True] | Literal[False] = False, - prompt: ResponsePromptParam | None = None, - ) -> Response | AsyncStream[ResponseStreamEvent]: - list_input = ItemHelpers.input_to_new_input_list(input) - list_input = _to_dump_compatible(list_input) - - if model_settings.parallel_tool_calls and tools: - parallel_tool_calls: bool | Omit = True - elif model_settings.parallel_tool_calls is False: - parallel_tool_calls = False - else: - parallel_tool_calls = omit - - tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) - converted_tools = Converter.convert_tools(tools, handoffs) - converted_tools_payload = _to_dump_compatible(converted_tools.tools) - response_format = Converter.get_response_format(output_schema) - - include_set: set[str] = set(converted_tools.includes) - if model_settings.response_include is not None: - include_set.update(model_settings.response_include) - if model_settings.top_logprobs is not None: - include_set.add("message.output_text.logprobs") - include = cast(list[ResponseIncludable], list(include_set)) - - if _debug.DONT_LOG_MODEL_DATA: - logger.debug("Calling LLM") - else: - input_json = json.dumps( - list_input, - indent=2, - ensure_ascii=False, - ) - tools_json = json.dumps( - converted_tools_payload, - indent=2, - ensure_ascii=False, - ) - logger.debug( - f"Calling LLM {self.model} with input:\n" - f"{input_json}\n" - f"Tools:\n{tools_json}\n" - f"Stream: {stream}\n" - f"Tool choice: {tool_choice}\n" - f"Response format: {response_format}\n" - f"Previous response id: {previous_response_id}\n" - f"Conversation id: {conversation_id}\n" - ) - - extra_args = dict(model_settings.extra_args or {}) - if model_settings.top_logprobs is not None: - extra_args["top_logprobs"] = model_settings.top_logprobs - if model_settings.verbosity is not None: - if response_format is not omit: - response_format["verbosity"] = model_settings.verbosity # type: ignore [index] - else: - response_format = {"verbosity": model_settings.verbosity} - - stream_param: Literal[True] | Omit = True if stream else omit - - response = await self._client.responses.create( - previous_response_id=self._non_null_or_omit(previous_response_id), - conversation=self._non_null_or_omit(conversation_id), - instructions=self._non_null_or_omit(system_instructions), - model=self.model, - input=list_input, - include=include, - tools=converted_tools_payload, - prompt=self._non_null_or_omit(prompt), - temperature=self._non_null_or_omit(model_settings.temperature), - top_p=self._non_null_or_omit(model_settings.top_p), - truncation=self._non_null_or_omit(model_settings.truncation), - max_output_tokens=self._non_null_or_omit(model_settings.max_tokens), - tool_choice=tool_choice, - parallel_tool_calls=parallel_tool_calls, - stream=cast(Any, stream_param), - extra_headers=self._merge_headers(model_settings), - extra_query=model_settings.extra_query, - extra_body=model_settings.extra_body, - text=response_format, - store=self._non_null_or_omit(model_settings.store), - reasoning=self._non_null_or_omit(model_settings.reasoning), - metadata=self._non_null_or_omit(model_settings.metadata), - **extra_args, - ) - return cast(Union[Response, AsyncStream[ResponseStreamEvent]], response) - - def _get_client(self) -> AsyncOpenAI: - if self._client is None: - self._client = AsyncOpenAI() - return self._client - - def _merge_headers(self, model_settings: ModelSettings): - return { - **_HEADERS, - **(model_settings.extra_headers or {}), - **(_HEADERS_OVERRIDE.get() or {}), - } - - -@dataclass -class ConvertedTools: - tools: list[ToolParam] - includes: list[ResponseIncludable] - - -class Converter: - @classmethod - def convert_tool_choice( - cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None - ) -> response_create_params.ToolChoice | Omit: - if tool_choice is None: - return omit - elif isinstance(tool_choice, MCPToolChoice): - return { - "server_label": tool_choice.server_label, - "type": "mcp", - "name": tool_choice.name, - } - elif tool_choice == "required": - return "required" - elif tool_choice == "auto": - return "auto" - elif tool_choice == "none": - return "none" - elif tool_choice == "file_search": - return { - "type": "file_search", - } - elif tool_choice == "web_search": - return { - # TODO: revist the type: ignore comment when ToolChoice is updated in the future - "type": "web_search", # type: ignore [typeddict-item] - } - elif tool_choice == "web_search_preview": - return { - "type": "web_search_preview", - } - elif tool_choice == "computer_use_preview": - return { - "type": "computer_use_preview", - } - elif tool_choice == "image_generation": - return { - "type": "image_generation", - } - elif tool_choice == "code_interpreter": - return { - "type": "code_interpreter", - } - elif tool_choice == "mcp": - # Note that this is still here for backwards compatibility, - # but migrating to MCPToolChoice is recommended. - return {"type": "mcp"} # type: ignore [typeddict-item] - else: - return { - "type": "function", - "name": tool_choice, - } - - @classmethod - def get_response_format( - cls, output_schema: AgentOutputSchemaBase | None - ) -> ResponseTextConfigParam | Omit: - if output_schema is None or output_schema.is_plain_text(): - return omit - else: - return { - "format": { - "type": "json_schema", - "name": "final_output", - "schema": output_schema.json_schema(), - "strict": output_schema.is_strict_json_schema(), - } - } - - @classmethod - def convert_tools( - cls, - tools: list[Tool], - handoffs: list[Handoff[Any, Any]], - ) -> ConvertedTools: - converted_tools: list[ToolParam] = [] - includes: list[ResponseIncludable] = [] - - computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)] - if len(computer_tools) > 1: - raise UserError(f"You can only provide one computer tool. Got {len(computer_tools)}") - - for tool in tools: - converted_tool, include = cls._convert_tool(tool) - converted_tools.append(converted_tool) - if include: - includes.append(include) - - for handoff in handoffs: - converted_tools.append(cls._convert_handoff_tool(handoff)) - - return ConvertedTools(tools=converted_tools, includes=includes) - - @classmethod - def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None]: - """Returns converted tool and includes""" - - if isinstance(tool, FunctionTool): - converted_tool: ToolParam = { - "name": tool.name, - "parameters": tool.params_json_schema, - "strict": tool.strict_json_schema, - "type": "function", - "description": tool.description, - } - includes: ResponseIncludable | None = None - elif isinstance(tool, WebSearchTool): - # TODO: revist the type: ignore comment when ToolParam is updated in the future - converted_tool = { - "type": "web_search", - "filters": tool.filters.model_dump() if tool.filters is not None else None, # type: ignore [typeddict-item] - "user_location": tool.user_location, - "search_context_size": tool.search_context_size, - } - includes = None - elif isinstance(tool, FileSearchTool): - converted_tool = { - "type": "file_search", - "vector_store_ids": tool.vector_store_ids, - } - if tool.max_num_results: - converted_tool["max_num_results"] = tool.max_num_results - if tool.ranking_options: - converted_tool["ranking_options"] = tool.ranking_options - if tool.filters: - converted_tool["filters"] = tool.filters - - includes = "file_search_call.results" if tool.include_search_results else None - elif isinstance(tool, ComputerTool): - converted_tool = { - "type": "computer_use_preview", - "environment": tool.computer.environment, - "display_width": tool.computer.dimensions[0], - "display_height": tool.computer.dimensions[1], - } - includes = None - elif isinstance(tool, HostedMCPTool): - converted_tool = tool.tool_config - includes = None - elif isinstance(tool, ImageGenerationTool): - converted_tool = tool.tool_config - includes = None - elif isinstance(tool, CodeInterpreterTool): - converted_tool = tool.tool_config - includes = None - elif isinstance(tool, LocalShellTool): - converted_tool = { - "type": "local_shell", - } - includes = None - else: - raise UserError(f"Unknown tool type: {type(tool)}, tool") - - return converted_tool, includes - - @classmethod - def _convert_handoff_tool(cls, handoff: Handoff) -> ToolParam: - return { - "name": handoff.tool_name, - "parameters": handoff.input_json_schema, - "strict": handoff.strict_json_schema, - "type": "function", - "description": handoff.tool_description, - } +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from contextvars import ContextVar +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload + +from openai import APIStatusError, AsyncOpenAI, AsyncStream, Omit, omit +from openai.types import ChatModel +from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseIncludable, + ResponseStreamEvent, + ResponseTextConfigParam, + ToolParam, + response_create_params, +) +from openai.types.responses.response_prompt_param import ResponsePromptParam + +from .. import _debug +from ..agent_output import AgentOutputSchemaBase +from ..exceptions import UserError +from ..handoffs import Handoff +from ..items import ItemHelpers, ModelResponse, TResponseInputItem +from ..logger import logger +from ..model_settings import MCPToolChoice +from ..tool import ( + CodeInterpreterTool, + ComputerTool, + FileSearchTool, + FunctionTool, + HostedMCPTool, + ImageGenerationTool, + LocalShellTool, + Tool, + WebSearchTool, +) +from ..tracing import SpanError, response_span +from ..usage import Usage +from ..util._json import _to_dump_compatible +from ..version import __version__ +from .interface import Model, ModelTracing + +if TYPE_CHECKING: + from ..model_settings import ModelSettings + + +_USER_AGENT = f"Agents/Python {__version__}" +_HEADERS = {"User-Agent": _USER_AGENT} + +# Override headers used by the Responses API. +_HEADERS_OVERRIDE: ContextVar[dict[str, str] | None] = ContextVar( + "openai_responses_headers_override", default=None +) + + +class OpenAIResponsesModel(Model): + """ + Implementation of `Model` that uses the OpenAI Responses API. + """ + + def __init__( + self, + model: str | ChatModel, + openai_client: AsyncOpenAI, + ) -> None: + self.model = model + self._client = openai_client + + def _non_null_or_omit(self, value: Any) -> Any: + return value if value is not None else omit + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, + conversation_id: str | None = None, + prompt: ResponsePromptParam | None = None, + enable_structured_output_with_tools: bool = False, + ) -> ModelResponse: + with response_span(disabled=tracing.is_disabled()) as span_response: + try: + response = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + stream=False, + prompt=prompt, + ) + + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("LLM responded") + else: + logger.debug( + "LLM resp:\n" + f"""{ + json.dumps( + [x.model_dump() for x in response.output], + indent=2, + ensure_ascii=False, + ) + }\n""" + ) + + usage = ( + Usage( + requests=1, + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.input_tokens_details, + output_tokens_details=response.usage.output_tokens_details, + ) + if response.usage + else Usage() + ) + + if tracing.include_data(): + span_response.span_data.response = response + span_response.span_data.input = input + except Exception as e: + span_response.set_error( + SpanError( + message="Error getting response", + data={ + "error": str(e) if tracing.include_data() else e.__class__.__name__, + }, + ) + ) + request_id = e.request_id if isinstance(e, APIStatusError) else None + logger.error(f"Error getting response: {e}. (request_id: {request_id})") + raise + + return ModelResponse( + output=response.output, + usage=usage, + response_id=response.id, + ) + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, + conversation_id: str | None = None, + prompt: ResponsePromptParam | None = None, + enable_structured_output_with_tools: bool = False, + ) -> AsyncIterator[ResponseStreamEvent]: + """ + Yields a partial message as it is generated, as well as the usage information. + """ + with response_span(disabled=tracing.is_disabled()) as span_response: + try: + stream = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + stream=True, + prompt=prompt, + ) + + final_response: Response | None = None + + async for chunk in stream: + if isinstance(chunk, ResponseCompletedEvent): + final_response = chunk.response + yield chunk + + if final_response and tracing.include_data(): + span_response.span_data.response = final_response + span_response.span_data.input = input + + except Exception as e: + span_response.set_error( + SpanError( + message="Error streaming response", + data={ + "error": str(e) if tracing.include_data() else e.__class__.__name__, + }, + ) + ) + logger.error(f"Error streaming response: {e}") + raise + + @overload + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + previous_response_id: str | None, + conversation_id: str | None, + stream: Literal[True], + prompt: ResponsePromptParam | None = None, + ) -> AsyncStream[ResponseStreamEvent]: ... + + @overload + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + previous_response_id: str | None, + conversation_id: str | None, + stream: Literal[False], + prompt: ResponsePromptParam | None = None, + ) -> Response: ... + + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + previous_response_id: str | None = None, + conversation_id: str | None = None, + stream: Literal[True] | Literal[False] = False, + prompt: ResponsePromptParam | None = None, + ) -> Response | AsyncStream[ResponseStreamEvent]: + list_input = ItemHelpers.input_to_new_input_list(input) + list_input = _to_dump_compatible(list_input) + + if model_settings.parallel_tool_calls and tools: + parallel_tool_calls: bool | Omit = True + elif model_settings.parallel_tool_calls is False: + parallel_tool_calls = False + else: + parallel_tool_calls = omit + + tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) + converted_tools = Converter.convert_tools(tools, handoffs) + converted_tools_payload = _to_dump_compatible(converted_tools.tools) + response_format = Converter.get_response_format(output_schema) + + include_set: set[str] = set(converted_tools.includes) + if model_settings.response_include is not None: + include_set.update(model_settings.response_include) + if model_settings.top_logprobs is not None: + include_set.add("message.output_text.logprobs") + include = cast(list[ResponseIncludable], list(include_set)) + + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("Calling LLM") + else: + input_json = json.dumps( + list_input, + indent=2, + ensure_ascii=False, + ) + tools_json = json.dumps( + converted_tools_payload, + indent=2, + ensure_ascii=False, + ) + logger.debug( + f"Calling LLM {self.model} with input:\n" + f"{input_json}\n" + f"Tools:\n{tools_json}\n" + f"Stream: {stream}\n" + f"Tool choice: {tool_choice}\n" + f"Response format: {response_format}\n" + f"Previous response id: {previous_response_id}\n" + f"Conversation id: {conversation_id}\n" + ) + + extra_args = dict(model_settings.extra_args or {}) + if model_settings.top_logprobs is not None: + extra_args["top_logprobs"] = model_settings.top_logprobs + if model_settings.verbosity is not None: + if response_format is not omit: + response_format["verbosity"] = model_settings.verbosity # type: ignore [index] + else: + response_format = {"verbosity": model_settings.verbosity} + + stream_param: Literal[True] | Omit = True if stream else omit + + response = await self._client.responses.create( + previous_response_id=self._non_null_or_omit(previous_response_id), + conversation=self._non_null_or_omit(conversation_id), + instructions=self._non_null_or_omit(system_instructions), + model=self.model, + input=list_input, + include=include, + tools=converted_tools_payload, + prompt=self._non_null_or_omit(prompt), + temperature=self._non_null_or_omit(model_settings.temperature), + top_p=self._non_null_or_omit(model_settings.top_p), + truncation=self._non_null_or_omit(model_settings.truncation), + max_output_tokens=self._non_null_or_omit(model_settings.max_tokens), + tool_choice=tool_choice, + parallel_tool_calls=parallel_tool_calls, + stream=cast(Any, stream_param), + extra_headers=self._merge_headers(model_settings), + extra_query=model_settings.extra_query, + extra_body=model_settings.extra_body, + text=response_format, + store=self._non_null_or_omit(model_settings.store), + reasoning=self._non_null_or_omit(model_settings.reasoning), + metadata=self._non_null_or_omit(model_settings.metadata), + **extra_args, + ) + return cast(Union[Response, AsyncStream[ResponseStreamEvent]], response) + + def _get_client(self) -> AsyncOpenAI: + if self._client is None: + self._client = AsyncOpenAI() + return self._client + + def _merge_headers(self, model_settings: ModelSettings): + return { + **_HEADERS, + **(model_settings.extra_headers or {}), + **(_HEADERS_OVERRIDE.get() or {}), + } + + +@dataclass +class ConvertedTools: + tools: list[ToolParam] + includes: list[ResponseIncludable] + + +class Converter: + @classmethod + def convert_tool_choice( + cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None + ) -> response_create_params.ToolChoice | Omit: + if tool_choice is None: + return omit + elif isinstance(tool_choice, MCPToolChoice): + return { + "server_label": tool_choice.server_label, + "type": "mcp", + "name": tool_choice.name, + } + elif tool_choice == "required": + return "required" + elif tool_choice == "auto": + return "auto" + elif tool_choice == "none": + return "none" + elif tool_choice == "file_search": + return { + "type": "file_search", + } + elif tool_choice == "web_search": + return { + # TODO: revist the type: ignore comment when ToolChoice is updated in the future + "type": "web_search", # type: ignore [typeddict-item] + } + elif tool_choice == "web_search_preview": + return { + "type": "web_search_preview", + } + elif tool_choice == "computer_use_preview": + return { + "type": "computer_use_preview", + } + elif tool_choice == "image_generation": + return { + "type": "image_generation", + } + elif tool_choice == "code_interpreter": + return { + "type": "code_interpreter", + } + elif tool_choice == "mcp": + # Note that this is still here for backwards compatibility, + # but migrating to MCPToolChoice is recommended. + return {"type": "mcp"} # type: ignore [typeddict-item] + else: + return { + "type": "function", + "name": tool_choice, + } + + @classmethod + def get_response_format( + cls, output_schema: AgentOutputSchemaBase | None + ) -> ResponseTextConfigParam | Omit: + if output_schema is None or output_schema.is_plain_text(): + return omit + else: + return { + "format": { + "type": "json_schema", + "name": "final_output", + "schema": output_schema.json_schema(), + "strict": output_schema.is_strict_json_schema(), + } + } + + @classmethod + def convert_tools( + cls, + tools: list[Tool], + handoffs: list[Handoff[Any, Any]], + ) -> ConvertedTools: + converted_tools: list[ToolParam] = [] + includes: list[ResponseIncludable] = [] + + computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)] + if len(computer_tools) > 1: + raise UserError(f"You can only provide one computer tool. Got {len(computer_tools)}") + + for tool in tools: + converted_tool, include = cls._convert_tool(tool) + converted_tools.append(converted_tool) + if include: + includes.append(include) + + for handoff in handoffs: + converted_tools.append(cls._convert_handoff_tool(handoff)) + + return ConvertedTools(tools=converted_tools, includes=includes) + + @classmethod + def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None]: + """Returns converted tool and includes""" + + if isinstance(tool, FunctionTool): + converted_tool: ToolParam = { + "name": tool.name, + "parameters": tool.params_json_schema, + "strict": tool.strict_json_schema, + "type": "function", + "description": tool.description, + } + includes: ResponseIncludable | None = None + elif isinstance(tool, WebSearchTool): + # TODO: revist the type: ignore comment when ToolParam is updated in the future + converted_tool = { + "type": "web_search", + "filters": tool.filters.model_dump() if tool.filters is not None else None, # type: ignore [typeddict-item] + "user_location": tool.user_location, + "search_context_size": tool.search_context_size, + } + includes = None + elif isinstance(tool, FileSearchTool): + converted_tool = { + "type": "file_search", + "vector_store_ids": tool.vector_store_ids, + } + if tool.max_num_results: + converted_tool["max_num_results"] = tool.max_num_results + if tool.ranking_options: + converted_tool["ranking_options"] = tool.ranking_options + if tool.filters: + converted_tool["filters"] = tool.filters + + includes = "file_search_call.results" if tool.include_search_results else None + elif isinstance(tool, ComputerTool): + converted_tool = { + "type": "computer_use_preview", + "environment": tool.computer.environment, + "display_width": tool.computer.dimensions[0], + "display_height": tool.computer.dimensions[1], + } + includes = None + elif isinstance(tool, HostedMCPTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, ImageGenerationTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, CodeInterpreterTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, LocalShellTool): + converted_tool = { + "type": "local_shell", + } + includes = None + else: + raise UserError(f"Unknown tool type: {type(tool)}, tool") + + return converted_tool, includes + + @classmethod + def _convert_handoff_tool(cls, handoff: Handoff) -> ToolParam: + return { + "name": handoff.tool_name, + "parameters": handoff.input_json_schema, + "strict": handoff.strict_json_schema, + "type": "function", + "description": handoff.tool_description, + } diff --git a/src/agents/run.py b/src/agents/run.py index 5b25df4f2..69cc88815 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1300,6 +1300,7 @@ async def _run_single_turn_streamed( previous_response_id=previous_response_id, conversation_id=conversation_id, prompt=prompt_config, + enable_structured_output_with_tools=agent.enable_structured_output_with_tools, ): # Emit the raw event ASAP streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) @@ -1735,6 +1736,7 @@ async def _get_new_response( previous_response_id=previous_response_id, conversation_id=conversation_id, prompt=prompt_config, + enable_structured_output_with_tools=agent.enable_structured_output_with_tools, ) context_wrapper.usage.add(new_response.usage) diff --git a/src/agents/util/_prompts.py b/src/agents/util/_prompts.py new file mode 100644 index 000000000..df20f4696 --- /dev/null +++ b/src/agents/util/_prompts.py @@ -0,0 +1,117 @@ +"""Utility functions for generating prompts for structured outputs.""" + +import json +import logging +from typing import Any + +from ..agent_output import AgentOutputSchemaBase + +logger = logging.getLogger(__name__) + + +def get_json_output_prompt(output_schema: AgentOutputSchemaBase) -> str: + if output_schema.is_plain_text(): + return "" + + json_output_prompt = "\n\nProvide your output as a JSON object containing the following fields:" + + try: + json_schema = output_schema.json_schema() + + # Extract field names and properties + response_model_properties = {} + json_schema_properties = json_schema.get("properties", {}) + + for field_name, field_properties in json_schema_properties.items(): + formatted_field_properties = { + prop_name: prop_value + for prop_name, prop_value in field_properties.items() + if prop_name != "title" + } + + # Handle enum references + if "allOf" in formatted_field_properties: + ref = formatted_field_properties["allOf"][0].get("$ref", "") + if ref.startswith("#/$defs/"): + enum_name = ref.split("/")[-1] + formatted_field_properties["enum_type"] = enum_name + + response_model_properties[field_name] = formatted_field_properties + + # Handle definitions (nested objects, enums, etc.) + json_schema_defs = json_schema.get("$defs") + if json_schema_defs is not None: + response_model_properties["$defs"] = {} + for def_name, def_properties in json_schema_defs.items(): + if "enum" in def_properties: + # Enum definition + response_model_properties["$defs"][def_name] = { + "type": "string", + "enum": def_properties["enum"], + "description": def_properties.get("description", ""), + } + else: + # Regular object definition + def_fields = def_properties.get("properties") + formatted_def_properties = {} + if def_fields is not None: + for field_name, field_properties in def_fields.items(): + formatted_field_properties = { + prop_name: prop_value + for prop_name, prop_value in field_properties.items() + if prop_name != "title" + } + formatted_def_properties[field_name] = formatted_field_properties + if len(formatted_def_properties) > 0: + response_model_properties["$defs"][def_name] = formatted_def_properties + + if len(response_model_properties) > 0: + # List field names + field_names = [key for key in response_model_properties.keys() if key != "$defs"] + json_output_prompt += "\n" + json_output_prompt += f"\n{json.dumps(field_names)}" + json_output_prompt += "\n" + + # Provide detailed properties + json_output_prompt += "\n\nHere are the properties for each field:" + json_output_prompt += "\n" + json_output_prompt += f"\n{json.dumps(response_model_properties, indent=2)}" + json_output_prompt += "\n" + + except (AttributeError, KeyError, TypeError, ValueError) as e: + # Fallback to simple instruction if schema generation fails + logger.warning( + f"Failed to generate detailed JSON schema for prompt injection: {e}. " + f"Using simple fallback for output type: {output_schema.name()}" + ) + json_output_prompt += f"\nOutput type: {output_schema.name()}" + except Exception as e: + # Catch any other unexpected errors but log them as errors + logger.error( + f"Unexpected error generating JSON prompt for {output_schema.name()}: {e}", + exc_info=True, + ) + json_output_prompt += f"\nOutput type: {output_schema.name()}" + + json_output_prompt += "\n\nIMPORTANT:" + json_output_prompt += "\n- Start your response with `{` and end it with `}`" + json_output_prompt += "\n- Your output will be parsed with json.loads()" + json_output_prompt += "\n- Make sure it only contains valid JSON" + json_output_prompt += "\n- Do NOT include markdown code blocks or any other formatting" + + return json_output_prompt + + +def should_inject_json_prompt( + output_schema: AgentOutputSchemaBase | None, + tools: list[Any], + enable_structured_output_with_tools: bool = False, +) -> bool: + if output_schema is None or output_schema.is_plain_text(): + return False + + # Only inject if explicitly requested by user AND both tools and output_schema are present + if enable_structured_output_with_tools and tools and len(tools) > 0: + return True + + return False diff --git a/tests/fake_model.py b/tests/fake_model.py index 6e13a02a4..e6898dbe1 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -1,343 +1,345 @@ -from __future__ import annotations - -from collections.abc import AsyncIterator -from typing import Any - -from openai.types.responses import ( - Response, - ResponseCompletedEvent, - ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, - ResponseCreatedEvent, - ResponseFunctionCallArgumentsDeltaEvent, - ResponseFunctionCallArgumentsDoneEvent, - ResponseFunctionToolCall, - ResponseInProgressEvent, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponseOutputMessage, - ResponseOutputText, - ResponseReasoningSummaryPartAddedEvent, - ResponseReasoningSummaryPartDoneEvent, - ResponseReasoningSummaryTextDeltaEvent, - ResponseReasoningSummaryTextDoneEvent, - ResponseTextDeltaEvent, - ResponseTextDoneEvent, - ResponseUsage, -) -from openai.types.responses.response_reasoning_item import ResponseReasoningItem -from openai.types.responses.response_reasoning_summary_part_added_event import ( - Part as AddedEventPart, -) -from openai.types.responses.response_reasoning_summary_part_done_event import Part as DoneEventPart -from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails - -from agents.agent_output import AgentOutputSchemaBase -from agents.handoffs import Handoff -from agents.items import ( - ModelResponse, - TResponseInputItem, - TResponseOutputItem, - TResponseStreamEvent, -) -from agents.model_settings import ModelSettings -from agents.models.interface import Model, ModelTracing -from agents.tool import Tool -from agents.tracing import SpanError, generation_span -from agents.usage import Usage - - -class FakeModel(Model): - def __init__( - self, - tracing_enabled: bool = False, - initial_output: list[TResponseOutputItem] | Exception | None = None, - ): - if initial_output is None: - initial_output = [] - self.turn_outputs: list[list[TResponseOutputItem] | Exception] = ( - [initial_output] if initial_output else [] - ) - self.tracing_enabled = tracing_enabled - self.last_turn_args: dict[str, Any] = {} - self.first_turn_args: dict[str, Any] | None = None - self.hardcoded_usage: Usage | None = None - - def set_hardcoded_usage(self, usage: Usage): - self.hardcoded_usage = usage - - def set_next_output(self, output: list[TResponseOutputItem] | Exception): - self.turn_outputs.append(output) - - def add_multiple_turn_outputs(self, outputs: list[list[TResponseOutputItem] | Exception]): - self.turn_outputs.extend(outputs) - - def get_next_output(self) -> list[TResponseOutputItem] | Exception: - if not self.turn_outputs: - return [] - return self.turn_outputs.pop(0) - - async def get_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None, - conversation_id: str | None, - prompt: Any | None, - ) -> ModelResponse: - turn_args = { - "system_instructions": system_instructions, - "input": input, - "model_settings": model_settings, - "tools": tools, - "output_schema": output_schema, - "previous_response_id": previous_response_id, - "conversation_id": conversation_id, - } - - if self.first_turn_args is None: - self.first_turn_args = turn_args.copy() - - self.last_turn_args = turn_args - - with generation_span(disabled=not self.tracing_enabled) as span: - output = self.get_next_output() - - if isinstance(output, Exception): - span.set_error( - SpanError( - message="Error", - data={ - "name": output.__class__.__name__, - "message": str(output), - }, - ) - ) - raise output - - return ModelResponse( - output=output, - usage=self.hardcoded_usage or Usage(), - response_id="resp-789", - ) - - async def stream_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None = None, - conversation_id: str | None = None, - prompt: Any | None = None, - ) -> AsyncIterator[TResponseStreamEvent]: - turn_args = { - "system_instructions": system_instructions, - "input": input, - "model_settings": model_settings, - "tools": tools, - "output_schema": output_schema, - "previous_response_id": previous_response_id, - "conversation_id": conversation_id, - } - - if self.first_turn_args is None: - self.first_turn_args = turn_args.copy() - - self.last_turn_args = turn_args - with generation_span(disabled=not self.tracing_enabled) as span: - output = self.get_next_output() - if isinstance(output, Exception): - span.set_error( - SpanError( - message="Error", - data={ - "name": output.__class__.__name__, - "message": str(output), - }, - ) - ) - raise output - - response = get_response_obj(output, usage=self.hardcoded_usage) - sequence_number = 0 - - yield ResponseCreatedEvent( - type="response.created", - response=response, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseInProgressEvent( - type="response.in_progress", - response=response, - sequence_number=sequence_number, - ) - sequence_number += 1 - - for output_index, output_item in enumerate(output): - yield ResponseOutputItemAddedEvent( - type="response.output_item.added", - item=output_item, - output_index=output_index, - sequence_number=sequence_number, - ) - sequence_number += 1 - - if isinstance(output_item, ResponseReasoningItem): - if output_item.summary: - for summary_index, summary in enumerate(output_item.summary): - yield ResponseReasoningSummaryPartAddedEvent( - type="response.reasoning_summary_part.added", - item_id=output_item.id, - output_index=output_index, - summary_index=summary_index, - part=AddedEventPart(text=summary.text, type=summary.type), - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseReasoningSummaryTextDeltaEvent( - type="response.reasoning_summary_text.delta", - item_id=output_item.id, - output_index=output_index, - summary_index=summary_index, - delta=summary.text, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseReasoningSummaryTextDoneEvent( - type="response.reasoning_summary_text.done", - item_id=output_item.id, - output_index=output_index, - summary_index=summary_index, - text=summary.text, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseReasoningSummaryPartDoneEvent( - type="response.reasoning_summary_part.done", - item_id=output_item.id, - output_index=output_index, - summary_index=summary_index, - part=DoneEventPart(text=summary.text, type=summary.type), - sequence_number=sequence_number, - ) - sequence_number += 1 - - elif isinstance(output_item, ResponseFunctionToolCall): - yield ResponseFunctionCallArgumentsDeltaEvent( - type="response.function_call_arguments.delta", - item_id=output_item.call_id, - output_index=output_index, - delta=output_item.arguments, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseFunctionCallArgumentsDoneEvent( - type="response.function_call_arguments.done", - item_id=output_item.call_id, - output_index=output_index, - arguments=output_item.arguments, - name=output_item.name, - sequence_number=sequence_number, - ) - sequence_number += 1 - - elif isinstance(output_item, ResponseOutputMessage): - for content_index, content_part in enumerate(output_item.content): - if isinstance(content_part, ResponseOutputText): - yield ResponseContentPartAddedEvent( - type="response.content_part.added", - item_id=output_item.id, - output_index=output_index, - content_index=content_index, - part=content_part, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseTextDeltaEvent( - type="response.output_text.delta", - item_id=output_item.id, - output_index=output_index, - content_index=content_index, - delta=content_part.text, - logprobs=[], - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseTextDoneEvent( - type="response.output_text.done", - item_id=output_item.id, - output_index=output_index, - content_index=content_index, - text=content_part.text, - logprobs=[], - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseContentPartDoneEvent( - type="response.content_part.done", - item_id=output_item.id, - output_index=output_index, - content_index=content_index, - part=content_part, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseOutputItemDoneEvent( - type="response.output_item.done", - item=output_item, - output_index=output_index, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseCompletedEvent( - type="response.completed", - response=response, - sequence_number=sequence_number, - ) - - -def get_response_obj( - output: list[TResponseOutputItem], - response_id: str | None = None, - usage: Usage | None = None, -) -> Response: - return Response( - id=response_id or "resp-789", - created_at=123, - model="test_model", - object="response", - output=output, - tool_choice="none", - tools=[], - top_p=None, - parallel_tool_calls=False, - usage=ResponseUsage( - input_tokens=usage.input_tokens if usage else 0, - output_tokens=usage.output_tokens if usage else 0, - total_tokens=usage.total_tokens if usage else 0, - input_tokens_details=InputTokensDetails(cached_tokens=0), - output_tokens_details=OutputTokensDetails(reasoning_tokens=0), - ), - ) +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any + +from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseFunctionCallArgumentsDeltaEvent, + ResponseFunctionCallArgumentsDoneEvent, + ResponseFunctionToolCall, + ResponseInProgressEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningSummaryPartAddedEvent, + ResponseReasoningSummaryPartDoneEvent, + ResponseReasoningSummaryTextDeltaEvent, + ResponseReasoningSummaryTextDoneEvent, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, + ResponseUsage, +) +from openai.types.responses.response_reasoning_item import ResponseReasoningItem +from openai.types.responses.response_reasoning_summary_part_added_event import ( + Part as AddedEventPart, +) +from openai.types.responses.response_reasoning_summary_part_done_event import Part as DoneEventPart +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from agents.agent_output import AgentOutputSchemaBase +from agents.handoffs import Handoff +from agents.items import ( + ModelResponse, + TResponseInputItem, + TResponseOutputItem, + TResponseStreamEvent, +) +from agents.model_settings import ModelSettings +from agents.models.interface import Model, ModelTracing +from agents.tool import Tool +from agents.tracing import SpanError, generation_span +from agents.usage import Usage + + +class FakeModel(Model): + def __init__( + self, + tracing_enabled: bool = False, + initial_output: list[TResponseOutputItem] | Exception | None = None, + ): + if initial_output is None: + initial_output = [] + self.turn_outputs: list[list[TResponseOutputItem] | Exception] = ( + [initial_output] if initial_output else [] + ) + self.tracing_enabled = tracing_enabled + self.last_turn_args: dict[str, Any] = {} + self.first_turn_args: dict[str, Any] | None = None + self.hardcoded_usage: Usage | None = None + + def set_hardcoded_usage(self, usage: Usage): + self.hardcoded_usage = usage + + def set_next_output(self, output: list[TResponseOutputItem] | Exception): + self.turn_outputs.append(output) + + def add_multiple_turn_outputs(self, outputs: list[list[TResponseOutputItem] | Exception]): + self.turn_outputs.extend(outputs) + + def get_next_output(self) -> list[TResponseOutputItem] | Exception: + if not self.turn_outputs: + return [] + return self.turn_outputs.pop(0) + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: Any | None, + enable_structured_output_with_tools: bool = False, + ) -> ModelResponse: + turn_args = { + "system_instructions": system_instructions, + "input": input, + "model_settings": model_settings, + "tools": tools, + "output_schema": output_schema, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + } + + if self.first_turn_args is None: + self.first_turn_args = turn_args.copy() + + self.last_turn_args = turn_args + + with generation_span(disabled=not self.tracing_enabled) as span: + output = self.get_next_output() + + if isinstance(output, Exception): + span.set_error( + SpanError( + message="Error", + data={ + "name": output.__class__.__name__, + "message": str(output), + }, + ) + ) + raise output + + return ModelResponse( + output=output, + usage=self.hardcoded_usage or Usage(), + response_id="resp-789", + ) + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None = None, + conversation_id: str | None = None, + prompt: Any | None = None, + enable_structured_output_with_tools: bool = False, + ) -> AsyncIterator[TResponseStreamEvent]: + turn_args = { + "system_instructions": system_instructions, + "input": input, + "model_settings": model_settings, + "tools": tools, + "output_schema": output_schema, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + } + + if self.first_turn_args is None: + self.first_turn_args = turn_args.copy() + + self.last_turn_args = turn_args + with generation_span(disabled=not self.tracing_enabled) as span: + output = self.get_next_output() + if isinstance(output, Exception): + span.set_error( + SpanError( + message="Error", + data={ + "name": output.__class__.__name__, + "message": str(output), + }, + ) + ) + raise output + + response = get_response_obj(output, usage=self.hardcoded_usage) + sequence_number = 0 + + yield ResponseCreatedEvent( + type="response.created", + response=response, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseInProgressEvent( + type="response.in_progress", + response=response, + sequence_number=sequence_number, + ) + sequence_number += 1 + + for output_index, output_item in enumerate(output): + yield ResponseOutputItemAddedEvent( + type="response.output_item.added", + item=output_item, + output_index=output_index, + sequence_number=sequence_number, + ) + sequence_number += 1 + + if isinstance(output_item, ResponseReasoningItem): + if output_item.summary: + for summary_index, summary in enumerate(output_item.summary): + yield ResponseReasoningSummaryPartAddedEvent( + type="response.reasoning_summary_part.added", + item_id=output_item.id, + output_index=output_index, + summary_index=summary_index, + part=AddedEventPart(text=summary.text, type=summary.type), + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseReasoningSummaryTextDeltaEvent( + type="response.reasoning_summary_text.delta", + item_id=output_item.id, + output_index=output_index, + summary_index=summary_index, + delta=summary.text, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseReasoningSummaryTextDoneEvent( + type="response.reasoning_summary_text.done", + item_id=output_item.id, + output_index=output_index, + summary_index=summary_index, + text=summary.text, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseReasoningSummaryPartDoneEvent( + type="response.reasoning_summary_part.done", + item_id=output_item.id, + output_index=output_index, + summary_index=summary_index, + part=DoneEventPart(text=summary.text, type=summary.type), + sequence_number=sequence_number, + ) + sequence_number += 1 + + elif isinstance(output_item, ResponseFunctionToolCall): + yield ResponseFunctionCallArgumentsDeltaEvent( + type="response.function_call_arguments.delta", + item_id=output_item.call_id, + output_index=output_index, + delta=output_item.arguments, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseFunctionCallArgumentsDoneEvent( + type="response.function_call_arguments.done", + item_id=output_item.call_id, + output_index=output_index, + arguments=output_item.arguments, + name=output_item.name, + sequence_number=sequence_number, + ) + sequence_number += 1 + + elif isinstance(output_item, ResponseOutputMessage): + for content_index, content_part in enumerate(output_item.content): + if isinstance(content_part, ResponseOutputText): + yield ResponseContentPartAddedEvent( + type="response.content_part.added", + item_id=output_item.id, + output_index=output_index, + content_index=content_index, + part=content_part, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseTextDeltaEvent( + type="response.output_text.delta", + item_id=output_item.id, + output_index=output_index, + content_index=content_index, + delta=content_part.text, + logprobs=[], + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseTextDoneEvent( + type="response.output_text.done", + item_id=output_item.id, + output_index=output_index, + content_index=content_index, + text=content_part.text, + logprobs=[], + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseContentPartDoneEvent( + type="response.content_part.done", + item_id=output_item.id, + output_index=output_index, + content_index=content_index, + part=content_part, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseOutputItemDoneEvent( + type="response.output_item.done", + item=output_item, + output_index=output_index, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseCompletedEvent( + type="response.completed", + response=response, + sequence_number=sequence_number, + ) + + +def get_response_obj( + output: list[TResponseOutputItem], + response_id: str | None = None, + usage: Usage | None = None, +) -> Response: + return Response( + id=response_id or "resp-789", + created_at=123, + model="test_model", + object="response", + output=output, + tool_choice="none", + tools=[], + top_p=None, + parallel_tool_calls=False, + usage=ResponseUsage( + input_tokens=usage.input_tokens if usage else 0, + output_tokens=usage.output_tokens if usage else 0, + total_tokens=usage.total_tokens if usage else 0, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ), + ) diff --git a/tests/test_agent_prompt.py b/tests/test_agent_prompt.py index 3d5ed5a3f..2e9334861 100644 --- a/tests/test_agent_prompt.py +++ b/tests/test_agent_prompt.py @@ -1,99 +1,100 @@ -import pytest - -from agents import Agent, Prompt, RunContextWrapper, Runner - -from .fake_model import FakeModel -from .test_responses import get_text_message - - -class PromptCaptureFakeModel(FakeModel): - """Subclass of FakeModel that records the prompt passed to the model.""" - - def __init__(self): - super().__init__() - self.last_prompt = None - - async def get_response( - self, - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - tracing, - *, - previous_response_id, - conversation_id, - prompt, - ): - # Record the prompt that the agent resolved and passed in. - self.last_prompt = prompt - return await super().get_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - tracing, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt, - ) - - -@pytest.mark.asyncio -async def test_static_prompt_is_resolved_correctly(): - static_prompt: Prompt = { - "id": "my_prompt", - "version": "1", - "variables": {"some_var": "some_value"}, - } - - agent = Agent(name="test", prompt=static_prompt) - context_wrapper = RunContextWrapper(context=None) - - resolved = await agent.get_prompt(context_wrapper) - - assert resolved == { - "id": "my_prompt", - "version": "1", - "variables": {"some_var": "some_value"}, - } - - -@pytest.mark.asyncio -async def test_dynamic_prompt_is_resolved_correctly(): - dynamic_prompt_value: Prompt = {"id": "dyn_prompt", "version": "2"} - - def dynamic_prompt_fn(_data): - return dynamic_prompt_value - - agent = Agent(name="test", prompt=dynamic_prompt_fn) - context_wrapper = RunContextWrapper(context=None) - - resolved = await agent.get_prompt(context_wrapper) - - assert resolved == {"id": "dyn_prompt", "version": "2", "variables": None} - - -@pytest.mark.asyncio -async def test_prompt_is_passed_to_model(): - static_prompt: Prompt = {"id": "model_prompt"} - - model = PromptCaptureFakeModel() - agent = Agent(name="test", model=model, prompt=static_prompt) - - # Ensure the model returns a simple message so the run completes in one turn. - model.set_next_output([get_text_message("done")]) - - await Runner.run(agent, input="hello") - - # The model should have received the prompt resolved by the agent. - expected_prompt = { - "id": "model_prompt", - "version": None, - "variables": None, - } - assert model.last_prompt == expected_prompt +import pytest + +from agents import Agent, Prompt, RunContextWrapper, Runner + +from .fake_model import FakeModel +from .test_responses import get_text_message + + +class PromptCaptureFakeModel(FakeModel): + """Subclass of FakeModel that records the prompt passed to the model.""" + + def __init__(self): + super().__init__() + self.last_prompt = None + + async def get_response( + self, + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + tracing, + *, + previous_response_id, + conversation_id, + prompt, + enable_structured_output_with_tools: bool = False, + ): + # Record the prompt that the agent resolved and passed in. + self.last_prompt = prompt + return await super().get_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ) + + +@pytest.mark.asyncio +async def test_static_prompt_is_resolved_correctly(): + static_prompt: Prompt = { + "id": "my_prompt", + "version": "1", + "variables": {"some_var": "some_value"}, + } + + agent = Agent(name="test", prompt=static_prompt) + context_wrapper = RunContextWrapper(context=None) + + resolved = await agent.get_prompt(context_wrapper) + + assert resolved == { + "id": "my_prompt", + "version": "1", + "variables": {"some_var": "some_value"}, + } + + +@pytest.mark.asyncio +async def test_dynamic_prompt_is_resolved_correctly(): + dynamic_prompt_value: Prompt = {"id": "dyn_prompt", "version": "2"} + + def dynamic_prompt_fn(_data): + return dynamic_prompt_value + + agent = Agent(name="test", prompt=dynamic_prompt_fn) + context_wrapper = RunContextWrapper(context=None) + + resolved = await agent.get_prompt(context_wrapper) + + assert resolved == {"id": "dyn_prompt", "version": "2", "variables": None} + + +@pytest.mark.asyncio +async def test_prompt_is_passed_to_model(): + static_prompt: Prompt = {"id": "model_prompt"} + + model = PromptCaptureFakeModel() + agent = Agent(name="test", model=model, prompt=static_prompt) + + # Ensure the model returns a simple message so the run completes in one turn. + model.set_next_output([get_text_message("done")]) + + await Runner.run(agent, input="hello") + + # The model should have received the prompt resolved by the agent. + expected_prompt = { + "id": "model_prompt", + "version": None, + "variables": None, + } + assert model.last_prompt == expected_prompt diff --git a/tests/test_gemini_local.py b/tests/test_gemini_local.py new file mode 100644 index 000000000..208364425 --- /dev/null +++ b/tests/test_gemini_local.py @@ -0,0 +1,169 @@ +""" +Test script for Gemini with prompt injection feature. +Run this locally to test the implementation with your own API key. + +Usage: +1. Set your API key: export GOOGLE_API_KEY=your_key_here +2. Run: python test_gemini_local.py +""" + +import asyncio +import logging +import os +from typing import Any + +from pydantic import BaseModel + +from agents import Agent, function_tool +from agents.extensions.models.litellm_model import LitellmModel + +# Enable logging to see the final system prompt sent to Gemini +logging.basicConfig(level=logging.INFO, format="%(message)s") + + +# Define your output schema +class WeatherReport(BaseModel): + """Weather report structure.""" + + city: str + temperature: float + conditions: str + humidity: int + + +# Define a simple tool +@function_tool +def get_weather(city: str) -> dict[str, Any]: + """Get the current weather for a city.""" + # Mock weather data + weather_data = { + "Tokyo": {"temperature": 22.5, "conditions": "sunny", "humidity": 65}, + "London": {"temperature": 15.0, "conditions": "rainy", "humidity": 80}, + "New York": {"temperature": 18.0, "conditions": "cloudy", "humidity": 70}, + } + + data = weather_data.get(city, {"temperature": 20.0, "conditions": "unknown", "humidity": 60}) + data["city"] = city + return data + + +async def main(): + """Main test function.""" + + # Check for API key + if not os.getenv("GOOGLE_API_KEY"): + print("ERROR: GOOGLE_API_KEY environment variable not set!") + print("\nTo set it:") + print(" Windows PowerShell: $env:GOOGLE_API_KEY='your_key_here'") + print(" Windows CMD: set GOOGLE_API_KEY=your_key_here") + print(" Linux/Mac: export GOOGLE_API_KEY=your_key_here") + return + + print("=" * 80) + print("Testing Gemini with Prompt Injection Feature") + print("=" * 80) + print("\n🔍 The final system prompt sent to Gemini will be shown below") + print("=" * 80) + + # Create agent with prompt injection enabled + agent = Agent( + name="weather_assistant", + instructions=( + "You are a helpful weather assistant. Use the get_weather tool to " + "fetch weather information, then provide a structured report." + ), + model=LitellmModel("gemini/gemini-2.5-flash"), + tools=[get_weather], + output_type=WeatherReport, + enable_structured_output_with_tools=True, # CRITICAL: Enable for Gemini! + ) + + print("\nAgent Configuration:") + print(" Model: gemini/gemini-2.5-flash") + print(f" Tools: {[tool.name for tool in agent.tools]}") + print(" Output Type: WeatherReport") + print(f" enable_structured_output_with_tools: {agent.enable_structured_output_with_tools}") + + print(f"\n{'=' * 80}") + print("Running agent with input: 'What's the weather in Tokyo?'") + print(f"{'=' * 80}\n") + + print("📤 Sending request to Gemini...") + print("⏳ Waiting for response...\n") + + try: + from agents import Runner + + result = await Runner.run( + starting_agent=agent, + input="What's the weather in Tokyo?", + ) + + print("\n✅ Agent execution completed!") + + print(f"\n{'=' * 80}") + print("🎉 SUCCESS! Response Received") + print(f"{'=' * 80}") + + print("\n📊 Result Analysis:") + print(f"{'=' * 80}") + print(f"Output Type: {type(result.final_output).__name__}") + print(f"Output Value: {result.final_output}") + print(f"{'=' * 80}") + + if isinstance(result.final_output, WeatherReport): + print("\n✅ STRUCTURED OUTPUT PARSING: SUCCESS!") + print(f"{'=' * 80}") + print("\n📋 Weather Report (Parsed from JSON):") + print(f"{'=' * 80}") + print(f" 🌍 City: {result.final_output.city}") + print(f" 🌡️ Temperature: {result.final_output.temperature}°C") + print(f" ☁️ Conditions: {result.final_output.conditions}") + print(f" 💧 Humidity: {result.final_output.humidity}%") + print(f"{'=' * 80}") + else: + print( + f"\n⚠️ WARNING: Output type is {type(result.final_output)}, expected WeatherReport" + ) + + print("\n📈 Token Usage:") + print(f"{'=' * 80}") + print(f" 📥 Input tokens: {result.context_wrapper.usage.input_tokens}") + print(f" 📤 Output tokens: {result.context_wrapper.usage.output_tokens}") + print(f" 📊 Total tokens: {result.context_wrapper.usage.total_tokens}") + print(f"{'=' * 80}") + + print("\n💡 What Happened:") + print(" 1. ✅ Prompt injection added JSON schema to system prompt") + print(" 2. ✅ Gemini called get_weather tool") + print(" 3. ✅ Gemini returned structured JSON matching WeatherReport schema") + print(" 4. ✅ SDK parsed JSON into WeatherReport Pydantic model") + print("\n🎯 Feature is working correctly!") + + except Exception as e: + print(f"\n{'=' * 80}") + print("❌ ERROR!") + print(f"{'=' * 80}") + print(f"\n💥 Error: {e}") + print("\n🔧 Troubleshooting Steps:") + print(f"{'=' * 80}") + print(" 1. ✓ Check your API key is valid") + print(" 2. ✓ Ensure litellm is installed: pip install 'openai-agents[litellm]'") + print(" 3. ✓ Check internet connection") + print(" 4. ✓ Check DEBUG logs above for prompt details") + print(f"{'=' * 80}") + + import traceback + + print("\n📋 Full traceback:") + print(f"{'=' * 80}") + traceback.print_exc() + print(f"{'=' * 80}") + + +if __name__ == "__main__": + print("\n" + "=" * 80) + print("Gemini + Prompt Injection Test") + print("=" * 80 + "\n") + + asyncio.run(main()) diff --git a/tests/test_streaming_tool_call_arguments.py b/tests/test_streaming_tool_call_arguments.py index ce476e59b..041a24713 100644 --- a/tests/test_streaming_tool_call_arguments.py +++ b/tests/test_streaming_tool_call_arguments.py @@ -1,373 +1,375 @@ -""" -Tests to ensure that tool call arguments are properly populated in streaming events. - -This test specifically guards against the regression where tool_called events -were emitted with empty arguments during streaming (Issue #1629). -""" - -import json -from collections.abc import AsyncIterator -from typing import Any, Optional, Union, cast - -import pytest -from openai.types.responses import ( - ResponseCompletedEvent, - ResponseFunctionToolCall, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, -) - -from agents import Agent, Runner, function_tool -from agents.agent_output import AgentOutputSchemaBase -from agents.handoffs import Handoff -from agents.items import TResponseInputItem, TResponseOutputItem, TResponseStreamEvent -from agents.model_settings import ModelSettings -from agents.models.interface import Model, ModelTracing -from agents.stream_events import RunItemStreamEvent -from agents.tool import Tool -from agents.tracing import generation_span - -from .fake_model import get_response_obj -from .test_responses import get_function_tool_call - - -class StreamingFakeModel(Model): - """A fake model that actually emits streaming events to test our streaming fix.""" - - def __init__(self): - self.turn_outputs: list[list[TResponseOutputItem]] = [] - self.last_turn_args: dict[str, Any] = {} - - def set_next_output(self, output: list[TResponseOutputItem]): - self.turn_outputs.append(output) - - def get_next_output(self) -> list[TResponseOutputItem]: - if not self.turn_outputs: - return [] - return self.turn_outputs.pop(0) - - async def get_response( - self, - system_instructions: Optional[str], - input: Union[str, list[TResponseInputItem]], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: Optional[AgentOutputSchemaBase], - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: Optional[str], - conversation_id: Optional[str], - prompt: Optional[Any], - ): - raise NotImplementedError("Use stream_response instead") - - async def stream_response( - self, - system_instructions: Optional[str], - input: Union[str, list[TResponseInputItem]], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: Optional[AgentOutputSchemaBase], - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: Optional[str] = None, - conversation_id: Optional[str] = None, - prompt: Optional[Any] = None, - ) -> AsyncIterator[TResponseStreamEvent]: - """Stream events that simulate real OpenAI streaming behavior for tool calls.""" - self.last_turn_args = { - "system_instructions": system_instructions, - "input": input, - "model_settings": model_settings, - "tools": tools, - "output_schema": output_schema, - "previous_response_id": previous_response_id, - "conversation_id": conversation_id, - } - - with generation_span(disabled=True) as _: - output = self.get_next_output() - - sequence_number = 0 - - # Emit each output item with proper streaming events - for item in output: - if isinstance(item, ResponseFunctionToolCall): - # First: emit ResponseOutputItemAddedEvent with EMPTY arguments - # (this simulates the real streaming behavior that was causing the bug) - empty_args_item = ResponseFunctionToolCall( - id=item.id, - call_id=item.call_id, - type=item.type, - name=item.name, - arguments="", # EMPTY - this is the bug condition! - ) - - yield ResponseOutputItemAddedEvent( - item=empty_args_item, - output_index=0, - type="response.output_item.added", - sequence_number=sequence_number, - ) - sequence_number += 1 - - # Then: emit ResponseOutputItemDoneEvent with COMPLETE arguments - yield ResponseOutputItemDoneEvent( - item=item, # This has the complete arguments - output_index=0, - type="response.output_item.done", - sequence_number=sequence_number, - ) - sequence_number += 1 - - # Finally: emit completion - yield ResponseCompletedEvent( - type="response.completed", - response=get_response_obj(output), - sequence_number=sequence_number, - ) - - -@function_tool -def calculate_sum(a: int, b: int) -> str: - """Add two numbers together.""" - return str(a + b) - - -@function_tool -def format_message(name: str, message: str, urgent: bool = False) -> str: - """Format a message with name and urgency.""" - prefix = "URGENT: " if urgent else "" - return f"{prefix}Hello {name}, {message}" - - -@pytest.mark.asyncio -async def test_streaming_tool_call_arguments_not_empty(): - """Test that tool_called events contain non-empty arguments during streaming.""" - model = StreamingFakeModel() - agent = Agent( - name="TestAgent", - model=model, - tools=[calculate_sum], - ) - - # Set up a tool call with arguments - expected_arguments = '{"a": 5, "b": 3}' - model.set_next_output( - [ - get_function_tool_call("calculate_sum", expected_arguments, "call_123"), - ] - ) - - result = Runner.run_streamed(agent, input="Add 5 and 3") - - tool_called_events = [] - async for event in result.stream_events(): - if ( - event.type == "run_item_stream_event" - and isinstance(event, RunItemStreamEvent) - and event.name == "tool_called" - ): - tool_called_events.append(event) - - # Verify we got exactly one tool_called event - assert len(tool_called_events) == 1, ( - f"Expected 1 tool_called event, got {len(tool_called_events)}" - ) - - tool_event = tool_called_events[0] - - # Verify the event has the expected structure - assert hasattr(tool_event.item, "raw_item"), "tool_called event should have raw_item" - assert hasattr(tool_event.item.raw_item, "arguments"), "raw_item should have arguments field" - - # The critical test: arguments should NOT be empty - # Cast to ResponseFunctionToolCall since we know that's what it is in our test - raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) - actual_arguments = raw_item.arguments - assert actual_arguments != "", ( - f"Tool call arguments should not be empty, got: '{actual_arguments}'" - ) - assert actual_arguments is not None, "Tool call arguments should not be None" - - # Verify arguments contain the expected data - assert actual_arguments == expected_arguments, ( - f"Expected arguments '{expected_arguments}', got '{actual_arguments}'" - ) - - # Verify arguments are valid JSON that can be parsed - try: - parsed_args = json.loads(actual_arguments) - assert parsed_args == {"a": 5, "b": 3}, ( - f"Parsed arguments should match expected values, got {parsed_args}" - ) - except json.JSONDecodeError as e: - pytest.fail( - f"Tool call arguments should be valid JSON, but got: '{actual_arguments}' with error: {e}" # noqa: E501 - ) - - -@pytest.mark.asyncio -async def test_streaming_tool_call_arguments_complex(): - """Test streaming tool calls with complex arguments including strings and booleans.""" - model = StreamingFakeModel() - agent = Agent( - name="TestAgent", - model=model, - tools=[format_message], - ) - - # Set up a tool call with complex arguments - expected_arguments = ( - '{"name": "Alice", "message": "Your meeting is starting soon", "urgent": true}' - ) - model.set_next_output( - [ - get_function_tool_call("format_message", expected_arguments, "call_456"), - ] - ) - - result = Runner.run_streamed(agent, input="Format a message for Alice") - - tool_called_events = [] - async for event in result.stream_events(): - if ( - event.type == "run_item_stream_event" - and isinstance(event, RunItemStreamEvent) - and event.name == "tool_called" - ): - tool_called_events.append(event) - - assert len(tool_called_events) == 1, ( - f"Expected 1 tool_called event, got {len(tool_called_events)}" - ) - - tool_event = tool_called_events[0] - # Cast to ResponseFunctionToolCall since we know that's what it is in our test - raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) - actual_arguments = raw_item.arguments - - # Critical checks for the regression - assert actual_arguments != "", "Tool call arguments should not be empty" - assert actual_arguments is not None, "Tool call arguments should not be None" - assert actual_arguments == expected_arguments, ( - f"Expected '{expected_arguments}', got '{actual_arguments}'" - ) - - # Verify the complex arguments parse correctly - parsed_args = json.loads(actual_arguments) - expected_parsed = {"name": "Alice", "message": "Your meeting is starting soon", "urgent": True} - assert parsed_args == expected_parsed, f"Parsed arguments should match, got {parsed_args}" - - -@pytest.mark.asyncio -async def test_streaming_multiple_tool_calls_arguments(): - """Test that multiple tool calls in streaming all have proper arguments.""" - model = StreamingFakeModel() - agent = Agent( - name="TestAgent", - model=model, - tools=[calculate_sum, format_message], - ) - - # Set up multiple tool calls - model.set_next_output( - [ - get_function_tool_call("calculate_sum", '{"a": 10, "b": 20}', "call_1"), - get_function_tool_call( - "format_message", '{"name": "Bob", "message": "Test"}', "call_2" - ), - ] - ) - - result = Runner.run_streamed(agent, input="Do some calculations") - - tool_called_events = [] - async for event in result.stream_events(): - if ( - event.type == "run_item_stream_event" - and isinstance(event, RunItemStreamEvent) - and event.name == "tool_called" - ): - tool_called_events.append(event) - - # Should have exactly 2 tool_called events - assert len(tool_called_events) == 2, ( - f"Expected 2 tool_called events, got {len(tool_called_events)}" - ) - - # Check first tool call - event1 = tool_called_events[0] - # Cast to ResponseFunctionToolCall since we know that's what it is in our test - raw_item1 = cast(ResponseFunctionToolCall, event1.item.raw_item) - args1 = raw_item1.arguments - assert args1 != "", "First tool call arguments should not be empty" - expected_args1 = '{"a": 10, "b": 20}' - assert args1 == expected_args1, ( - f"First tool call args: expected '{expected_args1}', got '{args1}'" - ) - - # Check second tool call - event2 = tool_called_events[1] - # Cast to ResponseFunctionToolCall since we know that's what it is in our test - raw_item2 = cast(ResponseFunctionToolCall, event2.item.raw_item) - args2 = raw_item2.arguments - assert args2 != "", "Second tool call arguments should not be empty" - expected_args2 = '{"name": "Bob", "message": "Test"}' - assert args2 == expected_args2, ( - f"Second tool call args: expected '{expected_args2}', got '{args2}'" - ) - - -@pytest.mark.asyncio -async def test_streaming_tool_call_with_empty_arguments(): - """Test that tool calls with legitimately empty arguments still work correctly.""" - model = StreamingFakeModel() - - @function_tool - def get_current_time() -> str: - """Get the current time (no arguments needed).""" - return "2024-01-15 10:30:00" - - agent = Agent( - name="TestAgent", - model=model, - tools=[get_current_time], - ) - - # Tool call with empty arguments (legitimate case) - model.set_next_output( - [ - get_function_tool_call("get_current_time", "{}", "call_time"), - ] - ) - - result = Runner.run_streamed(agent, input="What time is it?") - - tool_called_events = [] - async for event in result.stream_events(): - if ( - event.type == "run_item_stream_event" - and isinstance(event, RunItemStreamEvent) - and event.name == "tool_called" - ): - tool_called_events.append(event) - - assert len(tool_called_events) == 1, ( - f"Expected 1 tool_called event, got {len(tool_called_events)}" - ) - - tool_event = tool_called_events[0] - # Cast to ResponseFunctionToolCall since we know that's what it is in our test - raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) - actual_arguments = raw_item.arguments - - # Even "empty" arguments should be "{}", not literally empty string - assert actual_arguments is not None, "Arguments should not be None" - assert actual_arguments == "{}", f"Expected empty JSON object '{{}}', got '{actual_arguments}'" - - # Should parse as valid empty JSON - parsed_args = json.loads(actual_arguments) - assert parsed_args == {}, f"Should parse to empty dict, got {parsed_args}" +""" +Tests to ensure that tool call arguments are properly populated in streaming events. + +This test specifically guards against the regression where tool_called events +were emitted with empty arguments during streaming (Issue #1629). +""" + +import json +from collections.abc import AsyncIterator +from typing import Any, Optional, Union, cast + +import pytest +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseFunctionToolCall, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, +) + +from agents import Agent, Runner, function_tool +from agents.agent_output import AgentOutputSchemaBase +from agents.handoffs import Handoff +from agents.items import TResponseInputItem, TResponseOutputItem, TResponseStreamEvent +from agents.model_settings import ModelSettings +from agents.models.interface import Model, ModelTracing +from agents.stream_events import RunItemStreamEvent +from agents.tool import Tool +from agents.tracing import generation_span + +from .fake_model import get_response_obj +from .test_responses import get_function_tool_call + + +class StreamingFakeModel(Model): + """A fake model that actually emits streaming events to test our streaming fix.""" + + def __init__(self): + self.turn_outputs: list[list[TResponseOutputItem]] = [] + self.last_turn_args: dict[str, Any] = {} + + def set_next_output(self, output: list[TResponseOutputItem]): + self.turn_outputs.append(output) + + def get_next_output(self) -> list[TResponseOutputItem]: + if not self.turn_outputs: + return [] + return self.turn_outputs.pop(0) + + async def get_response( + self, + system_instructions: Optional[str], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Optional[AgentOutputSchemaBase], + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: Optional[str], + conversation_id: Optional[str], + prompt: Optional[Any], + enable_structured_output_with_tools: bool = False, + ): + raise NotImplementedError("Use stream_response instead") + + async def stream_response( + self, + system_instructions: Optional[str], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Optional[AgentOutputSchemaBase], + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: Optional[str] = None, + conversation_id: Optional[str] = None, + prompt: Optional[Any] = None, + enable_structured_output_with_tools: bool = False, + ) -> AsyncIterator[TResponseStreamEvent]: + """Stream events that simulate real OpenAI streaming behavior for tool calls.""" + self.last_turn_args = { + "system_instructions": system_instructions, + "input": input, + "model_settings": model_settings, + "tools": tools, + "output_schema": output_schema, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + } + + with generation_span(disabled=True) as _: + output = self.get_next_output() + + sequence_number = 0 + + # Emit each output item with proper streaming events + for item in output: + if isinstance(item, ResponseFunctionToolCall): + # First: emit ResponseOutputItemAddedEvent with EMPTY arguments + # (this simulates the real streaming behavior that was causing the bug) + empty_args_item = ResponseFunctionToolCall( + id=item.id, + call_id=item.call_id, + type=item.type, + name=item.name, + arguments="", # EMPTY - this is the bug condition! + ) + + yield ResponseOutputItemAddedEvent( + item=empty_args_item, + output_index=0, + type="response.output_item.added", + sequence_number=sequence_number, + ) + sequence_number += 1 + + # Then: emit ResponseOutputItemDoneEvent with COMPLETE arguments + yield ResponseOutputItemDoneEvent( + item=item, # This has the complete arguments + output_index=0, + type="response.output_item.done", + sequence_number=sequence_number, + ) + sequence_number += 1 + + # Finally: emit completion + yield ResponseCompletedEvent( + type="response.completed", + response=get_response_obj(output), + sequence_number=sequence_number, + ) + + +@function_tool +def calculate_sum(a: int, b: int) -> str: + """Add two numbers together.""" + return str(a + b) + + +@function_tool +def format_message(name: str, message: str, urgent: bool = False) -> str: + """Format a message with name and urgency.""" + prefix = "URGENT: " if urgent else "" + return f"{prefix}Hello {name}, {message}" + + +@pytest.mark.asyncio +async def test_streaming_tool_call_arguments_not_empty(): + """Test that tool_called events contain non-empty arguments during streaming.""" + model = StreamingFakeModel() + agent = Agent( + name="TestAgent", + model=model, + tools=[calculate_sum], + ) + + # Set up a tool call with arguments + expected_arguments = '{"a": 5, "b": 3}' + model.set_next_output( + [ + get_function_tool_call("calculate_sum", expected_arguments, "call_123"), + ] + ) + + result = Runner.run_streamed(agent, input="Add 5 and 3") + + tool_called_events = [] + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and isinstance(event, RunItemStreamEvent) + and event.name == "tool_called" + ): + tool_called_events.append(event) + + # Verify we got exactly one tool_called event + assert len(tool_called_events) == 1, ( + f"Expected 1 tool_called event, got {len(tool_called_events)}" + ) + + tool_event = tool_called_events[0] + + # Verify the event has the expected structure + assert hasattr(tool_event.item, "raw_item"), "tool_called event should have raw_item" + assert hasattr(tool_event.item.raw_item, "arguments"), "raw_item should have arguments field" + + # The critical test: arguments should NOT be empty + # Cast to ResponseFunctionToolCall since we know that's what it is in our test + raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) + actual_arguments = raw_item.arguments + assert actual_arguments != "", ( + f"Tool call arguments should not be empty, got: '{actual_arguments}'" + ) + assert actual_arguments is not None, "Tool call arguments should not be None" + + # Verify arguments contain the expected data + assert actual_arguments == expected_arguments, ( + f"Expected arguments '{expected_arguments}', got '{actual_arguments}'" + ) + + # Verify arguments are valid JSON that can be parsed + try: + parsed_args = json.loads(actual_arguments) + assert parsed_args == {"a": 5, "b": 3}, ( + f"Parsed arguments should match expected values, got {parsed_args}" + ) + except json.JSONDecodeError as e: + pytest.fail( + f"Tool call arguments should be valid JSON, but got: '{actual_arguments}' with error: {e}" # noqa: E501 + ) + + +@pytest.mark.asyncio +async def test_streaming_tool_call_arguments_complex(): + """Test streaming tool calls with complex arguments including strings and booleans.""" + model = StreamingFakeModel() + agent = Agent( + name="TestAgent", + model=model, + tools=[format_message], + ) + + # Set up a tool call with complex arguments + expected_arguments = ( + '{"name": "Alice", "message": "Your meeting is starting soon", "urgent": true}' + ) + model.set_next_output( + [ + get_function_tool_call("format_message", expected_arguments, "call_456"), + ] + ) + + result = Runner.run_streamed(agent, input="Format a message for Alice") + + tool_called_events = [] + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and isinstance(event, RunItemStreamEvent) + and event.name == "tool_called" + ): + tool_called_events.append(event) + + assert len(tool_called_events) == 1, ( + f"Expected 1 tool_called event, got {len(tool_called_events)}" + ) + + tool_event = tool_called_events[0] + # Cast to ResponseFunctionToolCall since we know that's what it is in our test + raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) + actual_arguments = raw_item.arguments + + # Critical checks for the regression + assert actual_arguments != "", "Tool call arguments should not be empty" + assert actual_arguments is not None, "Tool call arguments should not be None" + assert actual_arguments == expected_arguments, ( + f"Expected '{expected_arguments}', got '{actual_arguments}'" + ) + + # Verify the complex arguments parse correctly + parsed_args = json.loads(actual_arguments) + expected_parsed = {"name": "Alice", "message": "Your meeting is starting soon", "urgent": True} + assert parsed_args == expected_parsed, f"Parsed arguments should match, got {parsed_args}" + + +@pytest.mark.asyncio +async def test_streaming_multiple_tool_calls_arguments(): + """Test that multiple tool calls in streaming all have proper arguments.""" + model = StreamingFakeModel() + agent = Agent( + name="TestAgent", + model=model, + tools=[calculate_sum, format_message], + ) + + # Set up multiple tool calls + model.set_next_output( + [ + get_function_tool_call("calculate_sum", '{"a": 10, "b": 20}', "call_1"), + get_function_tool_call( + "format_message", '{"name": "Bob", "message": "Test"}', "call_2" + ), + ] + ) + + result = Runner.run_streamed(agent, input="Do some calculations") + + tool_called_events = [] + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and isinstance(event, RunItemStreamEvent) + and event.name == "tool_called" + ): + tool_called_events.append(event) + + # Should have exactly 2 tool_called events + assert len(tool_called_events) == 2, ( + f"Expected 2 tool_called events, got {len(tool_called_events)}" + ) + + # Check first tool call + event1 = tool_called_events[0] + # Cast to ResponseFunctionToolCall since we know that's what it is in our test + raw_item1 = cast(ResponseFunctionToolCall, event1.item.raw_item) + args1 = raw_item1.arguments + assert args1 != "", "First tool call arguments should not be empty" + expected_args1 = '{"a": 10, "b": 20}' + assert args1 == expected_args1, ( + f"First tool call args: expected '{expected_args1}', got '{args1}'" + ) + + # Check second tool call + event2 = tool_called_events[1] + # Cast to ResponseFunctionToolCall since we know that's what it is in our test + raw_item2 = cast(ResponseFunctionToolCall, event2.item.raw_item) + args2 = raw_item2.arguments + assert args2 != "", "Second tool call arguments should not be empty" + expected_args2 = '{"name": "Bob", "message": "Test"}' + assert args2 == expected_args2, ( + f"Second tool call args: expected '{expected_args2}', got '{args2}'" + ) + + +@pytest.mark.asyncio +async def test_streaming_tool_call_with_empty_arguments(): + """Test that tool calls with legitimately empty arguments still work correctly.""" + model = StreamingFakeModel() + + @function_tool + def get_current_time() -> str: + """Get the current time (no arguments needed).""" + return "2024-01-15 10:30:00" + + agent = Agent( + name="TestAgent", + model=model, + tools=[get_current_time], + ) + + # Tool call with empty arguments (legitimate case) + model.set_next_output( + [ + get_function_tool_call("get_current_time", "{}", "call_time"), + ] + ) + + result = Runner.run_streamed(agent, input="What time is it?") + + tool_called_events = [] + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and isinstance(event, RunItemStreamEvent) + and event.name == "tool_called" + ): + tool_called_events.append(event) + + assert len(tool_called_events) == 1, ( + f"Expected 1 tool_called event, got {len(tool_called_events)}" + ) + + tool_event = tool_called_events[0] + # Cast to ResponseFunctionToolCall since we know that's what it is in our test + raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) + actual_arguments = raw_item.arguments + + # Even "empty" arguments should be "{}", not literally empty string + assert actual_arguments is not None, "Arguments should not be None" + assert actual_arguments == "{}", f"Expected empty JSON object '{{}}', got '{actual_arguments}'" + + # Should parse as valid empty JSON + parsed_args = json.loads(actual_arguments) + assert parsed_args == {}, f"Should parse to empty dict, got {parsed_args}" diff --git a/tests/utils/test_prompts.py b/tests/utils/test_prompts.py new file mode 100644 index 000000000..503dd7dfe --- /dev/null +++ b/tests/utils/test_prompts.py @@ -0,0 +1,107 @@ +from pydantic import BaseModel + +from agents.agent_output import AgentOutputSchema +from agents.util._prompts import get_json_output_prompt, should_inject_json_prompt + + +class SimpleModel(BaseModel): + name: str + age: int + + +class NestedModel(BaseModel): + user: SimpleModel + active: bool + + +def test_get_json_output_prompt_returns_empty_for_plain_text(): + schema = AgentOutputSchema(str) + result = get_json_output_prompt(schema) + assert result == "" + + +def test_get_json_output_prompt_with_simple_schema(): + schema = AgentOutputSchema(SimpleModel) + result = get_json_output_prompt(schema) + assert "name" in result + assert "age" in result + assert "JSON" in result + + +def test_get_json_output_prompt_with_nested_schema(): + schema = AgentOutputSchema(NestedModel) + result = get_json_output_prompt(schema) + assert "user" in result + assert "active" in result + assert "JSON" in result + + +def test_get_json_output_prompt_handles_schema_error(): + schema = AgentOutputSchema(SimpleModel) + result = get_json_output_prompt(schema) + assert isinstance(result, str) + assert len(result) > 0 + + +def test_should_inject_json_prompt_default_false(): + schema = AgentOutputSchema(SimpleModel) + tools = [{"type": "function", "name": "test"}] + result = should_inject_json_prompt(schema, tools) + assert result is False + + +def test_should_inject_json_prompt_explicit_opt_in(): + schema = AgentOutputSchema(SimpleModel) + tools = [{"type": "function", "name": "test"}] + result = should_inject_json_prompt(schema, tools, enable_structured_output_with_tools=True) + assert result is True + + +def test_should_inject_json_prompt_no_schema(): + result = should_inject_json_prompt( + None, [{"type": "function"}], enable_structured_output_with_tools=True + ) + assert result is False + + +def test_should_inject_json_prompt_plain_text_schema(): + schema = AgentOutputSchema(str) + tools = [{"type": "function"}] + result = should_inject_json_prompt(schema, tools, enable_structured_output_with_tools=True) + assert result is False + + +def test_should_inject_json_prompt_no_tools(): + schema = AgentOutputSchema(SimpleModel) + result = should_inject_json_prompt(schema, [], enable_structured_output_with_tools=True) + assert result is False + + +def test_should_inject_json_prompt_empty_tools(): + schema = AgentOutputSchema(SimpleModel) + result = should_inject_json_prompt(schema, [], enable_structured_output_with_tools=True) + assert result is False + + +def test_should_inject_json_prompt_all_conditions_met(): + schema = AgentOutputSchema(SimpleModel) + tools = [{"type": "function", "name": "test"}] + result = should_inject_json_prompt(schema, tools, enable_structured_output_with_tools=True) + assert result is True + + +def test_should_inject_json_prompt_without_opt_in(): + schema = AgentOutputSchema(SimpleModel) + tools = [{"type": "function", "name": "test"}] + result = should_inject_json_prompt(schema, tools, enable_structured_output_with_tools=False) + assert result is False + + +def test_should_inject_json_prompt_multiple_tools(): + schema = AgentOutputSchema(SimpleModel) + tools = [ + {"type": "function", "name": "test1"}, + {"type": "function", "name": "test2"}, + ] + result = should_inject_json_prompt(schema, tools, enable_structured_output_with_tools=True) + assert result is True diff --git a/tests/voice/test_workflow.py b/tests/voice/test_workflow.py index 402c52128..ca47faf6e 100644 --- a/tests/voice/test_workflow.py +++ b/tests/voice/test_workflow.py @@ -1,219 +1,221 @@ -from __future__ import annotations - -import json -from collections.abc import AsyncIterator -from typing import Any - -import pytest -from inline_snapshot import snapshot -from openai.types.responses import ResponseCompletedEvent -from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent - -from agents import Agent, Model, ModelSettings, ModelTracing, Tool -from agents.agent_output import AgentOutputSchemaBase -from agents.handoffs import Handoff -from agents.items import ( - ModelResponse, - TResponseInputItem, - TResponseOutputItem, - TResponseStreamEvent, -) - -from ..fake_model import get_response_obj -from ..test_responses import get_function_tool, get_function_tool_call, get_text_message - -try: - from agents.voice import SingleAgentVoiceWorkflow - -except ImportError: - pass - - -class FakeStreamingModel(Model): - def __init__(self): - self.turn_outputs: list[list[TResponseOutputItem]] = [] - - def set_next_output(self, output: list[TResponseOutputItem]): - self.turn_outputs.append(output) - - def add_multiple_turn_outputs(self, outputs: list[list[TResponseOutputItem]]): - self.turn_outputs.extend(outputs) - - def get_next_output(self) -> list[TResponseOutputItem]: - if not self.turn_outputs: - return [] - return self.turn_outputs.pop(0) - - async def get_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None, - conversation_id: str | None, - prompt: Any | None, - ) -> ModelResponse: - raise NotImplementedError("Not implemented") - - async def stream_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None, - conversation_id: str | None, - prompt: Any | None, - ) -> AsyncIterator[TResponseStreamEvent]: - output = self.get_next_output() - for item in output: - if ( - item.type == "message" - and len(item.content) == 1 - and item.content[0].type == "output_text" - ): - yield ResponseTextDeltaEvent( - content_index=0, - delta=item.content[0].text, - type="response.output_text.delta", - output_index=0, - item_id=item.id, - sequence_number=0, - logprobs=[], - ) - - yield ResponseCompletedEvent( - type="response.completed", - response=get_response_obj(output), - sequence_number=1, - ) - - -@pytest.mark.asyncio -async def test_single_agent_workflow(monkeypatch) -> None: - model = FakeStreamingModel() - model.add_multiple_turn_outputs( - [ - # First turn: a message and a tool call - [ - get_function_tool_call("some_function", json.dumps({"a": "b"})), - get_text_message("a_message"), - ], - # Second turn: text message - [get_text_message("done")], - ] - ) - - agent = Agent( - "initial_agent", - model=model, - tools=[get_function_tool("some_function", "tool_result")], - ) - - workflow = SingleAgentVoiceWorkflow(agent) - output = [] - async for chunk in workflow.run("transcription_1"): - output.append(chunk) - - # Validate that the text yielded matches our fake events - assert output == ["a_message", "done"] - # Validate that internal state was updated - assert workflow._input_history == snapshot( - [ - {"content": "transcription_1", "role": "user"}, - { - "arguments": '{"a": "b"}', - "call_id": "2", - "name": "some_function", - "type": "function_call", - "id": "1", - }, - { - "id": "1", - "content": [ - {"annotations": [], "logprobs": [], "text": "a_message", "type": "output_text"} - ], - "role": "assistant", - "status": "completed", - "type": "message", - }, - { - "call_id": "2", - "output": "tool_result", - "type": "function_call_output", - }, - { - "id": "1", - "content": [ - {"annotations": [], "logprobs": [], "text": "done", "type": "output_text"} - ], - "role": "assistant", - "status": "completed", - "type": "message", - }, - ] - ) - assert workflow._current_agent == agent - - model.set_next_output([get_text_message("done_2")]) - - # Run it again with a new transcription to make sure the input history is updated - output = [] - async for chunk in workflow.run("transcription_2"): - output.append(chunk) - - assert workflow._input_history == snapshot( - [ - {"role": "user", "content": "transcription_1"}, - { - "arguments": '{"a": "b"}', - "call_id": "2", - "name": "some_function", - "type": "function_call", - "id": "1", - }, - { - "id": "1", - "content": [ - {"annotations": [], "logprobs": [], "text": "a_message", "type": "output_text"} - ], - "role": "assistant", - "status": "completed", - "type": "message", - }, - { - "call_id": "2", - "output": "tool_result", - "type": "function_call_output", - }, - { - "id": "1", - "content": [ - {"annotations": [], "logprobs": [], "text": "done", "type": "output_text"} - ], - "role": "assistant", - "status": "completed", - "type": "message", - }, - {"role": "user", "content": "transcription_2"}, - { - "id": "1", - "content": [ - {"annotations": [], "logprobs": [], "text": "done_2", "type": "output_text"} - ], - "role": "assistant", - "status": "completed", - "type": "message", - }, - ] - ) - assert workflow._current_agent == agent +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from typing import Any + +import pytest +from inline_snapshot import snapshot +from openai.types.responses import ResponseCompletedEvent +from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent + +from agents import Agent, Model, ModelSettings, ModelTracing, Tool +from agents.agent_output import AgentOutputSchemaBase +from agents.handoffs import Handoff +from agents.items import ( + ModelResponse, + TResponseInputItem, + TResponseOutputItem, + TResponseStreamEvent, +) + +from ..fake_model import get_response_obj +from ..test_responses import get_function_tool, get_function_tool_call, get_text_message + +try: + from agents.voice import SingleAgentVoiceWorkflow + +except ImportError: + pass + + +class FakeStreamingModel(Model): + def __init__(self): + self.turn_outputs: list[list[TResponseOutputItem]] = [] + + def set_next_output(self, output: list[TResponseOutputItem]): + self.turn_outputs.append(output) + + def add_multiple_turn_outputs(self, outputs: list[list[TResponseOutputItem]]): + self.turn_outputs.extend(outputs) + + def get_next_output(self) -> list[TResponseOutputItem]: + if not self.turn_outputs: + return [] + return self.turn_outputs.pop(0) + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: Any | None, + enable_structured_output_with_tools: bool = False, + ) -> ModelResponse: + raise NotImplementedError("Not implemented") + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: Any | None, + enable_structured_output_with_tools: bool = False, + ) -> AsyncIterator[TResponseStreamEvent]: + output = self.get_next_output() + for item in output: + if ( + item.type == "message" + and len(item.content) == 1 + and item.content[0].type == "output_text" + ): + yield ResponseTextDeltaEvent( + content_index=0, + delta=item.content[0].text, + type="response.output_text.delta", + output_index=0, + item_id=item.id, + sequence_number=0, + logprobs=[], + ) + + yield ResponseCompletedEvent( + type="response.completed", + response=get_response_obj(output), + sequence_number=1, + ) + + +@pytest.mark.asyncio +async def test_single_agent_workflow(monkeypatch) -> None: + model = FakeStreamingModel() + model.add_multiple_turn_outputs( + [ + # First turn: a message and a tool call + [ + get_function_tool_call("some_function", json.dumps({"a": "b"})), + get_text_message("a_message"), + ], + # Second turn: text message + [get_text_message("done")], + ] + ) + + agent = Agent( + "initial_agent", + model=model, + tools=[get_function_tool("some_function", "tool_result")], + ) + + workflow = SingleAgentVoiceWorkflow(agent) + output = [] + async for chunk in workflow.run("transcription_1"): + output.append(chunk) + + # Validate that the text yielded matches our fake events + assert output == ["a_message", "done"] + # Validate that internal state was updated + assert workflow._input_history == snapshot( + [ + {"content": "transcription_1", "role": "user"}, + { + "arguments": '{"a": "b"}', + "call_id": "2", + "name": "some_function", + "type": "function_call", + "id": "1", + }, + { + "id": "1", + "content": [ + {"annotations": [], "logprobs": [], "text": "a_message", "type": "output_text"} + ], + "role": "assistant", + "status": "completed", + "type": "message", + }, + { + "call_id": "2", + "output": "tool_result", + "type": "function_call_output", + }, + { + "id": "1", + "content": [ + {"annotations": [], "logprobs": [], "text": "done", "type": "output_text"} + ], + "role": "assistant", + "status": "completed", + "type": "message", + }, + ] + ) + assert workflow._current_agent == agent + + model.set_next_output([get_text_message("done_2")]) + + # Run it again with a new transcription to make sure the input history is updated + output = [] + async for chunk in workflow.run("transcription_2"): + output.append(chunk) + + assert workflow._input_history == snapshot( + [ + {"role": "user", "content": "transcription_1"}, + { + "arguments": '{"a": "b"}', + "call_id": "2", + "name": "some_function", + "type": "function_call", + "id": "1", + }, + { + "id": "1", + "content": [ + {"annotations": [], "logprobs": [], "text": "a_message", "type": "output_text"} + ], + "role": "assistant", + "status": "completed", + "type": "message", + }, + { + "call_id": "2", + "output": "tool_result", + "type": "function_call_output", + }, + { + "id": "1", + "content": [ + {"annotations": [], "logprobs": [], "text": "done", "type": "output_text"} + ], + "role": "assistant", + "status": "completed", + "type": "message", + }, + {"role": "user", "content": "transcription_2"}, + { + "id": "1", + "content": [ + {"annotations": [], "logprobs": [], "text": "done_2", "type": "output_text"} + ], + "role": "assistant", + "status": "completed", + "type": "message", + }, + ] + ) + assert workflow._current_agent == agent From 6c637fd04810b0e138403a3ccf1d0e282c4bacdb Mon Sep 17 00:00:00 2001 From: shayan-devv Date: Mon, 10 Nov 2025 21:47:11 +0500 Subject: [PATCH 2/4] fix: only pass enable_structured_output_with_tools when enabled Maintain backward compatibility with third-party Model implementations by only passing the enable_structured_output_with_tools parameter when it's explicitly enabled (True). This prevents TypeErrors in custom Model classes that don't support the new parameter yet. - Built-in models have default False, so they work either way - Third-party models without the parameter won't crash - Feature still works when explicitly enabled Fixes backward compatibility issue raised in code review. --- src/agents/run.py | 3812 +++++++++++++++++++++++---------------------- 1 file changed, 1915 insertions(+), 1897 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index 69cc88815..6867daf33 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1,1897 +1,1915 @@ -from __future__ import annotations - -import asyncio -import contextlib -import inspect -import os -import warnings -from dataclasses import dataclass, field -from typing import Any, Callable, Generic, cast, get_args - -from openai.types.responses import ( - ResponseCompletedEvent, - ResponseOutputItemDoneEvent, -) -from openai.types.responses.response_prompt_param import ( - ResponsePromptParam, -) -from openai.types.responses.response_reasoning_item import ResponseReasoningItem -from typing_extensions import NotRequired, TypedDict, Unpack - -from ._run_impl import ( - AgentToolUseTracker, - NextStepFinalOutput, - NextStepHandoff, - NextStepRunAgain, - QueueCompleteSentinel, - RunImpl, - SingleStepResult, - TraceCtxManager, - get_model_tracing_impl, -) -from .agent import Agent -from .agent_output import AgentOutputSchema, AgentOutputSchemaBase -from .exceptions import ( - AgentsException, - InputGuardrailTripwireTriggered, - MaxTurnsExceeded, - ModelBehaviorError, - OutputGuardrailTripwireTriggered, - RunErrorDetails, - UserError, -) -from .guardrail import ( - InputGuardrail, - InputGuardrailResult, - OutputGuardrail, - OutputGuardrailResult, -) -from .handoffs import Handoff, HandoffInputFilter, handoff -from .items import ( - HandoffCallItem, - ItemHelpers, - ModelResponse, - ReasoningItem, - RunItem, - ToolCallItem, - ToolCallItemTypes, - TResponseInputItem, -) -from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase -from .logger import logger -from .memory import Session, SessionInputCallback -from .model_settings import ModelSettings -from .models.interface import Model, ModelProvider -from .models.multi_provider import MultiProvider -from .result import RunResult, RunResultStreaming -from .run_context import RunContextWrapper, TContext -from .stream_events import ( - AgentUpdatedStreamEvent, - RawResponsesStreamEvent, - RunItemStreamEvent, - StreamEvent, -) -from .tool import Tool -from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult -from .tracing import Span, SpanError, agent_span, get_current_trace, trace -from .tracing.span_data import AgentSpanData -from .usage import Usage -from .util import _coro, _error_tracing -from .util._types import MaybeAwaitable - -DEFAULT_MAX_TURNS = 10 - -DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore -# the value is set at the end of the module - - -def set_default_agent_runner(runner: AgentRunner | None) -> None: - """ - WARNING: this class is experimental and not part of the public API - It should not be used directly. - """ - global DEFAULT_AGENT_RUNNER - DEFAULT_AGENT_RUNNER = runner or AgentRunner() - - -def get_default_agent_runner() -> AgentRunner: - """ - WARNING: this class is experimental and not part of the public API - It should not be used directly. - """ - global DEFAULT_AGENT_RUNNER - return DEFAULT_AGENT_RUNNER - - -def _default_trace_include_sensitive_data() -> bool: - """Returns the default value for trace_include_sensitive_data based on environment variable.""" - val = os.getenv("OPENAI_AGENTS_TRACE_INCLUDE_SENSITIVE_DATA", "true") - return val.strip().lower() in ("1", "true", "yes", "on") - - -@dataclass -class ModelInputData: - """Container for the data that will be sent to the model.""" - - input: list[TResponseInputItem] - instructions: str | None - - -@dataclass -class CallModelData(Generic[TContext]): - """Data passed to `RunConfig.call_model_input_filter` prior to model call.""" - - model_data: ModelInputData - agent: Agent[TContext] - context: TContext | None - - -@dataclass -class _ServerConversationTracker: - """Tracks server-side conversation state for either conversation_id or - previous_response_id modes.""" - - conversation_id: str | None = None - previous_response_id: str | None = None - sent_items: set[int] = field(default_factory=set) - server_items: set[int] = field(default_factory=set) - - def track_server_items(self, model_response: ModelResponse) -> None: - for output_item in model_response.output: - self.server_items.add(id(output_item)) - - # Update previous_response_id only when using previous_response_id - if ( - self.conversation_id is None - and self.previous_response_id is not None - and model_response.response_id is not None - ): - self.previous_response_id = model_response.response_id - - def prepare_input( - self, - original_input: str | list[TResponseInputItem], - generated_items: list[RunItem], - ) -> list[TResponseInputItem]: - input_items: list[TResponseInputItem] = [] - - # On first call (when there are no generated items yet), include the original input - if not generated_items: - input_items.extend(ItemHelpers.input_to_new_input_list(original_input)) - - # Process generated_items, skip items already sent or from server - for item in generated_items: - raw_item_id = id(item.raw_item) - - if raw_item_id in self.sent_items or raw_item_id in self.server_items: - continue - input_items.append(item.to_input_item()) - self.sent_items.add(raw_item_id) - - return input_items - - -# Type alias for the optional input filter callback -CallModelInputFilter = Callable[[CallModelData[Any]], MaybeAwaitable[ModelInputData]] - - -@dataclass -class RunConfig: - """Configures settings for the entire agent run.""" - - model: str | Model | None = None - """The model to use for the entire agent run. If set, will override the model set on every - agent. The model_provider passed in below must be able to resolve this model name. - """ - - model_provider: ModelProvider = field(default_factory=MultiProvider) - """The model provider to use when looking up string model names. Defaults to OpenAI.""" - - model_settings: ModelSettings | None = None - """Configure global model settings. Any non-null values will override the agent-specific model - settings. - """ - - handoff_input_filter: HandoffInputFilter | None = None - """A global input filter to apply to all handoffs. If `Handoff.input_filter` is set, then that - will take precedence. The input filter allows you to edit the inputs that are sent to the new - agent. See the documentation in `Handoff.input_filter` for more details. - """ - - input_guardrails: list[InputGuardrail[Any]] | None = None - """A list of input guardrails to run on the initial run input.""" - - output_guardrails: list[OutputGuardrail[Any]] | None = None - """A list of output guardrails to run on the final output of the run.""" - - tracing_disabled: bool = False - """Whether tracing is disabled for the agent run. If disabled, we will not trace the agent run. - """ - - trace_include_sensitive_data: bool = field( - default_factory=_default_trace_include_sensitive_data - ) - """Whether we include potentially sensitive data (for example: inputs/outputs of tool calls or - LLM generations) in traces. If False, we'll still create spans for these events, but the - sensitive data will not be included. - """ - - workflow_name: str = "Agent workflow" - """The name of the run, used for tracing. Should be a logical name for the run, like - "Code generation workflow" or "Customer support agent". - """ - - trace_id: str | None = None - """A custom trace ID to use for tracing. If not provided, we will generate a new trace ID.""" - - group_id: str | None = None - """ - A grouping identifier to use for tracing, to link multiple traces from the same conversation - or process. For example, you might use a chat thread ID. - """ - - trace_metadata: dict[str, Any] | None = None - """ - An optional dictionary of additional metadata to include with the trace. - """ - - session_input_callback: SessionInputCallback | None = None - """Defines how to handle session history when new input is provided. - - `None` (default): The new input is appended to the session history. - - `SessionInputCallback`: A custom function that receives the history and new input, and - returns the desired combined list of items. - """ - - call_model_input_filter: CallModelInputFilter | None = None - """ - Optional callback that is invoked immediately before calling the model. It receives the current - agent, context and the model input (instructions and input items), and must return a possibly - modified `ModelInputData` to use for the model call. - - This allows you to edit the input sent to the model e.g. to stay within a token limit. - For example, you can use this to add a system prompt to the input. - """ - - -class RunOptions(TypedDict, Generic[TContext]): - """Arguments for ``AgentRunner`` methods.""" - - context: NotRequired[TContext | None] - """The context for the run.""" - - max_turns: NotRequired[int] - """The maximum number of turns to run for.""" - - hooks: NotRequired[RunHooks[TContext] | None] - """Lifecycle hooks for the run.""" - - run_config: NotRequired[RunConfig | None] - """Run configuration.""" - - previous_response_id: NotRequired[str | None] - """The ID of the previous response, if any.""" - - conversation_id: NotRequired[str | None] - """The ID of the stored conversation, if any.""" - - session: NotRequired[Session | None] - """The session for the run.""" - - -class Runner: - @classmethod - async def run( - cls, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, - conversation_id: str | None = None, - session: Session | None = None, - ) -> RunResult: - """ - Run a workflow starting at the given agent. - - The agent will run in a loop until a final output is generated. The loop runs like so: - - 1. The agent is invoked with the given input. - 2. If there is a final output (i.e. the agent produces something of type - `agent.output_type`), the loop terminates. - 3. If there's a handoff, we run the loop again, with the new agent. - 4. Else, we run tool calls (if any), and re-run the loop. - - In two cases, the agent may raise an exception: - - 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. - 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered - exception is raised. - - Note: - Only the first agent's input guardrails are run. - - Args: - starting_agent: The starting agent to run. - input: The initial input to the agent. You can pass a single string for a - user message, or a list of input items. - context: The context to run the agent with. - max_turns: The maximum number of turns to run the agent for. A turn is - defined as one AI invocation (including any tool calls that might occur). - hooks: An object that receives callbacks on various lifecycle events. - run_config: Global settings for the entire agent run. - previous_response_id: The ID of the previous response. If using OpenAI - models via the Responses API, this allows you to skip passing in input - from the previous turn. - conversation_id: The conversation ID - (https://platform.openai.com/docs/guides/conversation-state?api-mode=responses). - If provided, the conversation will be used to read and write items. - Every agent will have access to the conversation history so far, - and its output items will be written to the conversation. - We recommend only using this if you are exclusively using OpenAI models; - other model providers don't write to the Conversation object, - so you'll end up having partial conversations stored. - session: A session for automatic conversation history management. - - Returns: - A run result containing all the inputs, guardrail results and the output of - the last agent. Agents may perform handoffs, so we don't know the specific - type of the output. - """ - - runner = DEFAULT_AGENT_RUNNER - return await runner.run( - starting_agent, - input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) - - @classmethod - def run_sync( - cls, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, - conversation_id: str | None = None, - session: Session | None = None, - ) -> RunResult: - """ - Run a workflow synchronously, starting at the given agent. - - Note: - This just wraps the `run` method, so it will not work if there's already an - event loop (e.g. inside an async function, or in a Jupyter notebook or async - context like FastAPI). For those cases, use the `run` method instead. - - The agent will run in a loop until a final output is generated. The loop runs: - - 1. The agent is invoked with the given input. - 2. If there is a final output (i.e. the agent produces something of type - `agent.output_type`), the loop terminates. - 3. If there's a handoff, we run the loop again, with the new agent. - 4. Else, we run tool calls (if any), and re-run the loop. - - In two cases, the agent may raise an exception: - - 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. - 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered - exception is raised. - - Note: - Only the first agent's input guardrails are run. - - Args: - starting_agent: The starting agent to run. - input: The initial input to the agent. You can pass a single string for a - user message, or a list of input items. - context: The context to run the agent with. - max_turns: The maximum number of turns to run the agent for. A turn is - defined as one AI invocation (including any tool calls that might occur). - hooks: An object that receives callbacks on various lifecycle events. - run_config: Global settings for the entire agent run. - previous_response_id: The ID of the previous response, if using OpenAI - models via the Responses API, this allows you to skip passing in input - from the previous turn. - conversation_id: The ID of the stored conversation, if any. - session: A session for automatic conversation history management. - - Returns: - A run result containing all the inputs, guardrail results and the output of - the last agent. Agents may perform handoffs, so we don't know the specific - type of the output. - """ - - runner = DEFAULT_AGENT_RUNNER - return runner.run_sync( - starting_agent, - input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) - - @classmethod - def run_streamed( - cls, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, - conversation_id: str | None = None, - session: Session | None = None, - ) -> RunResultStreaming: - """ - Run a workflow starting at the given agent in streaming mode. - - The returned result object contains a method you can use to stream semantic - events as they are generated. - - The agent will run in a loop until a final output is generated. The loop runs like so: - - 1. The agent is invoked with the given input. - 2. If there is a final output (i.e. the agent produces something of type - `agent.output_type`), the loop terminates. - 3. If there's a handoff, we run the loop again, with the new agent. - 4. Else, we run tool calls (if any), and re-run the loop. - - In two cases, the agent may raise an exception: - - 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. - 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered - exception is raised. - - Note: - Only the first agent's input guardrails are run. - - Args: - starting_agent: The starting agent to run. - input: The initial input to the agent. You can pass a single string for a - user message, or a list of input items. - context: The context to run the agent with. - max_turns: The maximum number of turns to run the agent for. A turn is - defined as one AI invocation (including any tool calls that might occur). - hooks: An object that receives callbacks on various lifecycle events. - run_config: Global settings for the entire agent run. - previous_response_id: The ID of the previous response, if using OpenAI - models via the Responses API, this allows you to skip passing in input - from the previous turn. - conversation_id: The ID of the stored conversation, if any. - session: A session for automatic conversation history management. - - Returns: - A result object that contains data about the run, as well as a method to - stream events. - """ - - runner = DEFAULT_AGENT_RUNNER - return runner.run_streamed( - starting_agent, - input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) - - -class AgentRunner: - """ - WARNING: this class is experimental and not part of the public API - It should not be used directly or subclassed. - """ - - async def run( - self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - **kwargs: Unpack[RunOptions[TContext]], - ) -> RunResult: - context = kwargs.get("context") - max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) - run_config = kwargs.get("run_config") - previous_response_id = kwargs.get("previous_response_id") - conversation_id = kwargs.get("conversation_id") - session = kwargs.get("session") - if run_config is None: - run_config = RunConfig() - - if conversation_id is not None or previous_response_id is not None: - server_conversation_tracker = _ServerConversationTracker( - conversation_id=conversation_id, previous_response_id=previous_response_id - ) - else: - server_conversation_tracker = None - - # Keep original user input separate from session-prepared input - original_user_input = input - prepared_input = await self._prepare_input_with_session( - input, session, run_config.session_input_callback - ) - - tool_use_tracker = AgentToolUseTracker() - - with TraceCtxManager( - workflow_name=run_config.workflow_name, - trace_id=run_config.trace_id, - group_id=run_config.group_id, - metadata=run_config.trace_metadata, - disabled=run_config.tracing_disabled, - ): - current_turn = 0 - original_input: str | list[TResponseInputItem] = _copy_str_or_list(prepared_input) - generated_items: list[RunItem] = [] - model_responses: list[ModelResponse] = [] - - context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( - context=context, # type: ignore - ) - - input_guardrail_results: list[InputGuardrailResult] = [] - tool_input_guardrail_results: list[ToolInputGuardrailResult] = [] - tool_output_guardrail_results: list[ToolOutputGuardrailResult] = [] - - current_span: Span[AgentSpanData] | None = None - current_agent = starting_agent - should_run_agent_start_hooks = True - - # save only the new user input to the session, not the combined history - await self._save_result_to_session(session, original_user_input, []) - - try: - while True: - all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper) - - # Start an agent span if we don't have one. This span is ended if the current - # agent changes, or if the agent loop ends. - if current_span is None: - handoff_names = [ - h.agent_name - for h in await AgentRunner._get_handoffs(current_agent, context_wrapper) - ] - if output_schema := AgentRunner._get_output_schema(current_agent): - output_type_name = output_schema.name() - else: - output_type_name = "str" - - current_span = agent_span( - name=current_agent.name, - handoffs=handoff_names, - output_type=output_type_name, - ) - current_span.start(mark_as_current=True) - current_span.span_data.tools = [t.name for t in all_tools] - - current_turn += 1 - if current_turn > max_turns: - _error_tracing.attach_error_to_span( - current_span, - SpanError( - message="Max turns exceeded", - data={"max_turns": max_turns}, - ), - ) - raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded") - - logger.debug( - f"Running agent {current_agent.name} (turn {current_turn})", - ) - - if current_turn == 1: - input_guardrail_results, turn_result = await asyncio.gather( - self._run_input_guardrails( - starting_agent, - starting_agent.input_guardrails - + (run_config.input_guardrails or []), - _copy_str_or_list(prepared_input), - context_wrapper, - ), - self._run_single_turn( - agent=current_agent, - all_tools=all_tools, - original_input=original_input, - generated_items=generated_items, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - should_run_agent_start_hooks=should_run_agent_start_hooks, - tool_use_tracker=tool_use_tracker, - server_conversation_tracker=server_conversation_tracker, - ), - ) - else: - turn_result = await self._run_single_turn( - agent=current_agent, - all_tools=all_tools, - original_input=original_input, - generated_items=generated_items, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - should_run_agent_start_hooks=should_run_agent_start_hooks, - tool_use_tracker=tool_use_tracker, - server_conversation_tracker=server_conversation_tracker, - ) - should_run_agent_start_hooks = False - - model_responses.append(turn_result.model_response) - original_input = turn_result.original_input - generated_items = turn_result.generated_items - - if server_conversation_tracker is not None: - server_conversation_tracker.track_server_items(turn_result.model_response) - - # Collect tool guardrail results from this turn - tool_input_guardrail_results.extend(turn_result.tool_input_guardrail_results) - tool_output_guardrail_results.extend(turn_result.tool_output_guardrail_results) - - if isinstance(turn_result.next_step, NextStepFinalOutput): - output_guardrail_results = await self._run_output_guardrails( - current_agent.output_guardrails + (run_config.output_guardrails or []), - current_agent, - turn_result.next_step.output, - context_wrapper, - ) - result = RunResult( - input=original_input, - new_items=generated_items, - raw_responses=model_responses, - final_output=turn_result.next_step.output, - _last_agent=current_agent, - input_guardrail_results=input_guardrail_results, - output_guardrail_results=output_guardrail_results, - tool_input_guardrail_results=tool_input_guardrail_results, - tool_output_guardrail_results=tool_output_guardrail_results, - context_wrapper=context_wrapper, - ) - if not any( - guardrail_result.output.tripwire_triggered - for guardrail_result in input_guardrail_results - ): - await self._save_result_to_session( - session, [], turn_result.new_step_items - ) - - return result - elif isinstance(turn_result.next_step, NextStepHandoff): - current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) - current_span.finish(reset_current=True) - current_span = None - should_run_agent_start_hooks = True - elif isinstance(turn_result.next_step, NextStepRunAgain): - if not any( - guardrail_result.output.tripwire_triggered - for guardrail_result in input_guardrail_results - ): - await self._save_result_to_session( - session, [], turn_result.new_step_items - ) - else: - raise AgentsException( - f"Unknown next step type: {type(turn_result.next_step)}" - ) - except AgentsException as exc: - exc.run_data = RunErrorDetails( - input=original_input, - new_items=generated_items, - raw_responses=model_responses, - last_agent=current_agent, - context_wrapper=context_wrapper, - input_guardrail_results=input_guardrail_results, - output_guardrail_results=[], - ) - raise - finally: - if current_span: - current_span.finish(reset_current=True) - - def run_sync( - self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - **kwargs: Unpack[RunOptions[TContext]], - ) -> RunResult: - context = kwargs.get("context") - max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = kwargs.get("hooks") - run_config = kwargs.get("run_config") - previous_response_id = kwargs.get("previous_response_id") - conversation_id = kwargs.get("conversation_id") - session = kwargs.get("session") - - # Python 3.14 stopped implicitly wiring up a default event loop - # when synchronous code touches asyncio APIs for the first time. - # Several of our synchronous entry points (for example the Redis/SQLAlchemy session helpers) - # construct asyncio primitives like asyncio.Lock during __init__, - # which binds them to whatever loop happens to be the thread's default at that moment. - # To keep those locks usable we must ensure that run_sync reuses that same default loop - # instead of hopping over to a brand-new asyncio.run() loop. - try: - already_running_loop = asyncio.get_running_loop() - except RuntimeError: - already_running_loop = None - - if already_running_loop is not None: - # This method is only expected to run when no loop is already active. - # (Each thread has its own default loop; concurrent sync runs should happen on - # different threads. In a single thread use the async API to interleave work.) - raise RuntimeError( - "AgentRunner.run_sync() cannot be called when an event loop is already running." - ) - - policy = asyncio.get_event_loop_policy() - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - try: - default_loop = policy.get_event_loop() - except RuntimeError: - default_loop = policy.new_event_loop() - policy.set_event_loop(default_loop) - - # We intentionally leave the default loop open even if we had to create one above. Session - # instances and other helpers stash loop-bound primitives between calls and expect to find - # the same default loop every time run_sync is invoked on this thread. - # Schedule the async run on the default loop so that we can manage cancellation explicitly. - task = default_loop.create_task( - self.run( - starting_agent, - input, - session=session, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - ) - ) - - try: - # Drive the coroutine to completion, harvesting the final RunResult. - return default_loop.run_until_complete(task) - except BaseException: - # If the sync caller aborts (KeyboardInterrupt, etc.), make sure the scheduled task - # does not linger on the shared loop by cancelling it and waiting for completion. - if not task.done(): - task.cancel() - with contextlib.suppress(asyncio.CancelledError): - default_loop.run_until_complete(task) - raise - finally: - if not default_loop.is_closed(): - # The loop stays open for subsequent runs, but we still need to flush any pending - # async generators so their cleanup code executes promptly. - with contextlib.suppress(RuntimeError): - default_loop.run_until_complete(default_loop.shutdown_asyncgens()) - - def run_streamed( - self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - **kwargs: Unpack[RunOptions[TContext]], - ) -> RunResultStreaming: - context = kwargs.get("context") - max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) - run_config = kwargs.get("run_config") - previous_response_id = kwargs.get("previous_response_id") - conversation_id = kwargs.get("conversation_id") - session = kwargs.get("session") - - if run_config is None: - run_config = RunConfig() - - # If there's already a trace, we don't create a new one. In addition, we can't end the - # trace here, because the actual work is done in `stream_events` and this method ends - # before that. - new_trace = ( - None - if get_current_trace() - else trace( - workflow_name=run_config.workflow_name, - trace_id=run_config.trace_id, - group_id=run_config.group_id, - metadata=run_config.trace_metadata, - disabled=run_config.tracing_disabled, - ) - ) - - output_schema = AgentRunner._get_output_schema(starting_agent) - context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( - context=context # type: ignore - ) - - streamed_result = RunResultStreaming( - input=_copy_str_or_list(input), - new_items=[], - current_agent=starting_agent, - raw_responses=[], - final_output=None, - is_complete=False, - current_turn=0, - max_turns=max_turns, - input_guardrail_results=[], - output_guardrail_results=[], - tool_input_guardrail_results=[], - tool_output_guardrail_results=[], - _current_agent_output_schema=output_schema, - trace=new_trace, - context_wrapper=context_wrapper, - ) - - # Kick off the actual agent loop in the background and return the streamed result object. - streamed_result._run_impl_task = asyncio.create_task( - self._start_streaming( - starting_input=input, - streamed_result=streamed_result, - starting_agent=starting_agent, - max_turns=max_turns, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) - ) - return streamed_result - - @staticmethod - def _validate_run_hooks( - hooks: RunHooksBase[Any, Agent[Any]] | AgentHooksBase[Any, Agent[Any]] | Any | None, - ) -> RunHooks[Any]: - if hooks is None: - return RunHooks[Any]() - input_hook_type = type(hooks).__name__ - if isinstance(hooks, AgentHooksBase): - raise TypeError( - "Run hooks must be instances of RunHooks. " - f"Received agent-scoped hooks ({input_hook_type}). " - "Attach AgentHooks to an Agent via Agent(..., hooks=...)." - ) - if not isinstance(hooks, RunHooksBase): - raise TypeError(f"Run hooks must be instances of RunHooks. Received {input_hook_type}.") - return hooks - - @classmethod - async def _maybe_filter_model_input( - cls, - *, - agent: Agent[TContext], - run_config: RunConfig, - context_wrapper: RunContextWrapper[TContext], - input_items: list[TResponseInputItem], - system_instructions: str | None, - ) -> ModelInputData: - """Apply optional call_model_input_filter to modify model input. - - Returns a `ModelInputData` that will be sent to the model. - """ - effective_instructions = system_instructions - effective_input: list[TResponseInputItem] = input_items - - if run_config.call_model_input_filter is None: - return ModelInputData(input=effective_input, instructions=effective_instructions) - - try: - model_input = ModelInputData( - input=effective_input.copy(), - instructions=effective_instructions, - ) - filter_payload: CallModelData[TContext] = CallModelData( - model_data=model_input, - agent=agent, - context=context_wrapper.context, - ) - maybe_updated = run_config.call_model_input_filter(filter_payload) - updated = await maybe_updated if inspect.isawaitable(maybe_updated) else maybe_updated - if not isinstance(updated, ModelInputData): - raise UserError("call_model_input_filter must return a ModelInputData instance") - return updated - except Exception as e: - _error_tracing.attach_error_to_current_span( - SpanError(message="Error in call_model_input_filter", data={"error": str(e)}) - ) - raise - - @classmethod - async def _run_input_guardrails_with_queue( - cls, - agent: Agent[Any], - guardrails: list[InputGuardrail[TContext]], - input: str | list[TResponseInputItem], - context: RunContextWrapper[TContext], - streamed_result: RunResultStreaming, - parent_span: Span[Any], - ): - queue = streamed_result._input_guardrail_queue - - # We'll run the guardrails and push them onto the queue as they complete - guardrail_tasks = [ - asyncio.create_task( - RunImpl.run_single_input_guardrail(agent, guardrail, input, context) - ) - for guardrail in guardrails - ] - guardrail_results = [] - try: - for done in asyncio.as_completed(guardrail_tasks): - result = await done - if result.output.tripwire_triggered: - _error_tracing.attach_error_to_span( - parent_span, - SpanError( - message="Guardrail tripwire triggered", - data={ - "guardrail": result.guardrail.get_name(), - "type": "input_guardrail", - }, - ), - ) - queue.put_nowait(result) - guardrail_results.append(result) - except Exception: - for t in guardrail_tasks: - t.cancel() - raise - - streamed_result.input_guardrail_results = guardrail_results - - @classmethod - async def _start_streaming( - cls, - starting_input: str | list[TResponseInputItem], - streamed_result: RunResultStreaming, - starting_agent: Agent[TContext], - max_turns: int, - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - previous_response_id: str | None, - conversation_id: str | None, - session: Session | None, - ): - if streamed_result.trace: - streamed_result.trace.start(mark_as_current=True) - - current_span: Span[AgentSpanData] | None = None - current_agent = starting_agent - current_turn = 0 - should_run_agent_start_hooks = True - tool_use_tracker = AgentToolUseTracker() - - if conversation_id is not None or previous_response_id is not None: - server_conversation_tracker = _ServerConversationTracker( - conversation_id=conversation_id, previous_response_id=previous_response_id - ) - else: - server_conversation_tracker = None - - streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) - - try: - # Prepare input with session if enabled - prepared_input = await AgentRunner._prepare_input_with_session( - starting_input, session, run_config.session_input_callback - ) - - # Update the streamed result with the prepared input - streamed_result.input = prepared_input - - await AgentRunner._save_result_to_session(session, starting_input, []) - - while True: - # Check for soft cancel before starting new turn - if streamed_result._cancel_mode == "after_turn": - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - break - - if streamed_result.is_complete: - break - - all_tools = await cls._get_all_tools(current_agent, context_wrapper) - - # Start an agent span if we don't have one. This span is ended if the current - # agent changes, or if the agent loop ends. - if current_span is None: - handoff_names = [ - h.agent_name - for h in await cls._get_handoffs(current_agent, context_wrapper) - ] - if output_schema := cls._get_output_schema(current_agent): - output_type_name = output_schema.name() - else: - output_type_name = "str" - - current_span = agent_span( - name=current_agent.name, - handoffs=handoff_names, - output_type=output_type_name, - ) - current_span.start(mark_as_current=True) - tool_names = [t.name for t in all_tools] - current_span.span_data.tools = tool_names - current_turn += 1 - streamed_result.current_turn = current_turn - - if current_turn > max_turns: - _error_tracing.attach_error_to_span( - current_span, - SpanError( - message="Max turns exceeded", - data={"max_turns": max_turns}, - ), - ) - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - break - - if current_turn == 1: - # Run the input guardrails in the background and put the results on the queue - streamed_result._input_guardrails_task = asyncio.create_task( - cls._run_input_guardrails_with_queue( - starting_agent, - starting_agent.input_guardrails + (run_config.input_guardrails or []), - ItemHelpers.input_to_new_input_list(prepared_input), - context_wrapper, - streamed_result, - current_span, - ) - ) - try: - turn_result = await cls._run_single_turn_streamed( - streamed_result, - current_agent, - hooks, - context_wrapper, - run_config, - should_run_agent_start_hooks, - tool_use_tracker, - all_tools, - server_conversation_tracker, - ) - should_run_agent_start_hooks = False - - streamed_result.raw_responses = streamed_result.raw_responses + [ - turn_result.model_response - ] - streamed_result.input = turn_result.original_input - streamed_result.new_items = turn_result.generated_items - - if server_conversation_tracker is not None: - server_conversation_tracker.track_server_items(turn_result.model_response) - - if isinstance(turn_result.next_step, NextStepHandoff): - # Save the conversation to session if enabled (before handoff) - # Note: Non-streaming path doesn't save handoff turns immediately, - # but streaming needs to for graceful cancellation support - if session is not None: - should_skip_session_save = ( - await AgentRunner._input_guardrail_tripwire_triggered_for_stream( - streamed_result - ) - ) - if should_skip_session_save is False: - await AgentRunner._save_result_to_session( - session, [], turn_result.new_step_items - ) - - current_agent = turn_result.next_step.new_agent - current_span.finish(reset_current=True) - current_span = None - should_run_agent_start_hooks = True - streamed_result._event_queue.put_nowait( - AgentUpdatedStreamEvent(new_agent=current_agent) - ) - - # Check for soft cancel after handoff - if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - break - elif isinstance(turn_result.next_step, NextStepFinalOutput): - streamed_result._output_guardrails_task = asyncio.create_task( - cls._run_output_guardrails( - current_agent.output_guardrails - + (run_config.output_guardrails or []), - current_agent, - turn_result.next_step.output, - context_wrapper, - ) - ) - - try: - output_guardrail_results = await streamed_result._output_guardrails_task - except Exception: - # Exceptions will be checked in the stream_events loop - output_guardrail_results = [] - - streamed_result.output_guardrail_results = output_guardrail_results - streamed_result.final_output = turn_result.next_step.output - streamed_result.is_complete = True - - # Save the conversation to session if enabled - if session is not None: - should_skip_session_save = ( - await AgentRunner._input_guardrail_tripwire_triggered_for_stream( - streamed_result - ) - ) - if should_skip_session_save is False: - await AgentRunner._save_result_to_session( - session, [], turn_result.new_step_items - ) - - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - elif isinstance(turn_result.next_step, NextStepRunAgain): - if session is not None: - should_skip_session_save = ( - await AgentRunner._input_guardrail_tripwire_triggered_for_stream( - streamed_result - ) - ) - if should_skip_session_save is False: - await AgentRunner._save_result_to_session( - session, [], turn_result.new_step_items - ) - - # Check for soft cancel after turn completion - if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - break - except AgentsException as exc: - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - exc.run_data = RunErrorDetails( - input=streamed_result.input, - new_items=streamed_result.new_items, - raw_responses=streamed_result.raw_responses, - last_agent=current_agent, - context_wrapper=context_wrapper, - input_guardrail_results=streamed_result.input_guardrail_results, - output_guardrail_results=streamed_result.output_guardrail_results, - ) - raise - except Exception as e: - if current_span: - _error_tracing.attach_error_to_span( - current_span, - SpanError( - message="Error in agent run", - data={"error": str(e)}, - ), - ) - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - raise - - streamed_result.is_complete = True - finally: - if streamed_result._input_guardrails_task: - try: - await AgentRunner._input_guardrail_tripwire_triggered_for_stream( - streamed_result - ) - except Exception as e: - logger.debug( - f"Error in streamed_result finalize for agent {current_agent.name} - {e}" - ) - if current_span: - current_span.finish(reset_current=True) - if streamed_result.trace: - streamed_result.trace.finish(reset_current=True) - - @classmethod - async def _run_single_turn_streamed( - cls, - streamed_result: RunResultStreaming, - agent: Agent[TContext], - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - should_run_agent_start_hooks: bool, - tool_use_tracker: AgentToolUseTracker, - all_tools: list[Tool], - server_conversation_tracker: _ServerConversationTracker | None = None, - ) -> SingleStepResult: - emitted_tool_call_ids: set[str] = set() - emitted_reasoning_item_ids: set[str] = set() - - if should_run_agent_start_hooks: - await asyncio.gather( - hooks.on_agent_start(context_wrapper, agent), - ( - agent.hooks.on_start(context_wrapper, agent) - if agent.hooks - else _coro.noop_coroutine() - ), - ) - - output_schema = cls._get_output_schema(agent) - - streamed_result.current_agent = agent - streamed_result._current_agent_output_schema = output_schema - - system_prompt, prompt_config = await asyncio.gather( - agent.get_system_prompt(context_wrapper), - agent.get_prompt(context_wrapper), - ) - - handoffs = await cls._get_handoffs(agent, context_wrapper) - model = cls._get_model(agent, run_config) - model_settings = agent.model_settings.resolve(run_config.model_settings) - model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) - - final_response: ModelResponse | None = None - - if server_conversation_tracker is not None: - input = server_conversation_tracker.prepare_input( - streamed_result.input, streamed_result.new_items - ) - else: - input = ItemHelpers.input_to_new_input_list(streamed_result.input) - input.extend([item.to_input_item() for item in streamed_result.new_items]) - - # THIS IS THE RESOLVED CONFLICT BLOCK - filtered = await cls._maybe_filter_model_input( - agent=agent, - run_config=run_config, - context_wrapper=context_wrapper, - input_items=input, - system_instructions=system_prompt, - ) - - # Call hook just before the model is invoked, with the correct system_prompt. - await asyncio.gather( - hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), - ( - agent.hooks.on_llm_start( - context_wrapper, agent, filtered.instructions, filtered.input - ) - if agent.hooks - else _coro.noop_coroutine() - ), - ) - - previous_response_id = ( - server_conversation_tracker.previous_response_id - if server_conversation_tracker - else None - ) - conversation_id = ( - server_conversation_tracker.conversation_id if server_conversation_tracker else None - ) - - # 1. Stream the output events - async for event in model.stream_response( - filtered.instructions, - filtered.input, - model_settings, - all_tools, - output_schema, - handoffs, - get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data - ), - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt_config, - enable_structured_output_with_tools=agent.enable_structured_output_with_tools, - ): - # Emit the raw event ASAP - streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) - - if isinstance(event, ResponseCompletedEvent): - usage = ( - Usage( - requests=1, - input_tokens=event.response.usage.input_tokens, - output_tokens=event.response.usage.output_tokens, - total_tokens=event.response.usage.total_tokens, - input_tokens_details=event.response.usage.input_tokens_details, - output_tokens_details=event.response.usage.output_tokens_details, - ) - if event.response.usage - else Usage() - ) - final_response = ModelResponse( - output=event.response.output, - usage=usage, - response_id=event.response.id, - ) - context_wrapper.usage.add(usage) - - if isinstance(event, ResponseOutputItemDoneEvent): - output_item = event.item - - if isinstance(output_item, _TOOL_CALL_TYPES): - call_id: str | None = getattr( - output_item, "call_id", getattr(output_item, "id", None) - ) - - if call_id and call_id not in emitted_tool_call_ids: - emitted_tool_call_ids.add(call_id) - - tool_item = ToolCallItem( - raw_item=cast(ToolCallItemTypes, output_item), - agent=agent, - ) - streamed_result._event_queue.put_nowait( - RunItemStreamEvent(item=tool_item, name="tool_called") - ) - - elif isinstance(output_item, ResponseReasoningItem): - reasoning_id: str | None = getattr(output_item, "id", None) - - if reasoning_id and reasoning_id not in emitted_reasoning_item_ids: - emitted_reasoning_item_ids.add(reasoning_id) - - reasoning_item = ReasoningItem(raw_item=output_item, agent=agent) - streamed_result._event_queue.put_nowait( - RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created") - ) - - # Call hook just after the model response is finalized. - if final_response is not None: - await asyncio.gather( - ( - agent.hooks.on_llm_end(context_wrapper, agent, final_response) - if agent.hooks - else _coro.noop_coroutine() - ), - hooks.on_llm_end(context_wrapper, agent, final_response), - ) - - # 2. At this point, the streaming is complete for this turn of the agent loop. - if not final_response: - raise ModelBehaviorError("Model did not produce a final response!") - - # 3. Now, we can process the turn as we do in the non-streaming case - single_step_result = await cls._get_single_step_result_from_response( - agent=agent, - original_input=streamed_result.input, - pre_step_items=streamed_result.new_items, - new_response=final_response, - output_schema=output_schema, - all_tools=all_tools, - handoffs=handoffs, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - tool_use_tracker=tool_use_tracker, - event_queue=streamed_result._event_queue, - ) - - import dataclasses as _dc - - # Filter out items that have already been sent to avoid duplicates - items_to_filter = single_step_result.new_step_items - - if emitted_tool_call_ids: - # Filter out tool call items that were already emitted during streaming - items_to_filter = [ - item - for item in items_to_filter - if not ( - isinstance(item, ToolCallItem) - and ( - call_id := getattr( - item.raw_item, "call_id", getattr(item.raw_item, "id", None) - ) - ) - and call_id in emitted_tool_call_ids - ) - ] - - if emitted_reasoning_item_ids: - # Filter out reasoning items that were already emitted during streaming - items_to_filter = [ - item - for item in items_to_filter - if not ( - isinstance(item, ReasoningItem) - and (reasoning_id := getattr(item.raw_item, "id", None)) - and reasoning_id in emitted_reasoning_item_ids - ) - ] - - # Filter out HandoffCallItem to avoid duplicates (already sent earlier) - items_to_filter = [ - item for item in items_to_filter if not isinstance(item, HandoffCallItem) - ] - - # Create filtered result and send to queue - filtered_result = _dc.replace(single_step_result, new_step_items=items_to_filter) - RunImpl.stream_step_result_to_queue(filtered_result, streamed_result._event_queue) - return single_step_result - - @classmethod - async def _run_single_turn( - cls, - *, - agent: Agent[TContext], - all_tools: list[Tool], - original_input: str | list[TResponseInputItem], - generated_items: list[RunItem], - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - should_run_agent_start_hooks: bool, - tool_use_tracker: AgentToolUseTracker, - server_conversation_tracker: _ServerConversationTracker | None = None, - ) -> SingleStepResult: - # Ensure we run the hooks before anything else - if should_run_agent_start_hooks: - await asyncio.gather( - hooks.on_agent_start(context_wrapper, agent), - ( - agent.hooks.on_start(context_wrapper, agent) - if agent.hooks - else _coro.noop_coroutine() - ), - ) - - system_prompt, prompt_config = await asyncio.gather( - agent.get_system_prompt(context_wrapper), - agent.get_prompt(context_wrapper), - ) - - output_schema = cls._get_output_schema(agent) - handoffs = await cls._get_handoffs(agent, context_wrapper) - if server_conversation_tracker is not None: - input = server_conversation_tracker.prepare_input(original_input, generated_items) - else: - input = ItemHelpers.input_to_new_input_list(original_input) - input.extend([generated_item.to_input_item() for generated_item in generated_items]) - - new_response = await cls._get_new_response( - agent, - system_prompt, - input, - output_schema, - all_tools, - handoffs, - hooks, - context_wrapper, - run_config, - tool_use_tracker, - server_conversation_tracker, - prompt_config, - ) - - return await cls._get_single_step_result_from_response( - agent=agent, - original_input=original_input, - pre_step_items=generated_items, - new_response=new_response, - output_schema=output_schema, - all_tools=all_tools, - handoffs=handoffs, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - tool_use_tracker=tool_use_tracker, - ) - - @classmethod - async def _get_single_step_result_from_response( - cls, - *, - agent: Agent[TContext], - all_tools: list[Tool], - original_input: str | list[TResponseInputItem], - pre_step_items: list[RunItem], - new_response: ModelResponse, - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - tool_use_tracker: AgentToolUseTracker, - event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None, - ) -> SingleStepResult: - processed_response = RunImpl.process_model_response( - agent=agent, - all_tools=all_tools, - response=new_response, - output_schema=output_schema, - handoffs=handoffs, - ) - - tool_use_tracker.add_tool_use(agent, processed_response.tools_used) - - # Send handoff items immediately for streaming, but avoid duplicates - if event_queue is not None and processed_response.new_items: - handoff_items = [ - item for item in processed_response.new_items if isinstance(item, HandoffCallItem) - ] - if handoff_items: - RunImpl.stream_step_items_to_queue(cast(list[RunItem], handoff_items), event_queue) - - return await RunImpl.execute_tools_and_side_effects( - agent=agent, - original_input=original_input, - pre_step_items=pre_step_items, - new_response=new_response, - processed_response=processed_response, - output_schema=output_schema, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - ) - - @classmethod - async def _get_single_step_result_from_streamed_response( - cls, - *, - agent: Agent[TContext], - all_tools: list[Tool], - streamed_result: RunResultStreaming, - new_response: ModelResponse, - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - tool_use_tracker: AgentToolUseTracker, - ) -> SingleStepResult: - original_input = streamed_result.input - pre_step_items = streamed_result.new_items - event_queue = streamed_result._event_queue - - processed_response = RunImpl.process_model_response( - agent=agent, - all_tools=all_tools, - response=new_response, - output_schema=output_schema, - handoffs=handoffs, - ) - new_items_processed_response = processed_response.new_items - tool_use_tracker.add_tool_use(agent, processed_response.tools_used) - RunImpl.stream_step_items_to_queue(new_items_processed_response, event_queue) - - single_step_result = await RunImpl.execute_tools_and_side_effects( - agent=agent, - original_input=original_input, - pre_step_items=pre_step_items, - new_response=new_response, - processed_response=processed_response, - output_schema=output_schema, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - ) - new_step_items = [ - item - for item in single_step_result.new_step_items - if item not in new_items_processed_response - ] - RunImpl.stream_step_items_to_queue(new_step_items, event_queue) - - return single_step_result - - @classmethod - async def _run_input_guardrails( - cls, - agent: Agent[Any], - guardrails: list[InputGuardrail[TContext]], - input: str | list[TResponseInputItem], - context: RunContextWrapper[TContext], - ) -> list[InputGuardrailResult]: - if not guardrails: - return [] - - guardrail_tasks = [ - asyncio.create_task( - RunImpl.run_single_input_guardrail(agent, guardrail, input, context) - ) - for guardrail in guardrails - ] - - guardrail_results = [] - - for done in asyncio.as_completed(guardrail_tasks): - result = await done - if result.output.tripwire_triggered: - # Cancel all guardrail tasks if a tripwire is triggered. - for t in guardrail_tasks: - t.cancel() - _error_tracing.attach_error_to_current_span( - SpanError( - message="Guardrail tripwire triggered", - data={"guardrail": result.guardrail.get_name()}, - ) - ) - raise InputGuardrailTripwireTriggered(result) - else: - guardrail_results.append(result) - - return guardrail_results - - @classmethod - async def _run_output_guardrails( - cls, - guardrails: list[OutputGuardrail[TContext]], - agent: Agent[TContext], - agent_output: Any, - context: RunContextWrapper[TContext], - ) -> list[OutputGuardrailResult]: - if not guardrails: - return [] - - guardrail_tasks = [ - asyncio.create_task( - RunImpl.run_single_output_guardrail(guardrail, agent, agent_output, context) - ) - for guardrail in guardrails - ] - - guardrail_results = [] - - for done in asyncio.as_completed(guardrail_tasks): - result = await done - if result.output.tripwire_triggered: - # Cancel all guardrail tasks if a tripwire is triggered. - for t in guardrail_tasks: - t.cancel() - _error_tracing.attach_error_to_current_span( - SpanError( - message="Guardrail tripwire triggered", - data={"guardrail": result.guardrail.get_name()}, - ) - ) - raise OutputGuardrailTripwireTriggered(result) - else: - guardrail_results.append(result) - - return guardrail_results - - @classmethod - async def _get_new_response( - cls, - agent: Agent[TContext], - system_prompt: str | None, - input: list[TResponseInputItem], - output_schema: AgentOutputSchemaBase | None, - all_tools: list[Tool], - handoffs: list[Handoff], - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - tool_use_tracker: AgentToolUseTracker, - server_conversation_tracker: _ServerConversationTracker | None, - prompt_config: ResponsePromptParam | None, - ) -> ModelResponse: - # Allow user to modify model input right before the call, if configured - filtered = await cls._maybe_filter_model_input( - agent=agent, - run_config=run_config, - context_wrapper=context_wrapper, - input_items=input, - system_instructions=system_prompt, - ) - - model = cls._get_model(agent, run_config) - model_settings = agent.model_settings.resolve(run_config.model_settings) - model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) - - # If we have run hooks, or if the agent has hooks, we need to call them before the LLM call - await asyncio.gather( - hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), - ( - agent.hooks.on_llm_start( - context_wrapper, - agent, - filtered.instructions, # Use filtered instructions - filtered.input, # Use filtered input - ) - if agent.hooks - else _coro.noop_coroutine() - ), - ) - - previous_response_id = ( - server_conversation_tracker.previous_response_id - if server_conversation_tracker - else None - ) - conversation_id = ( - server_conversation_tracker.conversation_id if server_conversation_tracker else None - ) - - new_response = await model.get_response( - system_instructions=filtered.instructions, - input=filtered.input, - model_settings=model_settings, - tools=all_tools, - output_schema=output_schema, - handoffs=handoffs, - tracing=get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data - ), - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt_config, - enable_structured_output_with_tools=agent.enable_structured_output_with_tools, - ) - - context_wrapper.usage.add(new_response.usage) - - # If we have run hooks, or if the agent has hooks, we need to call them after the LLM call - await asyncio.gather( - ( - agent.hooks.on_llm_end(context_wrapper, agent, new_response) - if agent.hooks - else _coro.noop_coroutine() - ), - hooks.on_llm_end(context_wrapper, agent, new_response), - ) - - return new_response - - @classmethod - def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: - if agent.output_type is None or agent.output_type is str: - return None - elif isinstance(agent.output_type, AgentOutputSchemaBase): - return agent.output_type - - return AgentOutputSchema(agent.output_type) - - @classmethod - async def _get_handoffs( - cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] - ) -> list[Handoff]: - handoffs = [] - for handoff_item in agent.handoffs: - if isinstance(handoff_item, Handoff): - handoffs.append(handoff_item) - elif isinstance(handoff_item, Agent): - handoffs.append(handoff(handoff_item)) - - async def _check_handoff_enabled(handoff_obj: Handoff) -> bool: - attr = handoff_obj.is_enabled - if isinstance(attr, bool): - return attr - res = attr(context_wrapper, agent) - if inspect.isawaitable(res): - return bool(await res) - return bool(res) - - results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs)) - enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok] - return enabled - - @classmethod - async def _get_all_tools( - cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] - ) -> list[Tool]: - return await agent.get_all_tools(context_wrapper) - - @classmethod - def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: - if isinstance(run_config.model, Model): - return run_config.model - elif isinstance(run_config.model, str): - return run_config.model_provider.get_model(run_config.model) - elif isinstance(agent.model, Model): - return agent.model - - return run_config.model_provider.get_model(agent.model) - - @classmethod - async def _prepare_input_with_session( - cls, - input: str | list[TResponseInputItem], - session: Session | None, - session_input_callback: SessionInputCallback | None, - ) -> str | list[TResponseInputItem]: - """Prepare input by combining it with session history if enabled.""" - if session is None: - return input - - # If the user doesn't specify an input callback and pass a list as input - if isinstance(input, list) and not session_input_callback: - raise UserError( - "When using session memory, list inputs require a " - "`RunConfig.session_input_callback` to define how they should be merged " - "with the conversation history. If you don't want to use a callback, " - "provide your input as a string instead, or disable session memory " - "(session=None) and pass a list to manage the history manually." - ) - - # Get previous conversation history - history = await session.get_items() - - # Convert input to list format - new_input_list = ItemHelpers.input_to_new_input_list(input) - - if session_input_callback is None: - return history + new_input_list - elif callable(session_input_callback): - res = session_input_callback(history, new_input_list) - if inspect.isawaitable(res): - return await res - return res - else: - raise UserError( - f"Invalid `session_input_callback` value: {session_input_callback}. " - "Choose between `None` or a custom callable function." - ) - - @classmethod - async def _save_result_to_session( - cls, - session: Session | None, - original_input: str | list[TResponseInputItem], - new_items: list[RunItem], - ) -> None: - """ - Save the conversation turn to session. - It does not account for any filtering or modification performed by - `RunConfig.session_input_callback`. - """ - if session is None: - return - - # Convert original input to list format if needed - input_list = ItemHelpers.input_to_new_input_list(original_input) - - # Convert new items to input format - new_items_as_input = [item.to_input_item() for item in new_items] - - # Save all items from this turn - items_to_save = input_list + new_items_as_input - await session.add_items(items_to_save) - - @staticmethod - async def _input_guardrail_tripwire_triggered_for_stream( - streamed_result: RunResultStreaming, - ) -> bool: - """Return True if any input guardrail triggered during a streamed run.""" - - task = streamed_result._input_guardrails_task - if task is None: - return False - - if not task.done(): - await task - - return any( - guardrail_result.output.tripwire_triggered - for guardrail_result in streamed_result.input_guardrail_results - ) - - -DEFAULT_AGENT_RUNNER = AgentRunner() -_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes) - - -def _copy_str_or_list(input: str | list[TResponseInputItem]) -> str | list[TResponseInputItem]: - if isinstance(input, str): - return input - return input.copy() +from __future__ import annotations + +import asyncio +import contextlib +import inspect +import os +import warnings +from dataclasses import dataclass, field +from typing import Any, Callable, Generic, cast, get_args + +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseOutputItemDoneEvent, +) +from openai.types.responses.response_prompt_param import ( + ResponsePromptParam, +) +from openai.types.responses.response_reasoning_item import ResponseReasoningItem +from typing_extensions import NotRequired, TypedDict, Unpack + +from ._run_impl import ( + AgentToolUseTracker, + NextStepFinalOutput, + NextStepHandoff, + NextStepRunAgain, + QueueCompleteSentinel, + RunImpl, + SingleStepResult, + TraceCtxManager, + get_model_tracing_impl, +) +from .agent import Agent +from .agent_output import AgentOutputSchema, AgentOutputSchemaBase +from .exceptions import ( + AgentsException, + InputGuardrailTripwireTriggered, + MaxTurnsExceeded, + ModelBehaviorError, + OutputGuardrailTripwireTriggered, + RunErrorDetails, + UserError, +) +from .guardrail import ( + InputGuardrail, + InputGuardrailResult, + OutputGuardrail, + OutputGuardrailResult, +) +from .handoffs import Handoff, HandoffInputFilter, handoff +from .items import ( + HandoffCallItem, + ItemHelpers, + ModelResponse, + ReasoningItem, + RunItem, + ToolCallItem, + ToolCallItemTypes, + TResponseInputItem, +) +from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase +from .logger import logger +from .memory import Session, SessionInputCallback +from .model_settings import ModelSettings +from .models.interface import Model, ModelProvider +from .models.multi_provider import MultiProvider +from .result import RunResult, RunResultStreaming +from .run_context import RunContextWrapper, TContext +from .stream_events import ( + AgentUpdatedStreamEvent, + RawResponsesStreamEvent, + RunItemStreamEvent, + StreamEvent, +) +from .tool import Tool +from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult +from .tracing import Span, SpanError, agent_span, get_current_trace, trace +from .tracing.span_data import AgentSpanData +from .usage import Usage +from .util import _coro, _error_tracing +from .util._types import MaybeAwaitable + +DEFAULT_MAX_TURNS = 10 + +DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore +# the value is set at the end of the module + + +def set_default_agent_runner(runner: AgentRunner | None) -> None: + """ + WARNING: this class is experimental and not part of the public API + It should not be used directly. + """ + global DEFAULT_AGENT_RUNNER + DEFAULT_AGENT_RUNNER = runner or AgentRunner() + + +def get_default_agent_runner() -> AgentRunner: + """ + WARNING: this class is experimental and not part of the public API + It should not be used directly. + """ + global DEFAULT_AGENT_RUNNER + return DEFAULT_AGENT_RUNNER + + +def _default_trace_include_sensitive_data() -> bool: + """Returns the default value for trace_include_sensitive_data based on environment variable.""" + val = os.getenv("OPENAI_AGENTS_TRACE_INCLUDE_SENSITIVE_DATA", "true") + return val.strip().lower() in ("1", "true", "yes", "on") + + +@dataclass +class ModelInputData: + """Container for the data that will be sent to the model.""" + + input: list[TResponseInputItem] + instructions: str | None + + +@dataclass +class CallModelData(Generic[TContext]): + """Data passed to `RunConfig.call_model_input_filter` prior to model call.""" + + model_data: ModelInputData + agent: Agent[TContext] + context: TContext | None + + +@dataclass +class _ServerConversationTracker: + """Tracks server-side conversation state for either conversation_id or + previous_response_id modes.""" + + conversation_id: str | None = None + previous_response_id: str | None = None + sent_items: set[int] = field(default_factory=set) + server_items: set[int] = field(default_factory=set) + + def track_server_items(self, model_response: ModelResponse) -> None: + for output_item in model_response.output: + self.server_items.add(id(output_item)) + + # Update previous_response_id only when using previous_response_id + if ( + self.conversation_id is None + and self.previous_response_id is not None + and model_response.response_id is not None + ): + self.previous_response_id = model_response.response_id + + def prepare_input( + self, + original_input: str | list[TResponseInputItem], + generated_items: list[RunItem], + ) -> list[TResponseInputItem]: + input_items: list[TResponseInputItem] = [] + + # On first call (when there are no generated items yet), include the original input + if not generated_items: + input_items.extend(ItemHelpers.input_to_new_input_list(original_input)) + + # Process generated_items, skip items already sent or from server + for item in generated_items: + raw_item_id = id(item.raw_item) + + if raw_item_id in self.sent_items or raw_item_id in self.server_items: + continue + input_items.append(item.to_input_item()) + self.sent_items.add(raw_item_id) + + return input_items + + +# Type alias for the optional input filter callback +CallModelInputFilter = Callable[[CallModelData[Any]], MaybeAwaitable[ModelInputData]] + + +@dataclass +class RunConfig: + """Configures settings for the entire agent run.""" + + model: str | Model | None = None + """The model to use for the entire agent run. If set, will override the model set on every + agent. The model_provider passed in below must be able to resolve this model name. + """ + + model_provider: ModelProvider = field(default_factory=MultiProvider) + """The model provider to use when looking up string model names. Defaults to OpenAI.""" + + model_settings: ModelSettings | None = None + """Configure global model settings. Any non-null values will override the agent-specific model + settings. + """ + + handoff_input_filter: HandoffInputFilter | None = None + """A global input filter to apply to all handoffs. If `Handoff.input_filter` is set, then that + will take precedence. The input filter allows you to edit the inputs that are sent to the new + agent. See the documentation in `Handoff.input_filter` for more details. + """ + + input_guardrails: list[InputGuardrail[Any]] | None = None + """A list of input guardrails to run on the initial run input.""" + + output_guardrails: list[OutputGuardrail[Any]] | None = None + """A list of output guardrails to run on the final output of the run.""" + + tracing_disabled: bool = False + """Whether tracing is disabled for the agent run. If disabled, we will not trace the agent run. + """ + + trace_include_sensitive_data: bool = field( + default_factory=_default_trace_include_sensitive_data + ) + """Whether we include potentially sensitive data (for example: inputs/outputs of tool calls or + LLM generations) in traces. If False, we'll still create spans for these events, but the + sensitive data will not be included. + """ + + workflow_name: str = "Agent workflow" + """The name of the run, used for tracing. Should be a logical name for the run, like + "Code generation workflow" or "Customer support agent". + """ + + trace_id: str | None = None + """A custom trace ID to use for tracing. If not provided, we will generate a new trace ID.""" + + group_id: str | None = None + """ + A grouping identifier to use for tracing, to link multiple traces from the same conversation + or process. For example, you might use a chat thread ID. + """ + + trace_metadata: dict[str, Any] | None = None + """ + An optional dictionary of additional metadata to include with the trace. + """ + + session_input_callback: SessionInputCallback | None = None + """Defines how to handle session history when new input is provided. + - `None` (default): The new input is appended to the session history. + - `SessionInputCallback`: A custom function that receives the history and new input, and + returns the desired combined list of items. + """ + + call_model_input_filter: CallModelInputFilter | None = None + """ + Optional callback that is invoked immediately before calling the model. It receives the current + agent, context and the model input (instructions and input items), and must return a possibly + modified `ModelInputData` to use for the model call. + + This allows you to edit the input sent to the model e.g. to stay within a token limit. + For example, you can use this to add a system prompt to the input. + """ + + +class RunOptions(TypedDict, Generic[TContext]): + """Arguments for ``AgentRunner`` methods.""" + + context: NotRequired[TContext | None] + """The context for the run.""" + + max_turns: NotRequired[int] + """The maximum number of turns to run for.""" + + hooks: NotRequired[RunHooks[TContext] | None] + """Lifecycle hooks for the run.""" + + run_config: NotRequired[RunConfig | None] + """Run configuration.""" + + previous_response_id: NotRequired[str | None] + """The ID of the previous response, if any.""" + + conversation_id: NotRequired[str | None] + """The ID of the stored conversation, if any.""" + + session: NotRequired[Session | None] + """The session for the run.""" + + +class Runner: + @classmethod + async def run( + cls, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + conversation_id: str | None = None, + session: Session | None = None, + ) -> RunResult: + """ + Run a workflow starting at the given agent. + + The agent will run in a loop until a final output is generated. The loop runs like so: + + 1. The agent is invoked with the given input. + 2. If there is a final output (i.e. the agent produces something of type + `agent.output_type`), the loop terminates. + 3. If there's a handoff, we run the loop again, with the new agent. + 4. Else, we run tool calls (if any), and re-run the loop. + + In two cases, the agent may raise an exception: + + 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. + 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered + exception is raised. + + Note: + Only the first agent's input guardrails are run. + + Args: + starting_agent: The starting agent to run. + input: The initial input to the agent. You can pass a single string for a + user message, or a list of input items. + context: The context to run the agent with. + max_turns: The maximum number of turns to run the agent for. A turn is + defined as one AI invocation (including any tool calls that might occur). + hooks: An object that receives callbacks on various lifecycle events. + run_config: Global settings for the entire agent run. + previous_response_id: The ID of the previous response. If using OpenAI + models via the Responses API, this allows you to skip passing in input + from the previous turn. + conversation_id: The conversation ID + (https://platform.openai.com/docs/guides/conversation-state?api-mode=responses). + If provided, the conversation will be used to read and write items. + Every agent will have access to the conversation history so far, + and its output items will be written to the conversation. + We recommend only using this if you are exclusively using OpenAI models; + other model providers don't write to the Conversation object, + so you'll end up having partial conversations stored. + session: A session for automatic conversation history management. + + Returns: + A run result containing all the inputs, guardrail results and the output of + the last agent. Agents may perform handoffs, so we don't know the specific + type of the output. + """ + + runner = DEFAULT_AGENT_RUNNER + return await runner.run( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + + @classmethod + def run_sync( + cls, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + conversation_id: str | None = None, + session: Session | None = None, + ) -> RunResult: + """ + Run a workflow synchronously, starting at the given agent. + + Note: + This just wraps the `run` method, so it will not work if there's already an + event loop (e.g. inside an async function, or in a Jupyter notebook or async + context like FastAPI). For those cases, use the `run` method instead. + + The agent will run in a loop until a final output is generated. The loop runs: + + 1. The agent is invoked with the given input. + 2. If there is a final output (i.e. the agent produces something of type + `agent.output_type`), the loop terminates. + 3. If there's a handoff, we run the loop again, with the new agent. + 4. Else, we run tool calls (if any), and re-run the loop. + + In two cases, the agent may raise an exception: + + 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. + 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered + exception is raised. + + Note: + Only the first agent's input guardrails are run. + + Args: + starting_agent: The starting agent to run. + input: The initial input to the agent. You can pass a single string for a + user message, or a list of input items. + context: The context to run the agent with. + max_turns: The maximum number of turns to run the agent for. A turn is + defined as one AI invocation (including any tool calls that might occur). + hooks: An object that receives callbacks on various lifecycle events. + run_config: Global settings for the entire agent run. + previous_response_id: The ID of the previous response, if using OpenAI + models via the Responses API, this allows you to skip passing in input + from the previous turn. + conversation_id: The ID of the stored conversation, if any. + session: A session for automatic conversation history management. + + Returns: + A run result containing all the inputs, guardrail results and the output of + the last agent. Agents may perform handoffs, so we don't know the specific + type of the output. + """ + + runner = DEFAULT_AGENT_RUNNER + return runner.run_sync( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + + @classmethod + def run_streamed( + cls, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + conversation_id: str | None = None, + session: Session | None = None, + ) -> RunResultStreaming: + """ + Run a workflow starting at the given agent in streaming mode. + + The returned result object contains a method you can use to stream semantic + events as they are generated. + + The agent will run in a loop until a final output is generated. The loop runs like so: + + 1. The agent is invoked with the given input. + 2. If there is a final output (i.e. the agent produces something of type + `agent.output_type`), the loop terminates. + 3. If there's a handoff, we run the loop again, with the new agent. + 4. Else, we run tool calls (if any), and re-run the loop. + + In two cases, the agent may raise an exception: + + 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. + 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered + exception is raised. + + Note: + Only the first agent's input guardrails are run. + + Args: + starting_agent: The starting agent to run. + input: The initial input to the agent. You can pass a single string for a + user message, or a list of input items. + context: The context to run the agent with. + max_turns: The maximum number of turns to run the agent for. A turn is + defined as one AI invocation (including any tool calls that might occur). + hooks: An object that receives callbacks on various lifecycle events. + run_config: Global settings for the entire agent run. + previous_response_id: The ID of the previous response, if using OpenAI + models via the Responses API, this allows you to skip passing in input + from the previous turn. + conversation_id: The ID of the stored conversation, if any. + session: A session for automatic conversation history management. + + Returns: + A result object that contains data about the run, as well as a method to + stream events. + """ + + runner = DEFAULT_AGENT_RUNNER + return runner.run_streamed( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + + +class AgentRunner: + """ + WARNING: this class is experimental and not part of the public API + It should not be used directly or subclassed. + """ + + async def run( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + **kwargs: Unpack[RunOptions[TContext]], + ) -> RunResult: + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") + session = kwargs.get("session") + if run_config is None: + run_config = RunConfig() + + if conversation_id is not None or previous_response_id is not None: + server_conversation_tracker = _ServerConversationTracker( + conversation_id=conversation_id, previous_response_id=previous_response_id + ) + else: + server_conversation_tracker = None + + # Keep original user input separate from session-prepared input + original_user_input = input + prepared_input = await self._prepare_input_with_session( + input, session, run_config.session_input_callback + ) + + tool_use_tracker = AgentToolUseTracker() + + with TraceCtxManager( + workflow_name=run_config.workflow_name, + trace_id=run_config.trace_id, + group_id=run_config.group_id, + metadata=run_config.trace_metadata, + disabled=run_config.tracing_disabled, + ): + current_turn = 0 + original_input: str | list[TResponseInputItem] = _copy_str_or_list(prepared_input) + generated_items: list[RunItem] = [] + model_responses: list[ModelResponse] = [] + + context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( + context=context, # type: ignore + ) + + input_guardrail_results: list[InputGuardrailResult] = [] + tool_input_guardrail_results: list[ToolInputGuardrailResult] = [] + tool_output_guardrail_results: list[ToolOutputGuardrailResult] = [] + + current_span: Span[AgentSpanData] | None = None + current_agent = starting_agent + should_run_agent_start_hooks = True + + # save only the new user input to the session, not the combined history + await self._save_result_to_session(session, original_user_input, []) + + try: + while True: + all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper) + + # Start an agent span if we don't have one. This span is ended if the current + # agent changes, or if the agent loop ends. + if current_span is None: + handoff_names = [ + h.agent_name + for h in await AgentRunner._get_handoffs(current_agent, context_wrapper) + ] + if output_schema := AgentRunner._get_output_schema(current_agent): + output_type_name = output_schema.name() + else: + output_type_name = "str" + + current_span = agent_span( + name=current_agent.name, + handoffs=handoff_names, + output_type=output_type_name, + ) + current_span.start(mark_as_current=True) + current_span.span_data.tools = [t.name for t in all_tools] + + current_turn += 1 + if current_turn > max_turns: + _error_tracing.attach_error_to_span( + current_span, + SpanError( + message="Max turns exceeded", + data={"max_turns": max_turns}, + ), + ) + raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded") + + logger.debug( + f"Running agent {current_agent.name} (turn {current_turn})", + ) + + if current_turn == 1: + input_guardrail_results, turn_result = await asyncio.gather( + self._run_input_guardrails( + starting_agent, + starting_agent.input_guardrails + + (run_config.input_guardrails or []), + _copy_str_or_list(prepared_input), + context_wrapper, + ), + self._run_single_turn( + agent=current_agent, + all_tools=all_tools, + original_input=original_input, + generated_items=generated_items, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + should_run_agent_start_hooks=should_run_agent_start_hooks, + tool_use_tracker=tool_use_tracker, + server_conversation_tracker=server_conversation_tracker, + ), + ) + else: + turn_result = await self._run_single_turn( + agent=current_agent, + all_tools=all_tools, + original_input=original_input, + generated_items=generated_items, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + should_run_agent_start_hooks=should_run_agent_start_hooks, + tool_use_tracker=tool_use_tracker, + server_conversation_tracker=server_conversation_tracker, + ) + should_run_agent_start_hooks = False + + model_responses.append(turn_result.model_response) + original_input = turn_result.original_input + generated_items = turn_result.generated_items + + if server_conversation_tracker is not None: + server_conversation_tracker.track_server_items(turn_result.model_response) + + # Collect tool guardrail results from this turn + tool_input_guardrail_results.extend(turn_result.tool_input_guardrail_results) + tool_output_guardrail_results.extend(turn_result.tool_output_guardrail_results) + + if isinstance(turn_result.next_step, NextStepFinalOutput): + output_guardrail_results = await self._run_output_guardrails( + current_agent.output_guardrails + (run_config.output_guardrails or []), + current_agent, + turn_result.next_step.output, + context_wrapper, + ) + result = RunResult( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + final_output=turn_result.next_step.output, + _last_agent=current_agent, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=output_guardrail_results, + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, + context_wrapper=context_wrapper, + ) + if not any( + guardrail_result.output.tripwire_triggered + for guardrail_result in input_guardrail_results + ): + await self._save_result_to_session( + session, [], turn_result.new_step_items + ) + + return result + elif isinstance(turn_result.next_step, NextStepHandoff): + current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) + current_span.finish(reset_current=True) + current_span = None + should_run_agent_start_hooks = True + elif isinstance(turn_result.next_step, NextStepRunAgain): + if not any( + guardrail_result.output.tripwire_triggered + for guardrail_result in input_guardrail_results + ): + await self._save_result_to_session( + session, [], turn_result.new_step_items + ) + else: + raise AgentsException( + f"Unknown next step type: {type(turn_result.next_step)}" + ) + except AgentsException as exc: + exc.run_data = RunErrorDetails( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + last_agent=current_agent, + context_wrapper=context_wrapper, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=[], + ) + raise + finally: + if current_span: + current_span.finish(reset_current=True) + + def run_sync( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + **kwargs: Unpack[RunOptions[TContext]], + ) -> RunResult: + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = kwargs.get("hooks") + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") + session = kwargs.get("session") + + # Python 3.14 stopped implicitly wiring up a default event loop + # when synchronous code touches asyncio APIs for the first time. + # Several of our synchronous entry points (for example the Redis/SQLAlchemy session helpers) + # construct asyncio primitives like asyncio.Lock during __init__, + # which binds them to whatever loop happens to be the thread's default at that moment. + # To keep those locks usable we must ensure that run_sync reuses that same default loop + # instead of hopping over to a brand-new asyncio.run() loop. + try: + already_running_loop = asyncio.get_running_loop() + except RuntimeError: + already_running_loop = None + + if already_running_loop is not None: + # This method is only expected to run when no loop is already active. + # (Each thread has its own default loop; concurrent sync runs should happen on + # different threads. In a single thread use the async API to interleave work.) + raise RuntimeError( + "AgentRunner.run_sync() cannot be called when an event loop is already running." + ) + + policy = asyncio.get_event_loop_policy() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + try: + default_loop = policy.get_event_loop() + except RuntimeError: + default_loop = policy.new_event_loop() + policy.set_event_loop(default_loop) + + # We intentionally leave the default loop open even if we had to create one above. Session + # instances and other helpers stash loop-bound primitives between calls and expect to find + # the same default loop every time run_sync is invoked on this thread. + # Schedule the async run on the default loop so that we can manage cancellation explicitly. + task = default_loop.create_task( + self.run( + starting_agent, + input, + session=session, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + ) + ) + + try: + # Drive the coroutine to completion, harvesting the final RunResult. + return default_loop.run_until_complete(task) + except BaseException: + # If the sync caller aborts (KeyboardInterrupt, etc.), make sure the scheduled task + # does not linger on the shared loop by cancelling it and waiting for completion. + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + default_loop.run_until_complete(task) + raise + finally: + if not default_loop.is_closed(): + # The loop stays open for subsequent runs, but we still need to flush any pending + # async generators so their cleanup code executes promptly. + with contextlib.suppress(RuntimeError): + default_loop.run_until_complete(default_loop.shutdown_asyncgens()) + + def run_streamed( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + **kwargs: Unpack[RunOptions[TContext]], + ) -> RunResultStreaming: + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") + session = kwargs.get("session") + + if run_config is None: + run_config = RunConfig() + + # If there's already a trace, we don't create a new one. In addition, we can't end the + # trace here, because the actual work is done in `stream_events` and this method ends + # before that. + new_trace = ( + None + if get_current_trace() + else trace( + workflow_name=run_config.workflow_name, + trace_id=run_config.trace_id, + group_id=run_config.group_id, + metadata=run_config.trace_metadata, + disabled=run_config.tracing_disabled, + ) + ) + + output_schema = AgentRunner._get_output_schema(starting_agent) + context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( + context=context # type: ignore + ) + + streamed_result = RunResultStreaming( + input=_copy_str_or_list(input), + new_items=[], + current_agent=starting_agent, + raw_responses=[], + final_output=None, + is_complete=False, + current_turn=0, + max_turns=max_turns, + input_guardrail_results=[], + output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + _current_agent_output_schema=output_schema, + trace=new_trace, + context_wrapper=context_wrapper, + ) + + # Kick off the actual agent loop in the background and return the streamed result object. + streamed_result._run_impl_task = asyncio.create_task( + self._start_streaming( + starting_input=input, + streamed_result=streamed_result, + starting_agent=starting_agent, + max_turns=max_turns, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + ) + return streamed_result + + @staticmethod + def _validate_run_hooks( + hooks: RunHooksBase[Any, Agent[Any]] | AgentHooksBase[Any, Agent[Any]] | Any | None, + ) -> RunHooks[Any]: + if hooks is None: + return RunHooks[Any]() + input_hook_type = type(hooks).__name__ + if isinstance(hooks, AgentHooksBase): + raise TypeError( + "Run hooks must be instances of RunHooks. " + f"Received agent-scoped hooks ({input_hook_type}). " + "Attach AgentHooks to an Agent via Agent(..., hooks=...)." + ) + if not isinstance(hooks, RunHooksBase): + raise TypeError(f"Run hooks must be instances of RunHooks. Received {input_hook_type}.") + return hooks + + @classmethod + async def _maybe_filter_model_input( + cls, + *, + agent: Agent[TContext], + run_config: RunConfig, + context_wrapper: RunContextWrapper[TContext], + input_items: list[TResponseInputItem], + system_instructions: str | None, + ) -> ModelInputData: + """Apply optional call_model_input_filter to modify model input. + + Returns a `ModelInputData` that will be sent to the model. + """ + effective_instructions = system_instructions + effective_input: list[TResponseInputItem] = input_items + + if run_config.call_model_input_filter is None: + return ModelInputData(input=effective_input, instructions=effective_instructions) + + try: + model_input = ModelInputData( + input=effective_input.copy(), + instructions=effective_instructions, + ) + filter_payload: CallModelData[TContext] = CallModelData( + model_data=model_input, + agent=agent, + context=context_wrapper.context, + ) + maybe_updated = run_config.call_model_input_filter(filter_payload) + updated = await maybe_updated if inspect.isawaitable(maybe_updated) else maybe_updated + if not isinstance(updated, ModelInputData): + raise UserError("call_model_input_filter must return a ModelInputData instance") + return updated + except Exception as e: + _error_tracing.attach_error_to_current_span( + SpanError(message="Error in call_model_input_filter", data={"error": str(e)}) + ) + raise + + @classmethod + async def _run_input_guardrails_with_queue( + cls, + agent: Agent[Any], + guardrails: list[InputGuardrail[TContext]], + input: str | list[TResponseInputItem], + context: RunContextWrapper[TContext], + streamed_result: RunResultStreaming, + parent_span: Span[Any], + ): + queue = streamed_result._input_guardrail_queue + + # We'll run the guardrails and push them onto the queue as they complete + guardrail_tasks = [ + asyncio.create_task( + RunImpl.run_single_input_guardrail(agent, guardrail, input, context) + ) + for guardrail in guardrails + ] + guardrail_results = [] + try: + for done in asyncio.as_completed(guardrail_tasks): + result = await done + if result.output.tripwire_triggered: + _error_tracing.attach_error_to_span( + parent_span, + SpanError( + message="Guardrail tripwire triggered", + data={ + "guardrail": result.guardrail.get_name(), + "type": "input_guardrail", + }, + ), + ) + queue.put_nowait(result) + guardrail_results.append(result) + except Exception: + for t in guardrail_tasks: + t.cancel() + raise + + streamed_result.input_guardrail_results = guardrail_results + + @classmethod + async def _start_streaming( + cls, + starting_input: str | list[TResponseInputItem], + streamed_result: RunResultStreaming, + starting_agent: Agent[TContext], + max_turns: int, + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + previous_response_id: str | None, + conversation_id: str | None, + session: Session | None, + ): + if streamed_result.trace: + streamed_result.trace.start(mark_as_current=True) + + current_span: Span[AgentSpanData] | None = None + current_agent = starting_agent + current_turn = 0 + should_run_agent_start_hooks = True + tool_use_tracker = AgentToolUseTracker() + + if conversation_id is not None or previous_response_id is not None: + server_conversation_tracker = _ServerConversationTracker( + conversation_id=conversation_id, previous_response_id=previous_response_id + ) + else: + server_conversation_tracker = None + + streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) + + try: + # Prepare input with session if enabled + prepared_input = await AgentRunner._prepare_input_with_session( + starting_input, session, run_config.session_input_callback + ) + + # Update the streamed result with the prepared input + streamed_result.input = prepared_input + + await AgentRunner._save_result_to_session(session, starting_input, []) + + while True: + # Check for soft cancel before starting new turn + if streamed_result._cancel_mode == "after_turn": + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + + if streamed_result.is_complete: + break + + all_tools = await cls._get_all_tools(current_agent, context_wrapper) + + # Start an agent span if we don't have one. This span is ended if the current + # agent changes, or if the agent loop ends. + if current_span is None: + handoff_names = [ + h.agent_name + for h in await cls._get_handoffs(current_agent, context_wrapper) + ] + if output_schema := cls._get_output_schema(current_agent): + output_type_name = output_schema.name() + else: + output_type_name = "str" + + current_span = agent_span( + name=current_agent.name, + handoffs=handoff_names, + output_type=output_type_name, + ) + current_span.start(mark_as_current=True) + tool_names = [t.name for t in all_tools] + current_span.span_data.tools = tool_names + current_turn += 1 + streamed_result.current_turn = current_turn + + if current_turn > max_turns: + _error_tracing.attach_error_to_span( + current_span, + SpanError( + message="Max turns exceeded", + data={"max_turns": max_turns}, + ), + ) + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + + if current_turn == 1: + # Run the input guardrails in the background and put the results on the queue + streamed_result._input_guardrails_task = asyncio.create_task( + cls._run_input_guardrails_with_queue( + starting_agent, + starting_agent.input_guardrails + (run_config.input_guardrails or []), + ItemHelpers.input_to_new_input_list(prepared_input), + context_wrapper, + streamed_result, + current_span, + ) + ) + try: + turn_result = await cls._run_single_turn_streamed( + streamed_result, + current_agent, + hooks, + context_wrapper, + run_config, + should_run_agent_start_hooks, + tool_use_tracker, + all_tools, + server_conversation_tracker, + ) + should_run_agent_start_hooks = False + + streamed_result.raw_responses = streamed_result.raw_responses + [ + turn_result.model_response + ] + streamed_result.input = turn_result.original_input + streamed_result.new_items = turn_result.generated_items + + if server_conversation_tracker is not None: + server_conversation_tracker.track_server_items(turn_result.model_response) + + if isinstance(turn_result.next_step, NextStepHandoff): + # Save the conversation to session if enabled (before handoff) + # Note: Non-streaming path doesn't save handoff turns immediately, + # but streaming needs to for graceful cancellation support + if session is not None: + should_skip_session_save = ( + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) + ) + if should_skip_session_save is False: + await AgentRunner._save_result_to_session( + session, [], turn_result.new_step_items + ) + + current_agent = turn_result.next_step.new_agent + current_span.finish(reset_current=True) + current_span = None + should_run_agent_start_hooks = True + streamed_result._event_queue.put_nowait( + AgentUpdatedStreamEvent(new_agent=current_agent) + ) + + # Check for soft cancel after handoff + if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + elif isinstance(turn_result.next_step, NextStepFinalOutput): + streamed_result._output_guardrails_task = asyncio.create_task( + cls._run_output_guardrails( + current_agent.output_guardrails + + (run_config.output_guardrails or []), + current_agent, + turn_result.next_step.output, + context_wrapper, + ) + ) + + try: + output_guardrail_results = await streamed_result._output_guardrails_task + except Exception: + # Exceptions will be checked in the stream_events loop + output_guardrail_results = [] + + streamed_result.output_guardrail_results = output_guardrail_results + streamed_result.final_output = turn_result.next_step.output + streamed_result.is_complete = True + + # Save the conversation to session if enabled + if session is not None: + should_skip_session_save = ( + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) + ) + if should_skip_session_save is False: + await AgentRunner._save_result_to_session( + session, [], turn_result.new_step_items + ) + + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + elif isinstance(turn_result.next_step, NextStepRunAgain): + if session is not None: + should_skip_session_save = ( + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) + ) + if should_skip_session_save is False: + await AgentRunner._save_result_to_session( + session, [], turn_result.new_step_items + ) + + # Check for soft cancel after turn completion + if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + except AgentsException as exc: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + exc.run_data = RunErrorDetails( + input=streamed_result.input, + new_items=streamed_result.new_items, + raw_responses=streamed_result.raw_responses, + last_agent=current_agent, + context_wrapper=context_wrapper, + input_guardrail_results=streamed_result.input_guardrail_results, + output_guardrail_results=streamed_result.output_guardrail_results, + ) + raise + except Exception as e: + if current_span: + _error_tracing.attach_error_to_span( + current_span, + SpanError( + message="Error in agent run", + data={"error": str(e)}, + ), + ) + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + raise + + streamed_result.is_complete = True + finally: + if streamed_result._input_guardrails_task: + try: + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) + except Exception as e: + logger.debug( + f"Error in streamed_result finalize for agent {current_agent.name} - {e}" + ) + if current_span: + current_span.finish(reset_current=True) + if streamed_result.trace: + streamed_result.trace.finish(reset_current=True) + + @classmethod + async def _run_single_turn_streamed( + cls, + streamed_result: RunResultStreaming, + agent: Agent[TContext], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + should_run_agent_start_hooks: bool, + tool_use_tracker: AgentToolUseTracker, + all_tools: list[Tool], + server_conversation_tracker: _ServerConversationTracker | None = None, + ) -> SingleStepResult: + emitted_tool_call_ids: set[str] = set() + emitted_reasoning_item_ids: set[str] = set() + + if should_run_agent_start_hooks: + await asyncio.gather( + hooks.on_agent_start(context_wrapper, agent), + ( + agent.hooks.on_start(context_wrapper, agent) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + output_schema = cls._get_output_schema(agent) + + streamed_result.current_agent = agent + streamed_result._current_agent_output_schema = output_schema + + system_prompt, prompt_config = await asyncio.gather( + agent.get_system_prompt(context_wrapper), + agent.get_prompt(context_wrapper), + ) + + handoffs = await cls._get_handoffs(agent, context_wrapper) + model = cls._get_model(agent, run_config) + model_settings = agent.model_settings.resolve(run_config.model_settings) + model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + + final_response: ModelResponse | None = None + + if server_conversation_tracker is not None: + input = server_conversation_tracker.prepare_input( + streamed_result.input, streamed_result.new_items + ) + else: + input = ItemHelpers.input_to_new_input_list(streamed_result.input) + input.extend([item.to_input_item() for item in streamed_result.new_items]) + + # THIS IS THE RESOLVED CONFLICT BLOCK + filtered = await cls._maybe_filter_model_input( + agent=agent, + run_config=run_config, + context_wrapper=context_wrapper, + input_items=input, + system_instructions=system_prompt, + ) + + # Call hook just before the model is invoked, with the correct system_prompt. + await asyncio.gather( + hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), + ( + agent.hooks.on_llm_start( + context_wrapper, agent, filtered.instructions, filtered.input + ) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + previous_response_id = ( + server_conversation_tracker.previous_response_id + if server_conversation_tracker + else None + ) + conversation_id = ( + server_conversation_tracker.conversation_id if server_conversation_tracker else None + ) + + # 1. Stream the output events + # Build kwargs for model call + model_kwargs: dict[str, Any] = { + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + "prompt": prompt_config, + } + + # Only pass enable_structured_output_with_tools when enabled + # to maintain backward compatibility with third-party Model implementations + if agent.enable_structured_output_with_tools: + model_kwargs["enable_structured_output_with_tools"] = True + + async for event in model.stream_response( + filtered.instructions, + filtered.input, + model_settings, + all_tools, + output_schema, + handoffs, + get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data + ), + **model_kwargs, + ): + # Emit the raw event ASAP + streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) + + if isinstance(event, ResponseCompletedEvent): + usage = ( + Usage( + requests=1, + input_tokens=event.response.usage.input_tokens, + output_tokens=event.response.usage.output_tokens, + total_tokens=event.response.usage.total_tokens, + input_tokens_details=event.response.usage.input_tokens_details, + output_tokens_details=event.response.usage.output_tokens_details, + ) + if event.response.usage + else Usage() + ) + final_response = ModelResponse( + output=event.response.output, + usage=usage, + response_id=event.response.id, + ) + context_wrapper.usage.add(usage) + + if isinstance(event, ResponseOutputItemDoneEvent): + output_item = event.item + + if isinstance(output_item, _TOOL_CALL_TYPES): + call_id: str | None = getattr( + output_item, "call_id", getattr(output_item, "id", None) + ) + + if call_id and call_id not in emitted_tool_call_ids: + emitted_tool_call_ids.add(call_id) + + tool_item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, output_item), + agent=agent, + ) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent(item=tool_item, name="tool_called") + ) + + elif isinstance(output_item, ResponseReasoningItem): + reasoning_id: str | None = getattr(output_item, "id", None) + + if reasoning_id and reasoning_id not in emitted_reasoning_item_ids: + emitted_reasoning_item_ids.add(reasoning_id) + + reasoning_item = ReasoningItem(raw_item=output_item, agent=agent) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created") + ) + + # Call hook just after the model response is finalized. + if final_response is not None: + await asyncio.gather( + ( + agent.hooks.on_llm_end(context_wrapper, agent, final_response) + if agent.hooks + else _coro.noop_coroutine() + ), + hooks.on_llm_end(context_wrapper, agent, final_response), + ) + + # 2. At this point, the streaming is complete for this turn of the agent loop. + if not final_response: + raise ModelBehaviorError("Model did not produce a final response!") + + # 3. Now, we can process the turn as we do in the non-streaming case + single_step_result = await cls._get_single_step_result_from_response( + agent=agent, + original_input=streamed_result.input, + pre_step_items=streamed_result.new_items, + new_response=final_response, + output_schema=output_schema, + all_tools=all_tools, + handoffs=handoffs, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + tool_use_tracker=tool_use_tracker, + event_queue=streamed_result._event_queue, + ) + + import dataclasses as _dc + + # Filter out items that have already been sent to avoid duplicates + items_to_filter = single_step_result.new_step_items + + if emitted_tool_call_ids: + # Filter out tool call items that were already emitted during streaming + items_to_filter = [ + item + for item in items_to_filter + if not ( + isinstance(item, ToolCallItem) + and ( + call_id := getattr( + item.raw_item, "call_id", getattr(item.raw_item, "id", None) + ) + ) + and call_id in emitted_tool_call_ids + ) + ] + + if emitted_reasoning_item_ids: + # Filter out reasoning items that were already emitted during streaming + items_to_filter = [ + item + for item in items_to_filter + if not ( + isinstance(item, ReasoningItem) + and (reasoning_id := getattr(item.raw_item, "id", None)) + and reasoning_id in emitted_reasoning_item_ids + ) + ] + + # Filter out HandoffCallItem to avoid duplicates (already sent earlier) + items_to_filter = [ + item for item in items_to_filter if not isinstance(item, HandoffCallItem) + ] + + # Create filtered result and send to queue + filtered_result = _dc.replace(single_step_result, new_step_items=items_to_filter) + RunImpl.stream_step_result_to_queue(filtered_result, streamed_result._event_queue) + return single_step_result + + @classmethod + async def _run_single_turn( + cls, + *, + agent: Agent[TContext], + all_tools: list[Tool], + original_input: str | list[TResponseInputItem], + generated_items: list[RunItem], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + should_run_agent_start_hooks: bool, + tool_use_tracker: AgentToolUseTracker, + server_conversation_tracker: _ServerConversationTracker | None = None, + ) -> SingleStepResult: + # Ensure we run the hooks before anything else + if should_run_agent_start_hooks: + await asyncio.gather( + hooks.on_agent_start(context_wrapper, agent), + ( + agent.hooks.on_start(context_wrapper, agent) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + system_prompt, prompt_config = await asyncio.gather( + agent.get_system_prompt(context_wrapper), + agent.get_prompt(context_wrapper), + ) + + output_schema = cls._get_output_schema(agent) + handoffs = await cls._get_handoffs(agent, context_wrapper) + if server_conversation_tracker is not None: + input = server_conversation_tracker.prepare_input(original_input, generated_items) + else: + input = ItemHelpers.input_to_new_input_list(original_input) + input.extend([generated_item.to_input_item() for generated_item in generated_items]) + + new_response = await cls._get_new_response( + agent, + system_prompt, + input, + output_schema, + all_tools, + handoffs, + hooks, + context_wrapper, + run_config, + tool_use_tracker, + server_conversation_tracker, + prompt_config, + ) + + return await cls._get_single_step_result_from_response( + agent=agent, + original_input=original_input, + pre_step_items=generated_items, + new_response=new_response, + output_schema=output_schema, + all_tools=all_tools, + handoffs=handoffs, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + tool_use_tracker=tool_use_tracker, + ) + + @classmethod + async def _get_single_step_result_from_response( + cls, + *, + agent: Agent[TContext], + all_tools: list[Tool], + original_input: str | list[TResponseInputItem], + pre_step_items: list[RunItem], + new_response: ModelResponse, + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + tool_use_tracker: AgentToolUseTracker, + event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None, + ) -> SingleStepResult: + processed_response = RunImpl.process_model_response( + agent=agent, + all_tools=all_tools, + response=new_response, + output_schema=output_schema, + handoffs=handoffs, + ) + + tool_use_tracker.add_tool_use(agent, processed_response.tools_used) + + # Send handoff items immediately for streaming, but avoid duplicates + if event_queue is not None and processed_response.new_items: + handoff_items = [ + item for item in processed_response.new_items if isinstance(item, HandoffCallItem) + ] + if handoff_items: + RunImpl.stream_step_items_to_queue(cast(list[RunItem], handoff_items), event_queue) + + return await RunImpl.execute_tools_and_side_effects( + agent=agent, + original_input=original_input, + pre_step_items=pre_step_items, + new_response=new_response, + processed_response=processed_response, + output_schema=output_schema, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + ) + + @classmethod + async def _get_single_step_result_from_streamed_response( + cls, + *, + agent: Agent[TContext], + all_tools: list[Tool], + streamed_result: RunResultStreaming, + new_response: ModelResponse, + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + tool_use_tracker: AgentToolUseTracker, + ) -> SingleStepResult: + original_input = streamed_result.input + pre_step_items = streamed_result.new_items + event_queue = streamed_result._event_queue + + processed_response = RunImpl.process_model_response( + agent=agent, + all_tools=all_tools, + response=new_response, + output_schema=output_schema, + handoffs=handoffs, + ) + new_items_processed_response = processed_response.new_items + tool_use_tracker.add_tool_use(agent, processed_response.tools_used) + RunImpl.stream_step_items_to_queue(new_items_processed_response, event_queue) + + single_step_result = await RunImpl.execute_tools_and_side_effects( + agent=agent, + original_input=original_input, + pre_step_items=pre_step_items, + new_response=new_response, + processed_response=processed_response, + output_schema=output_schema, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + ) + new_step_items = [ + item + for item in single_step_result.new_step_items + if item not in new_items_processed_response + ] + RunImpl.stream_step_items_to_queue(new_step_items, event_queue) + + return single_step_result + + @classmethod + async def _run_input_guardrails( + cls, + agent: Agent[Any], + guardrails: list[InputGuardrail[TContext]], + input: str | list[TResponseInputItem], + context: RunContextWrapper[TContext], + ) -> list[InputGuardrailResult]: + if not guardrails: + return [] + + guardrail_tasks = [ + asyncio.create_task( + RunImpl.run_single_input_guardrail(agent, guardrail, input, context) + ) + for guardrail in guardrails + ] + + guardrail_results = [] + + for done in asyncio.as_completed(guardrail_tasks): + result = await done + if result.output.tripwire_triggered: + # Cancel all guardrail tasks if a tripwire is triggered. + for t in guardrail_tasks: + t.cancel() + _error_tracing.attach_error_to_current_span( + SpanError( + message="Guardrail tripwire triggered", + data={"guardrail": result.guardrail.get_name()}, + ) + ) + raise InputGuardrailTripwireTriggered(result) + else: + guardrail_results.append(result) + + return guardrail_results + + @classmethod + async def _run_output_guardrails( + cls, + guardrails: list[OutputGuardrail[TContext]], + agent: Agent[TContext], + agent_output: Any, + context: RunContextWrapper[TContext], + ) -> list[OutputGuardrailResult]: + if not guardrails: + return [] + + guardrail_tasks = [ + asyncio.create_task( + RunImpl.run_single_output_guardrail(guardrail, agent, agent_output, context) + ) + for guardrail in guardrails + ] + + guardrail_results = [] + + for done in asyncio.as_completed(guardrail_tasks): + result = await done + if result.output.tripwire_triggered: + # Cancel all guardrail tasks if a tripwire is triggered. + for t in guardrail_tasks: + t.cancel() + _error_tracing.attach_error_to_current_span( + SpanError( + message="Guardrail tripwire triggered", + data={"guardrail": result.guardrail.get_name()}, + ) + ) + raise OutputGuardrailTripwireTriggered(result) + else: + guardrail_results.append(result) + + return guardrail_results + + @classmethod + async def _get_new_response( + cls, + agent: Agent[TContext], + system_prompt: str | None, + input: list[TResponseInputItem], + output_schema: AgentOutputSchemaBase | None, + all_tools: list[Tool], + handoffs: list[Handoff], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + tool_use_tracker: AgentToolUseTracker, + server_conversation_tracker: _ServerConversationTracker | None, + prompt_config: ResponsePromptParam | None, + ) -> ModelResponse: + # Allow user to modify model input right before the call, if configured + filtered = await cls._maybe_filter_model_input( + agent=agent, + run_config=run_config, + context_wrapper=context_wrapper, + input_items=input, + system_instructions=system_prompt, + ) + + model = cls._get_model(agent, run_config) + model_settings = agent.model_settings.resolve(run_config.model_settings) + model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + + # If we have run hooks, or if the agent has hooks, we need to call them before the LLM call + await asyncio.gather( + hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), + ( + agent.hooks.on_llm_start( + context_wrapper, + agent, + filtered.instructions, # Use filtered instructions + filtered.input, # Use filtered input + ) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + previous_response_id = ( + server_conversation_tracker.previous_response_id + if server_conversation_tracker + else None + ) + conversation_id = ( + server_conversation_tracker.conversation_id if server_conversation_tracker else None + ) + + # Build kwargs for model call + model_kwargs: dict[str, Any] = { + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + "prompt": prompt_config, + } + + # Only pass enable_structured_output_with_tools when enabled + # to maintain backward compatibility with third-party Model implementations + if agent.enable_structured_output_with_tools: + model_kwargs["enable_structured_output_with_tools"] = True + + new_response = await model.get_response( + system_instructions=filtered.instructions, + input=filtered.input, + model_settings=model_settings, + tools=all_tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data + ), + **model_kwargs, + ) + + context_wrapper.usage.add(new_response.usage) + + # If we have run hooks, or if the agent has hooks, we need to call them after the LLM call + await asyncio.gather( + ( + agent.hooks.on_llm_end(context_wrapper, agent, new_response) + if agent.hooks + else _coro.noop_coroutine() + ), + hooks.on_llm_end(context_wrapper, agent, new_response), + ) + + return new_response + + @classmethod + def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: + if agent.output_type is None or agent.output_type is str: + return None + elif isinstance(agent.output_type, AgentOutputSchemaBase): + return agent.output_type + + return AgentOutputSchema(agent.output_type) + + @classmethod + async def _get_handoffs( + cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] + ) -> list[Handoff]: + handoffs = [] + for handoff_item in agent.handoffs: + if isinstance(handoff_item, Handoff): + handoffs.append(handoff_item) + elif isinstance(handoff_item, Agent): + handoffs.append(handoff(handoff_item)) + + async def _check_handoff_enabled(handoff_obj: Handoff) -> bool: + attr = handoff_obj.is_enabled + if isinstance(attr, bool): + return attr + res = attr(context_wrapper, agent) + if inspect.isawaitable(res): + return bool(await res) + return bool(res) + + results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs)) + enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok] + return enabled + + @classmethod + async def _get_all_tools( + cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] + ) -> list[Tool]: + return await agent.get_all_tools(context_wrapper) + + @classmethod + def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: + if isinstance(run_config.model, Model): + return run_config.model + elif isinstance(run_config.model, str): + return run_config.model_provider.get_model(run_config.model) + elif isinstance(agent.model, Model): + return agent.model + + return run_config.model_provider.get_model(agent.model) + + @classmethod + async def _prepare_input_with_session( + cls, + input: str | list[TResponseInputItem], + session: Session | None, + session_input_callback: SessionInputCallback | None, + ) -> str | list[TResponseInputItem]: + """Prepare input by combining it with session history if enabled.""" + if session is None: + return input + + # If the user doesn't specify an input callback and pass a list as input + if isinstance(input, list) and not session_input_callback: + raise UserError( + "When using session memory, list inputs require a " + "`RunConfig.session_input_callback` to define how they should be merged " + "with the conversation history. If you don't want to use a callback, " + "provide your input as a string instead, or disable session memory " + "(session=None) and pass a list to manage the history manually." + ) + + # Get previous conversation history + history = await session.get_items() + + # Convert input to list format + new_input_list = ItemHelpers.input_to_new_input_list(input) + + if session_input_callback is None: + return history + new_input_list + elif callable(session_input_callback): + res = session_input_callback(history, new_input_list) + if inspect.isawaitable(res): + return await res + return res + else: + raise UserError( + f"Invalid `session_input_callback` value: {session_input_callback}. " + "Choose between `None` or a custom callable function." + ) + + @classmethod + async def _save_result_to_session( + cls, + session: Session | None, + original_input: str | list[TResponseInputItem], + new_items: list[RunItem], + ) -> None: + """ + Save the conversation turn to session. + It does not account for any filtering or modification performed by + `RunConfig.session_input_callback`. + """ + if session is None: + return + + # Convert original input to list format if needed + input_list = ItemHelpers.input_to_new_input_list(original_input) + + # Convert new items to input format + new_items_as_input = [item.to_input_item() for item in new_items] + + # Save all items from this turn + items_to_save = input_list + new_items_as_input + await session.add_items(items_to_save) + + @staticmethod + async def _input_guardrail_tripwire_triggered_for_stream( + streamed_result: RunResultStreaming, + ) -> bool: + """Return True if any input guardrail triggered during a streamed run.""" + + task = streamed_result._input_guardrails_task + if task is None: + return False + + if not task.done(): + await task + + return any( + guardrail_result.output.tripwire_triggered + for guardrail_result in streamed_result.input_guardrail_results + ) + + +DEFAULT_AGENT_RUNNER = AgentRunner() +_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes) + + +def _copy_str_or_list(input: str | list[TResponseInputItem]) -> str | list[TResponseInputItem]: + if isinstance(input, str): + return input + return input.copy() From cfcd77706f5657b9f5854df82f2b0ad1e0edb613 Mon Sep 17 00:00:00 2001 From: shayan-devv Date: Tue, 11 Nov 2025 11:50:53 +0500 Subject: [PATCH 3/4] refactor: move enable_structured_output_with_tools to LitellmModel Moved the enable_structured_output_with_tools parameter from the Agent class to LitellmModel.__init__() to minimize the diff and isolate changes within the LiteLLM adapter as requested during code review. Changes: - Added enable_structured_output_with_tools parameter to LitellmModel.__init__() - Stored as instance variable and used throughout LitellmModel - Removed parameter from Agent class and related validation - Removed parameter from Model interface (get_response / stream_response) - Removed parameter from Runner (no longer passed to model calls) - Removed parameter from OpenAI model implementations - Reverted test mock models to original signatures - Updated test_gemini_local.py for model-level configuration - Updated documentation to reflect model-level usage Before: Agent(model=..., enable_structured_output_with_tools=True) After: Agent(model=LitellmModel(..., enable_structured_output_with_tools=True)) --- docs/agents.md | 8 +++-- docs/models/litellm.md | 8 +++-- docs/models/structured_output_with_tools.md | 16 ++++++---- src/agents/agent.py | 16 ---------- src/agents/extensions/models/litellm_model.py | 15 +++------ src/agents/models/interface.py | 8 ----- src/agents/models/openai_chatcompletions.py | 11 ------- src/agents/models/openai_responses.py | 2 -- src/agents/run.py | 32 ++++--------------- tests/fake_model.py | 2 -- tests/test_agent_prompt.py | 1 - tests/test_gemini_local.py | 15 ++++++--- tests/test_streaming_tool_call_arguments.py | 2 -- tests/voice/test_workflow.py | 2 -- 14 files changed, 42 insertions(+), 96 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 14b5df295..a2e0db1bd 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -81,14 +81,16 @@ from agents.extensions.models.litellm_model import LitellmModel agent = Agent( name="Weather assistant", - model=LitellmModel("gemini/gemini-1.5-flash"), + model=LitellmModel( + "gemini/gemini-2.5-flash", + enable_structured_output_with_tools=True, # Required for Gemini + ), tools=[get_weather], output_type=WeatherReport, - enable_structured_output_with_tools=True, # Required for Gemini ) ``` -The `enable_structured_output_with_tools` parameter injects JSON formatting instructions into the system prompt as a workaround. This is only needed for models accessed via [`LitellmModel`][agents.extensions.models.litellm_model.LitellmModel] that lack native support. OpenAI models ignore this parameter. +The `enable_structured_output_with_tools` parameter on [`LitellmModel`][agents.extensions.models.litellm_model.LitellmModel] injects JSON formatting instructions into the system prompt as a workaround. This is only needed for models that lack native support for using tools and structured outputs simultaneously (like Gemini). See the [prompt injection documentation](models/structured_output_with_tools.md) for more details. diff --git a/docs/models/litellm.md b/docs/models/litellm.md index 163877925..5c8aecfb4 100644 --- a/docs/models/litellm.md +++ b/docs/models/litellm.md @@ -111,13 +111,15 @@ def analyze_data(query: str) -> dict: agent = Agent( name="Analyst", - model=LitellmModel("gemini/gemini-1.5-flash"), + model=LitellmModel( + "gemini/gemini-2.5-flash", + enable_structured_output_with_tools=True, # Required for Gemini + ), tools=[analyze_data], output_type=Report, - enable_structured_output_with_tools=True, # Required for Gemini ) ``` -The `enable_structured_output_with_tools` parameter enables a workaround that injects JSON formatting instructions into the system prompt instead of using the native API. This allows models like Gemini to return structured outputs even when using tools. +The `enable_structured_output_with_tools` parameter on `LitellmModel` enables a workaround that injects JSON formatting instructions into the system prompt instead of using the native API. This allows models like Gemini to return structured outputs even when using tools. See the [prompt injection documentation](structured_output_with_tools.md) for complete details. diff --git a/docs/models/structured_output_with_tools.md b/docs/models/structured_output_with_tools.md index 7c00c64fb..899a086b3 100644 --- a/docs/models/structured_output_with_tools.md +++ b/docs/models/structured_output_with_tools.md @@ -25,7 +25,7 @@ def get_weather(city: str) -> dict: # This causes an error with Gemini agent = Agent( - model=LitellmModel("gemini/gemini-1.5-flash"), + model=LitellmModel("gemini/gemini-2.5-flash"), tools=[get_weather], output_type=WeatherReport, # Error: can't use both! ) @@ -40,14 +40,16 @@ GeminiException BadRequestError - Function calling with a response mime type ## The Solution -Enable prompt injection by setting `enable_structured_output_with_tools=True` on your agent: +Enable prompt injection by setting `enable_structured_output_with_tools=True` on the `LitellmModel`: ```python agent = Agent( - model=LitellmModel("gemini/gemini-1.5-flash"), + model=LitellmModel( + "gemini/gemini-2.5-flash", + enable_structured_output_with_tools=True, # ← Enables the workaround + ), tools=[get_weather], output_type=WeatherReport, - enable_structured_output_with_tools=True, # ← Enables the workaround ) ``` @@ -90,10 +92,12 @@ async def main(): agent = Agent( name="WeatherBot", instructions="Use the get_weather tool, then provide a structured report.", - model=LitellmModel("gemini/gemini-1.5-flash"), + model=LitellmModel( + "gemini/gemini-2.5-flash", + enable_structured_output_with_tools=True, # Required for Gemini + ), tools=[get_weather], output_type=WeatherReport, - enable_structured_output_with_tools=True, # Required for Gemini ) result = await Runner.run(agent, "What's the weather in Tokyo?") diff --git a/src/agents/agent.py b/src/agents/agent.py index c05a2c02a..a061926b1 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -231,16 +231,6 @@ class Agent(AgentBase, Generic[TContext]): """Whether to reset the tool choice to the default value after a tool has been called. Defaults to True. This ensures that the agent doesn't enter an infinite loop of tool usage.""" - enable_structured_output_with_tools: bool = False - """Enable structured outputs when using tools on models that don't natively support both - simultaneously (e.g., Gemini). When enabled, injects JSON formatting instructions into the - system prompt as a workaround instead of using the native API. Defaults to False (use native - API support when available). - - Set to True when using models that don't support both features natively (e.g., Gemini via - LiteLLM). - """ - def __post_init__(self): from typing import get_origin @@ -374,12 +364,6 @@ def __post_init__(self): f"got {type(self.reset_tool_choice).__name__}" ) - if not isinstance(self.enable_structured_output_with_tools, bool): - raise TypeError( - f"Agent enable_structured_output_with_tools must be a boolean, " - f"got {type(self.enable_structured_output_with_tools).__name__}" - ) - def clone(self, **kwargs: Any) -> Agent[TContext]: """Make a copy of the agent, with the given arguments changed. Notes: diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index c0968b87f..e0709902c 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -74,10 +74,12 @@ def __init__( model: str, base_url: str | None = None, api_key: str | None = None, + enable_structured_output_with_tools: bool = False, ): self.model = model self.base_url = base_url self.api_key = api_key + self.enable_structured_output_with_tools = enable_structured_output_with_tools async def get_response( self, @@ -89,9 +91,8 @@ async def get_response( handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None = None, # unused - conversation_id: str | None = None, # unused + conversation_id: str | None = None, prompt: Any | None = None, - enable_structured_output_with_tools: bool = False, ) -> ModelResponse: with generation_span( model=str(self.model), @@ -110,7 +111,6 @@ async def get_response( tracing, stream=False, prompt=prompt, - enable_structured_output_with_tools=enable_structured_output_with_tools, ) message: litellm.types.utils.Message | None = None @@ -195,9 +195,8 @@ async def stream_response( handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None = None, # unused - conversation_id: str | None = None, # unused + conversation_id: str | None = None, prompt: Any | None = None, - enable_structured_output_with_tools: bool = False, ) -> AsyncIterator[TResponseStreamEvent]: with generation_span( model=str(self.model), @@ -216,7 +215,6 @@ async def stream_response( tracing, stream=True, prompt=prompt, - enable_structured_output_with_tools=enable_structured_output_with_tools, ) final_response: Response | None = None @@ -248,7 +246,6 @@ async def _fetch_response( tracing: ModelTracing, stream: Literal[True], prompt: Any | None = None, - enable_structured_output_with_tools: bool = False, ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ... @overload @@ -264,7 +261,6 @@ async def _fetch_response( tracing: ModelTracing, stream: Literal[False], prompt: Any | None = None, - enable_structured_output_with_tools: bool = False, ) -> litellm.types.utils.ModelResponse: ... async def _fetch_response( @@ -279,7 +275,6 @@ async def _fetch_response( tracing: ModelTracing, stream: bool = False, prompt: Any | None = None, - enable_structured_output_with_tools: bool = False, ) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]: # Preserve reasoning messages for tool calls when reasoning is on # This is needed for models like Claude 4 Sonnet/Opus which support interleaved thinking @@ -298,7 +293,7 @@ async def _fetch_response( # Check if we need to inject JSON output prompt for models that don't support # tools + structured output simultaneously (like Gemini) inject_json_prompt = should_inject_json_prompt( - output_schema, tools, enable_structured_output_with_tools + output_schema, tools, self.enable_structured_output_with_tools ) if inject_json_prompt and output_schema: json_prompt = get_json_output_prompt(output_schema) diff --git a/src/agents/models/interface.py b/src/agents/models/interface.py index f69946f8b..1e90a3a9d 100644 --- a/src/agents/models/interface.py +++ b/src/agents/models/interface.py @@ -50,7 +50,6 @@ async def get_response( previous_response_id: str | None, conversation_id: str | None, prompt: ResponsePromptParam | None, - enable_structured_output_with_tools: bool = False, ) -> ModelResponse: """Get a response from the model. @@ -66,9 +65,6 @@ async def get_response( except for the OpenAI Responses API. conversation_id: The ID of the stored conversation, if any. prompt: The prompt config to use for the model. - enable_structured_output_with_tools: Whether to inject JSON formatting instructions - into the system prompt when using structured outputs with tools. Required for - models that don't support both features natively (like Gemini). Returns: The full model response. @@ -89,7 +85,6 @@ def stream_response( previous_response_id: str | None, conversation_id: str | None, prompt: ResponsePromptParam | None, - enable_structured_output_with_tools: bool = False, ) -> AsyncIterator[TResponseStreamEvent]: """Stream a response from the model. @@ -105,9 +100,6 @@ def stream_response( except for the OpenAI Responses API. conversation_id: The ID of the stored conversation, if any. prompt: The prompt config to use for the model. - enable_structured_output_with_tools: Whether to inject JSON formatting instructions - into the system prompt when using structured outputs with tools. Required for - models that don't support both features natively (like Gemini). Returns: An iterator of response stream events, in OpenAI Responses format. diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 56b79ced3..91c1c6174 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -59,7 +59,6 @@ async def get_response( previous_response_id: str | None = None, # unused conversation_id: str | None = None, # unused prompt: ResponsePromptParam | None = None, - enable_structured_output_with_tools: bool = False, ) -> ModelResponse: with generation_span( model=str(self.model), @@ -77,7 +76,6 @@ async def get_response( tracing, stream=False, prompt=prompt, - enable_structured_output_with_tools=enable_structured_output_with_tools, ) message: ChatCompletionMessage | None = None @@ -149,7 +147,6 @@ async def stream_response( previous_response_id: str | None = None, # unused conversation_id: str | None = None, # unused prompt: ResponsePromptParam | None = None, - enable_structured_output_with_tools: bool = False, ) -> AsyncIterator[TResponseStreamEvent]: """ Yields a partial message as it is generated, as well as the usage information. @@ -170,7 +167,6 @@ async def stream_response( tracing, stream=True, prompt=prompt, - enable_structured_output_with_tools=enable_structured_output_with_tools, ) final_response: Response | None = None @@ -202,7 +198,6 @@ async def _fetch_response( tracing: ModelTracing, stream: Literal[True], prompt: ResponsePromptParam | None = None, - enable_structured_output_with_tools: bool = False, ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ... @overload @@ -218,7 +213,6 @@ async def _fetch_response( tracing: ModelTracing, stream: Literal[False], prompt: ResponsePromptParam | None = None, - enable_structured_output_with_tools: bool = False, ) -> ChatCompletion: ... async def _fetch_response( @@ -233,12 +227,7 @@ async def _fetch_response( tracing: ModelTracing, stream: bool = False, prompt: ResponsePromptParam | None = None, - enable_structured_output_with_tools: bool = False, ) -> ChatCompletion | tuple[Response, AsyncStream[ChatCompletionChunk]]: - # Note: enable_structured_output_with_tools parameter is accepted for interface consistency - # but not used for OpenAI models since they have native support for - # tools + structured outputs simultaneously - converted_messages = Converter.items_to_messages(input) if system_instructions: diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index ca48d17d8..6ef191914 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -84,7 +84,6 @@ async def get_response( previous_response_id: str | None = None, conversation_id: str | None = None, prompt: ResponsePromptParam | None = None, - enable_structured_output_with_tools: bool = False, ) -> ModelResponse: with response_span(disabled=tracing.is_disabled()) as span_response: try: @@ -162,7 +161,6 @@ async def stream_response( previous_response_id: str | None = None, conversation_id: str | None = None, prompt: ResponsePromptParam | None = None, - enable_structured_output_with_tools: bool = False, ) -> AsyncIterator[ResponseStreamEvent]: """ Yields a partial message as it is generated, as well as the usage information. diff --git a/src/agents/run.py b/src/agents/run.py index 6867daf33..55a0b59e3 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1287,18 +1287,6 @@ async def _run_single_turn_streamed( ) # 1. Stream the output events - # Build kwargs for model call - model_kwargs: dict[str, Any] = { - "previous_response_id": previous_response_id, - "conversation_id": conversation_id, - "prompt": prompt_config, - } - - # Only pass enable_structured_output_with_tools when enabled - # to maintain backward compatibility with third-party Model implementations - if agent.enable_structured_output_with_tools: - model_kwargs["enable_structured_output_with_tools"] = True - async for event in model.stream_response( filtered.instructions, filtered.input, @@ -1309,7 +1297,9 @@ async def _run_single_turn_streamed( get_model_tracing_impl( run_config.tracing_disabled, run_config.trace_include_sensitive_data ), - **model_kwargs, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt_config, ): # Emit the raw event ASAP streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) @@ -1732,18 +1722,6 @@ async def _get_new_response( server_conversation_tracker.conversation_id if server_conversation_tracker else None ) - # Build kwargs for model call - model_kwargs: dict[str, Any] = { - "previous_response_id": previous_response_id, - "conversation_id": conversation_id, - "prompt": prompt_config, - } - - # Only pass enable_structured_output_with_tools when enabled - # to maintain backward compatibility with third-party Model implementations - if agent.enable_structured_output_with_tools: - model_kwargs["enable_structured_output_with_tools"] = True - new_response = await model.get_response( system_instructions=filtered.instructions, input=filtered.input, @@ -1754,7 +1732,9 @@ async def _get_new_response( tracing=get_model_tracing_impl( run_config.tracing_disabled, run_config.trace_include_sensitive_data ), - **model_kwargs, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt_config, ) context_wrapper.usage.add(new_response.usage) diff --git a/tests/fake_model.py b/tests/fake_model.py index e6898dbe1..efedeb7fe 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -90,7 +90,6 @@ async def get_response( previous_response_id: str | None, conversation_id: str | None, prompt: Any | None, - enable_structured_output_with_tools: bool = False, ) -> ModelResponse: turn_args = { "system_instructions": system_instructions, @@ -141,7 +140,6 @@ async def stream_response( previous_response_id: str | None = None, conversation_id: str | None = None, prompt: Any | None = None, - enable_structured_output_with_tools: bool = False, ) -> AsyncIterator[TResponseStreamEvent]: turn_args = { "system_instructions": system_instructions, diff --git a/tests/test_agent_prompt.py b/tests/test_agent_prompt.py index 2e9334861..b11c78893 100644 --- a/tests/test_agent_prompt.py +++ b/tests/test_agent_prompt.py @@ -26,7 +26,6 @@ async def get_response( previous_response_id, conversation_id, prompt, - enable_structured_output_with_tools: bool = False, ): # Record the prompt that the agent resolved and passed in. self.last_prompt = prompt diff --git a/tests/test_gemini_local.py b/tests/test_gemini_local.py index 208364425..6d1d2b309 100644 --- a/tests/test_gemini_local.py +++ b/tests/test_gemini_local.py @@ -65,24 +65,31 @@ async def main(): print("\n🔍 The final system prompt sent to Gemini will be shown below") print("=" * 80) - # Create agent with prompt injection enabled + # Create agent with prompt injection enabled on the model agent = Agent( name="weather_assistant", instructions=( "You are a helpful weather assistant. Use the get_weather tool to " "fetch weather information, then provide a structured report." ), - model=LitellmModel("gemini/gemini-2.5-flash"), + model=LitellmModel( + "gemini/gemini-2.5-flash", + enable_structured_output_with_tools=True, # CRITICAL: Enable for Gemini! + ), tools=[get_weather], output_type=WeatherReport, - enable_structured_output_with_tools=True, # CRITICAL: Enable for Gemini! ) print("\nAgent Configuration:") print(" Model: gemini/gemini-2.5-flash") print(f" Tools: {[tool.name for tool in agent.tools]}") print(" Output Type: WeatherReport") - print(f" enable_structured_output_with_tools: {agent.enable_structured_output_with_tools}") + # Type check: ensure agent.model is LitellmModel + if isinstance(agent.model, LitellmModel): + print( + f" enable_structured_output_with_tools: " + f"{agent.model.enable_structured_output_with_tools}" + ) print(f"\n{'=' * 80}") print("Running agent with input: 'What's the weather in Tokyo?'") diff --git a/tests/test_streaming_tool_call_arguments.py b/tests/test_streaming_tool_call_arguments.py index 041a24713..8e0f847c4 100644 --- a/tests/test_streaming_tool_call_arguments.py +++ b/tests/test_streaming_tool_call_arguments.py @@ -59,7 +59,6 @@ async def get_response( previous_response_id: Optional[str], conversation_id: Optional[str], prompt: Optional[Any], - enable_structured_output_with_tools: bool = False, ): raise NotImplementedError("Use stream_response instead") @@ -76,7 +75,6 @@ async def stream_response( previous_response_id: Optional[str] = None, conversation_id: Optional[str] = None, prompt: Optional[Any] = None, - enable_structured_output_with_tools: bool = False, ) -> AsyncIterator[TResponseStreamEvent]: """Stream events that simulate real OpenAI streaming behavior for tool calls.""" self.last_turn_args = { diff --git a/tests/voice/test_workflow.py b/tests/voice/test_workflow.py index ca47faf6e..a12be1dd1 100644 --- a/tests/voice/test_workflow.py +++ b/tests/voice/test_workflow.py @@ -57,7 +57,6 @@ async def get_response( previous_response_id: str | None, conversation_id: str | None, prompt: Any | None, - enable_structured_output_with_tools: bool = False, ) -> ModelResponse: raise NotImplementedError("Not implemented") @@ -74,7 +73,6 @@ async def stream_response( previous_response_id: str | None, conversation_id: str | None, prompt: Any | None, - enable_structured_output_with_tools: bool = False, ) -> AsyncIterator[TResponseStreamEvent]: output = self.get_next_output() for item in output: From 1d140ffe0f7f3e22387374f63153efbdf9da7153 Mon Sep 17 00:00:00 2001 From: shayan-devv Date: Tue, 11 Nov 2025 15:25:17 +0500 Subject: [PATCH 4/4] fix: include handoffs when deciding to inject JSON prompt The JSON prompt injection was only triggered when tools list was non-empty, but handoffs are converted to function tools and added separately. This meant that agents using only handoffs with output_schema would not get the prompt injection even when enable_structured_output_with_tools=True, causing Gemini to error with 'Function calling with response mime type application/json is unsupported.' Changes: - Combine tools and handoffs before checking if JSON prompt should be injected - Add test case for handoffs-only scenario - Update inline comment to clarify why handoffs must be included This ensures the opt-in flag works correctly for multi-agent scenarios where an agent might use handoffs without regular tools. --- src/agents/extensions/models/litellm_model.py | 6 +- src/agents/models/interface.py | 250 +- src/agents/models/openai_chatcompletions.py | 718 ++-- src/agents/models/openai_responses.py | 1032 ++--- src/agents/run.py | 3790 ++++++++--------- tests/fake_model.py | 686 +-- tests/test_agent_prompt.py | 198 +- tests/test_streaming_tool_call_arguments.py | 746 ++-- tests/utils/test_prompts.py | 11 + tests/voice/test_workflow.py | 438 +- 10 files changed, 3945 insertions(+), 3930 deletions(-) diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index e0709902c..f3d13aa24 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -292,8 +292,12 @@ async def _fetch_response( # Check if we need to inject JSON output prompt for models that don't support # tools + structured output simultaneously (like Gemini) + # Note: handoffs are converted to function tools, so we need to include them in the check + tools_and_handoffs = list(tools) if tools else [] + if handoffs: + tools_and_handoffs.extend(handoffs) inject_json_prompt = should_inject_json_prompt( - output_schema, tools, self.enable_structured_output_with_tools + output_schema, tools_and_handoffs, self.enable_structured_output_with_tools ) if inject_json_prompt and output_schema: json_prompt = get_json_output_prompt(output_schema) diff --git a/src/agents/models/interface.py b/src/agents/models/interface.py index 1e90a3a9d..f25934780 100644 --- a/src/agents/models/interface.py +++ b/src/agents/models/interface.py @@ -1,125 +1,125 @@ -from __future__ import annotations - -import abc -import enum -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING - -from openai.types.responses.response_prompt_param import ResponsePromptParam - -from ..agent_output import AgentOutputSchemaBase -from ..handoffs import Handoff -from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent -from ..tool import Tool - -if TYPE_CHECKING: - from ..model_settings import ModelSettings - - -class ModelTracing(enum.Enum): - DISABLED = 0 - """Tracing is disabled entirely.""" - - ENABLED = 1 - """Tracing is enabled, and all data is included.""" - - ENABLED_WITHOUT_DATA = 2 - """Tracing is enabled, but inputs/outputs are not included.""" - - def is_disabled(self) -> bool: - return self == ModelTracing.DISABLED - - def include_data(self) -> bool: - return self == ModelTracing.ENABLED - - -class Model(abc.ABC): - """The base interface for calling an LLM.""" - - @abc.abstractmethod - async def get_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None, - conversation_id: str | None, - prompt: ResponsePromptParam | None, - ) -> ModelResponse: - """Get a response from the model. - - Args: - system_instructions: The system instructions to use. - input: The input items to the model, in OpenAI Responses format. - model_settings: The model settings to use. - tools: The tools available to the model. - output_schema: The output schema to use. - handoffs: The handoffs available to the model. - tracing: Tracing configuration. - previous_response_id: the ID of the previous response. Generally not used by the model, - except for the OpenAI Responses API. - conversation_id: The ID of the stored conversation, if any. - prompt: The prompt config to use for the model. - - Returns: - The full model response. - """ - pass - - @abc.abstractmethod - def stream_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None, - conversation_id: str | None, - prompt: ResponsePromptParam | None, - ) -> AsyncIterator[TResponseStreamEvent]: - """Stream a response from the model. - - Args: - system_instructions: The system instructions to use. - input: The input items to the model, in OpenAI Responses format. - model_settings: The model settings to use. - tools: The tools available to the model. - output_schema: The output schema to use. - handoffs: The handoffs available to the model. - tracing: Tracing configuration. - previous_response_id: the ID of the previous response. Generally not used by the model, - except for the OpenAI Responses API. - conversation_id: The ID of the stored conversation, if any. - prompt: The prompt config to use for the model. - - Returns: - An iterator of response stream events, in OpenAI Responses format. - """ - pass - - -class ModelProvider(abc.ABC): - """The base interface for a model provider. - - Model provider is responsible for looking up Models by name. - """ - - @abc.abstractmethod - def get_model(self, model_name: str | None) -> Model: - """Get a model by name. - - Args: - model_name: The name of the model to get. - - Returns: - The model. - """ +from __future__ import annotations + +import abc +import enum +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING + +from openai.types.responses.response_prompt_param import ResponsePromptParam + +from ..agent_output import AgentOutputSchemaBase +from ..handoffs import Handoff +from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent +from ..tool import Tool + +if TYPE_CHECKING: + from ..model_settings import ModelSettings + + +class ModelTracing(enum.Enum): + DISABLED = 0 + """Tracing is disabled entirely.""" + + ENABLED = 1 + """Tracing is enabled, and all data is included.""" + + ENABLED_WITHOUT_DATA = 2 + """Tracing is enabled, but inputs/outputs are not included.""" + + def is_disabled(self) -> bool: + return self == ModelTracing.DISABLED + + def include_data(self) -> bool: + return self == ModelTracing.ENABLED + + +class Model(abc.ABC): + """The base interface for calling an LLM.""" + + @abc.abstractmethod + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> ModelResponse: + """Get a response from the model. + + Args: + system_instructions: The system instructions to use. + input: The input items to the model, in OpenAI Responses format. + model_settings: The model settings to use. + tools: The tools available to the model. + output_schema: The output schema to use. + handoffs: The handoffs available to the model. + tracing: Tracing configuration. + previous_response_id: the ID of the previous response. Generally not used by the model, + except for the OpenAI Responses API. + conversation_id: The ID of the stored conversation, if any. + prompt: The prompt config to use for the model. + + Returns: + The full model response. + """ + pass + + @abc.abstractmethod + def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> AsyncIterator[TResponseStreamEvent]: + """Stream a response from the model. + + Args: + system_instructions: The system instructions to use. + input: The input items to the model, in OpenAI Responses format. + model_settings: The model settings to use. + tools: The tools available to the model. + output_schema: The output schema to use. + handoffs: The handoffs available to the model. + tracing: Tracing configuration. + previous_response_id: the ID of the previous response. Generally not used by the model, + except for the OpenAI Responses API. + conversation_id: The ID of the stored conversation, if any. + prompt: The prompt config to use for the model. + + Returns: + An iterator of response stream events, in OpenAI Responses format. + """ + pass + + +class ModelProvider(abc.ABC): + """The base interface for a model provider. + + Model provider is responsible for looking up Models by name. + """ + + @abc.abstractmethod + def get_model(self, model_name: str | None) -> Model: + """Get a model by name. + + Args: + model_name: The name of the model to get. + + Returns: + The model. + """ diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 91c1c6174..d6cf662d2 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -1,359 +1,359 @@ -from __future__ import annotations - -import json -import time -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any, Literal, cast, overload - -from openai import AsyncOpenAI, AsyncStream, Omit, omit -from openai.types import ChatModel -from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage -from openai.types.chat.chat_completion import Choice -from openai.types.responses import Response -from openai.types.responses.response_prompt_param import ResponsePromptParam -from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails - -from .. import _debug -from ..agent_output import AgentOutputSchemaBase -from ..handoffs import Handoff -from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent -from ..logger import logger -from ..tool import Tool -from ..tracing import generation_span -from ..tracing.span_data import GenerationSpanData -from ..tracing.spans import Span -from ..usage import Usage -from ..util._json import _to_dump_compatible -from .chatcmpl_converter import Converter -from .chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE, ChatCmplHelpers -from .chatcmpl_stream_handler import ChatCmplStreamHandler -from .fake_id import FAKE_RESPONSES_ID -from .interface import Model, ModelTracing -from .openai_responses import Converter as OpenAIResponsesConverter - -if TYPE_CHECKING: - from ..model_settings import ModelSettings - - -class OpenAIChatCompletionsModel(Model): - def __init__( - self, - model: str | ChatModel, - openai_client: AsyncOpenAI, - ) -> None: - self.model = model - self._client = openai_client - - def _non_null_or_omit(self, value: Any) -> Any: - return value if value is not None else omit - - async def get_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - previous_response_id: str | None = None, # unused - conversation_id: str | None = None, # unused - prompt: ResponsePromptParam | None = None, - ) -> ModelResponse: - with generation_span( - model=str(self.model), - model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, - disabled=tracing.is_disabled(), - ) as span_generation: - response = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - span_generation, - tracing, - stream=False, - prompt=prompt, - ) - - message: ChatCompletionMessage | None = None - first_choice: Choice | None = None - if response.choices and len(response.choices) > 0: - first_choice = response.choices[0] - message = first_choice.message - - if _debug.DONT_LOG_MODEL_DATA: - logger.debug("Received model response") - else: - if message is not None: - logger.debug( - "LLM resp:\n%s\n", - json.dumps(message.model_dump(), indent=2, ensure_ascii=False), - ) - else: - finish_reason = first_choice.finish_reason if first_choice else "-" - logger.debug(f"LLM resp had no message. finish_reason: {finish_reason}") - - usage = ( - Usage( - requests=1, - input_tokens=response.usage.prompt_tokens, - output_tokens=response.usage.completion_tokens, - total_tokens=response.usage.total_tokens, - input_tokens_details=InputTokensDetails( - cached_tokens=getattr( - response.usage.prompt_tokens_details, "cached_tokens", 0 - ) - or 0, - ), - output_tokens_details=OutputTokensDetails( - reasoning_tokens=getattr( - response.usage.completion_tokens_details, "reasoning_tokens", 0 - ) - or 0, - ), - ) - if response.usage - else Usage() - ) - if tracing.include_data(): - span_generation.span_data.output = ( - [message.model_dump()] if message is not None else [] - ) - span_generation.span_data.usage = { - "input_tokens": usage.input_tokens, - "output_tokens": usage.output_tokens, - } - - items = Converter.message_to_output_items(message) if message is not None else [] - - return ModelResponse( - output=items, - usage=usage, - response_id=None, - ) - - async def stream_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - previous_response_id: str | None = None, # unused - conversation_id: str | None = None, # unused - prompt: ResponsePromptParam | None = None, - ) -> AsyncIterator[TResponseStreamEvent]: - """ - Yields a partial message as it is generated, as well as the usage information. - """ - with generation_span( - model=str(self.model), - model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, - disabled=tracing.is_disabled(), - ) as span_generation: - response, stream = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - span_generation, - tracing, - stream=True, - prompt=prompt, - ) - - final_response: Response | None = None - async for chunk in ChatCmplStreamHandler.handle_stream(response, stream): - yield chunk - - if chunk.type == "response.completed": - final_response = chunk.response - - if tracing.include_data() and final_response: - span_generation.span_data.output = [final_response.model_dump()] - - if final_response and final_response.usage: - span_generation.span_data.usage = { - "input_tokens": final_response.usage.input_tokens, - "output_tokens": final_response.usage.output_tokens, - } - - @overload - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - span: Span[GenerationSpanData], - tracing: ModelTracing, - stream: Literal[True], - prompt: ResponsePromptParam | None = None, - ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ... - - @overload - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - span: Span[GenerationSpanData], - tracing: ModelTracing, - stream: Literal[False], - prompt: ResponsePromptParam | None = None, - ) -> ChatCompletion: ... - - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - span: Span[GenerationSpanData], - tracing: ModelTracing, - stream: bool = False, - prompt: ResponsePromptParam | None = None, - ) -> ChatCompletion | tuple[Response, AsyncStream[ChatCompletionChunk]]: - converted_messages = Converter.items_to_messages(input) - - if system_instructions: - converted_messages.insert( - 0, - { - "content": system_instructions, - "role": "system", - }, - ) - converted_messages = _to_dump_compatible(converted_messages) - - if tracing.include_data(): - span.span_data.input = converted_messages - - if model_settings.parallel_tool_calls and tools: - parallel_tool_calls: bool | Omit = True - elif model_settings.parallel_tool_calls is False: - parallel_tool_calls = False - else: - parallel_tool_calls = omit - tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) - response_format = Converter.convert_response_format(output_schema) - - converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else [] - - for handoff in handoffs: - converted_tools.append(Converter.convert_handoff_tool(handoff)) - - converted_tools = _to_dump_compatible(converted_tools) - tools_param = converted_tools if converted_tools else omit - - if _debug.DONT_LOG_MODEL_DATA: - logger.debug("Calling LLM") - else: - messages_json = json.dumps( - converted_messages, - indent=2, - ensure_ascii=False, - ) - tools_json = json.dumps( - converted_tools, - indent=2, - ensure_ascii=False, - ) - logger.debug( - f"{messages_json}\n" - f"Tools:\n{tools_json}\n" - f"Stream: {stream}\n" - f"Tool choice: {tool_choice}\n" - f"Response format: {response_format}\n" - ) - - reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None - store = ChatCmplHelpers.get_store_param(self._get_client(), model_settings) - - stream_options = ChatCmplHelpers.get_stream_options_param( - self._get_client(), model_settings, stream=stream - ) - - stream_param: Literal[True] | Omit = True if stream else omit - - ret = await self._get_client().chat.completions.create( - model=self.model, - messages=converted_messages, - tools=tools_param, - temperature=self._non_null_or_omit(model_settings.temperature), - top_p=self._non_null_or_omit(model_settings.top_p), - frequency_penalty=self._non_null_or_omit(model_settings.frequency_penalty), - presence_penalty=self._non_null_or_omit(model_settings.presence_penalty), - max_tokens=self._non_null_or_omit(model_settings.max_tokens), - tool_choice=tool_choice, - response_format=response_format, - parallel_tool_calls=parallel_tool_calls, - stream=cast(Any, stream_param), - stream_options=self._non_null_or_omit(stream_options), - store=self._non_null_or_omit(store), - reasoning_effort=self._non_null_or_omit(reasoning_effort), - verbosity=self._non_null_or_omit(model_settings.verbosity), - top_logprobs=self._non_null_or_omit(model_settings.top_logprobs), - extra_headers=self._merge_headers(model_settings), - extra_query=model_settings.extra_query, - extra_body=model_settings.extra_body, - metadata=self._non_null_or_omit(model_settings.metadata), - **(model_settings.extra_args or {}), - ) - - if isinstance(ret, ChatCompletion): - return ret - - responses_tool_choice = OpenAIResponsesConverter.convert_tool_choice( - model_settings.tool_choice - ) - if responses_tool_choice is None or responses_tool_choice is omit: - # For Responses API data compatibility with Chat Completions patterns, - # we need to set "none" if tool_choice is absent. - # Without this fix, you'll get the following error: - # pydantic_core._pydantic_core.ValidationError: 4 validation errors for Response - # tool_choice.literal['none','auto','required'] - # Input should be 'none', 'auto' or 'required' - # see also: https://github.com/openai/openai-agents-python/issues/980 - responses_tool_choice = "auto" - - response = Response( - id=FAKE_RESPONSES_ID, - created_at=time.time(), - model=self.model, - object="response", - output=[], - tool_choice=responses_tool_choice, # type: ignore[arg-type] - top_p=model_settings.top_p, - temperature=model_settings.temperature, - tools=[], - parallel_tool_calls=parallel_tool_calls or False, - reasoning=model_settings.reasoning, - ) - return response, ret - - def _get_client(self) -> AsyncOpenAI: - if self._client is None: - self._client = AsyncOpenAI() - return self._client - - def _merge_headers(self, model_settings: ModelSettings): - return { - **HEADERS, - **(model_settings.extra_headers or {}), - **(HEADERS_OVERRIDE.get() or {}), - } +from __future__ import annotations + +import json +import time +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any, Literal, cast, overload + +from openai import AsyncOpenAI, AsyncStream, Omit, omit +from openai.types import ChatModel +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.responses import Response +from openai.types.responses.response_prompt_param import ResponsePromptParam +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from .. import _debug +from ..agent_output import AgentOutputSchemaBase +from ..handoffs import Handoff +from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent +from ..logger import logger +from ..tool import Tool +from ..tracing import generation_span +from ..tracing.span_data import GenerationSpanData +from ..tracing.spans import Span +from ..usage import Usage +from ..util._json import _to_dump_compatible +from .chatcmpl_converter import Converter +from .chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE, ChatCmplHelpers +from .chatcmpl_stream_handler import ChatCmplStreamHandler +from .fake_id import FAKE_RESPONSES_ID +from .interface import Model, ModelTracing +from .openai_responses import Converter as OpenAIResponsesConverter + +if TYPE_CHECKING: + from ..model_settings import ModelSettings + + +class OpenAIChatCompletionsModel(Model): + def __init__( + self, + model: str | ChatModel, + openai_client: AsyncOpenAI, + ) -> None: + self.model = model + self._client = openai_client + + def _non_null_or_omit(self, value: Any) -> Any: + return value if value is not None else omit + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused + prompt: ResponsePromptParam | None = None, + ) -> ModelResponse: + with generation_span( + model=str(self.model), + model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, + disabled=tracing.is_disabled(), + ) as span_generation: + response = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + span_generation, + tracing, + stream=False, + prompt=prompt, + ) + + message: ChatCompletionMessage | None = None + first_choice: Choice | None = None + if response.choices and len(response.choices) > 0: + first_choice = response.choices[0] + message = first_choice.message + + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("Received model response") + else: + if message is not None: + logger.debug( + "LLM resp:\n%s\n", + json.dumps(message.model_dump(), indent=2, ensure_ascii=False), + ) + else: + finish_reason = first_choice.finish_reason if first_choice else "-" + logger.debug(f"LLM resp had no message. finish_reason: {finish_reason}") + + usage = ( + Usage( + requests=1, + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response.usage.prompt_tokens_details, "cached_tokens", 0 + ) + or 0, + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response.usage.completion_tokens_details, "reasoning_tokens", 0 + ) + or 0, + ), + ) + if response.usage + else Usage() + ) + if tracing.include_data(): + span_generation.span_data.output = ( + [message.model_dump()] if message is not None else [] + ) + span_generation.span_data.usage = { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + } + + items = Converter.message_to_output_items(message) if message is not None else [] + + return ModelResponse( + output=items, + usage=usage, + response_id=None, + ) + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused + prompt: ResponsePromptParam | None = None, + ) -> AsyncIterator[TResponseStreamEvent]: + """ + Yields a partial message as it is generated, as well as the usage information. + """ + with generation_span( + model=str(self.model), + model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)}, + disabled=tracing.is_disabled(), + ) as span_generation: + response, stream = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + span_generation, + tracing, + stream=True, + prompt=prompt, + ) + + final_response: Response | None = None + async for chunk in ChatCmplStreamHandler.handle_stream(response, stream): + yield chunk + + if chunk.type == "response.completed": + final_response = chunk.response + + if tracing.include_data() and final_response: + span_generation.span_data.output = [final_response.model_dump()] + + if final_response and final_response.usage: + span_generation.span_data.usage = { + "input_tokens": final_response.usage.input_tokens, + "output_tokens": final_response.usage.output_tokens, + } + + @overload + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: Literal[True], + prompt: ResponsePromptParam | None = None, + ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ... + + @overload + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: Literal[False], + prompt: ResponsePromptParam | None = None, + ) -> ChatCompletion: ... + + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: bool = False, + prompt: ResponsePromptParam | None = None, + ) -> ChatCompletion | tuple[Response, AsyncStream[ChatCompletionChunk]]: + converted_messages = Converter.items_to_messages(input) + + if system_instructions: + converted_messages.insert( + 0, + { + "content": system_instructions, + "role": "system", + }, + ) + converted_messages = _to_dump_compatible(converted_messages) + + if tracing.include_data(): + span.span_data.input = converted_messages + + if model_settings.parallel_tool_calls and tools: + parallel_tool_calls: bool | Omit = True + elif model_settings.parallel_tool_calls is False: + parallel_tool_calls = False + else: + parallel_tool_calls = omit + tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) + response_format = Converter.convert_response_format(output_schema) + + converted_tools = [Converter.tool_to_openai(tool) for tool in tools] if tools else [] + + for handoff in handoffs: + converted_tools.append(Converter.convert_handoff_tool(handoff)) + + converted_tools = _to_dump_compatible(converted_tools) + tools_param = converted_tools if converted_tools else omit + + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("Calling LLM") + else: + messages_json = json.dumps( + converted_messages, + indent=2, + ensure_ascii=False, + ) + tools_json = json.dumps( + converted_tools, + indent=2, + ensure_ascii=False, + ) + logger.debug( + f"{messages_json}\n" + f"Tools:\n{tools_json}\n" + f"Stream: {stream}\n" + f"Tool choice: {tool_choice}\n" + f"Response format: {response_format}\n" + ) + + reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None + store = ChatCmplHelpers.get_store_param(self._get_client(), model_settings) + + stream_options = ChatCmplHelpers.get_stream_options_param( + self._get_client(), model_settings, stream=stream + ) + + stream_param: Literal[True] | Omit = True if stream else omit + + ret = await self._get_client().chat.completions.create( + model=self.model, + messages=converted_messages, + tools=tools_param, + temperature=self._non_null_or_omit(model_settings.temperature), + top_p=self._non_null_or_omit(model_settings.top_p), + frequency_penalty=self._non_null_or_omit(model_settings.frequency_penalty), + presence_penalty=self._non_null_or_omit(model_settings.presence_penalty), + max_tokens=self._non_null_or_omit(model_settings.max_tokens), + tool_choice=tool_choice, + response_format=response_format, + parallel_tool_calls=parallel_tool_calls, + stream=cast(Any, stream_param), + stream_options=self._non_null_or_omit(stream_options), + store=self._non_null_or_omit(store), + reasoning_effort=self._non_null_or_omit(reasoning_effort), + verbosity=self._non_null_or_omit(model_settings.verbosity), + top_logprobs=self._non_null_or_omit(model_settings.top_logprobs), + extra_headers=self._merge_headers(model_settings), + extra_query=model_settings.extra_query, + extra_body=model_settings.extra_body, + metadata=self._non_null_or_omit(model_settings.metadata), + **(model_settings.extra_args or {}), + ) + + if isinstance(ret, ChatCompletion): + return ret + + responses_tool_choice = OpenAIResponsesConverter.convert_tool_choice( + model_settings.tool_choice + ) + if responses_tool_choice is None or responses_tool_choice is omit: + # For Responses API data compatibility with Chat Completions patterns, + # we need to set "none" if tool_choice is absent. + # Without this fix, you'll get the following error: + # pydantic_core._pydantic_core.ValidationError: 4 validation errors for Response + # tool_choice.literal['none','auto','required'] + # Input should be 'none', 'auto' or 'required' + # see also: https://github.com/openai/openai-agents-python/issues/980 + responses_tool_choice = "auto" + + response = Response( + id=FAKE_RESPONSES_ID, + created_at=time.time(), + model=self.model, + object="response", + output=[], + tool_choice=responses_tool_choice, # type: ignore[arg-type] + top_p=model_settings.top_p, + temperature=model_settings.temperature, + tools=[], + parallel_tool_calls=parallel_tool_calls or False, + reasoning=model_settings.reasoning, + ) + return response, ret + + def _get_client(self) -> AsyncOpenAI: + if self._client is None: + self._client = AsyncOpenAI() + return self._client + + def _merge_headers(self, model_settings: ModelSettings): + return { + **HEADERS, + **(model_settings.extra_headers or {}), + **(HEADERS_OVERRIDE.get() or {}), + } diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 6ef191914..36a981404 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -1,516 +1,516 @@ -from __future__ import annotations - -import json -from collections.abc import AsyncIterator -from contextvars import ContextVar -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload - -from openai import APIStatusError, AsyncOpenAI, AsyncStream, Omit, omit -from openai.types import ChatModel -from openai.types.responses import ( - Response, - ResponseCompletedEvent, - ResponseIncludable, - ResponseStreamEvent, - ResponseTextConfigParam, - ToolParam, - response_create_params, -) -from openai.types.responses.response_prompt_param import ResponsePromptParam - -from .. import _debug -from ..agent_output import AgentOutputSchemaBase -from ..exceptions import UserError -from ..handoffs import Handoff -from ..items import ItemHelpers, ModelResponse, TResponseInputItem -from ..logger import logger -from ..model_settings import MCPToolChoice -from ..tool import ( - CodeInterpreterTool, - ComputerTool, - FileSearchTool, - FunctionTool, - HostedMCPTool, - ImageGenerationTool, - LocalShellTool, - Tool, - WebSearchTool, -) -from ..tracing import SpanError, response_span -from ..usage import Usage -from ..util._json import _to_dump_compatible -from ..version import __version__ -from .interface import Model, ModelTracing - -if TYPE_CHECKING: - from ..model_settings import ModelSettings - - -_USER_AGENT = f"Agents/Python {__version__}" -_HEADERS = {"User-Agent": _USER_AGENT} - -# Override headers used by the Responses API. -_HEADERS_OVERRIDE: ContextVar[dict[str, str] | None] = ContextVar( - "openai_responses_headers_override", default=None -) - - -class OpenAIResponsesModel(Model): - """ - Implementation of `Model` that uses the OpenAI Responses API. - """ - - def __init__( - self, - model: str | ChatModel, - openai_client: AsyncOpenAI, - ) -> None: - self.model = model - self._client = openai_client - - def _non_null_or_omit(self, value: Any) -> Any: - return value if value is not None else omit - - async def get_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - previous_response_id: str | None = None, - conversation_id: str | None = None, - prompt: ResponsePromptParam | None = None, - ) -> ModelResponse: - with response_span(disabled=tracing.is_disabled()) as span_response: - try: - response = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - stream=False, - prompt=prompt, - ) - - if _debug.DONT_LOG_MODEL_DATA: - logger.debug("LLM responded") - else: - logger.debug( - "LLM resp:\n" - f"""{ - json.dumps( - [x.model_dump() for x in response.output], - indent=2, - ensure_ascii=False, - ) - }\n""" - ) - - usage = ( - Usage( - requests=1, - input_tokens=response.usage.input_tokens, - output_tokens=response.usage.output_tokens, - total_tokens=response.usage.total_tokens, - input_tokens_details=response.usage.input_tokens_details, - output_tokens_details=response.usage.output_tokens_details, - ) - if response.usage - else Usage() - ) - - if tracing.include_data(): - span_response.span_data.response = response - span_response.span_data.input = input - except Exception as e: - span_response.set_error( - SpanError( - message="Error getting response", - data={ - "error": str(e) if tracing.include_data() else e.__class__.__name__, - }, - ) - ) - request_id = e.request_id if isinstance(e, APIStatusError) else None - logger.error(f"Error getting response: {e}. (request_id: {request_id})") - raise - - return ModelResponse( - output=response.output, - usage=usage, - response_id=response.id, - ) - - async def stream_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - previous_response_id: str | None = None, - conversation_id: str | None = None, - prompt: ResponsePromptParam | None = None, - ) -> AsyncIterator[ResponseStreamEvent]: - """ - Yields a partial message as it is generated, as well as the usage information. - """ - with response_span(disabled=tracing.is_disabled()) as span_response: - try: - stream = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - stream=True, - prompt=prompt, - ) - - final_response: Response | None = None - - async for chunk in stream: - if isinstance(chunk, ResponseCompletedEvent): - final_response = chunk.response - yield chunk - - if final_response and tracing.include_data(): - span_response.span_data.response = final_response - span_response.span_data.input = input - - except Exception as e: - span_response.set_error( - SpanError( - message="Error streaming response", - data={ - "error": str(e) if tracing.include_data() else e.__class__.__name__, - }, - ) - ) - logger.error(f"Error streaming response: {e}") - raise - - @overload - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - previous_response_id: str | None, - conversation_id: str | None, - stream: Literal[True], - prompt: ResponsePromptParam | None = None, - ) -> AsyncStream[ResponseStreamEvent]: ... - - @overload - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - previous_response_id: str | None, - conversation_id: str | None, - stream: Literal[False], - prompt: ResponsePromptParam | None = None, - ) -> Response: ... - - async def _fetch_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - previous_response_id: str | None = None, - conversation_id: str | None = None, - stream: Literal[True] | Literal[False] = False, - prompt: ResponsePromptParam | None = None, - ) -> Response | AsyncStream[ResponseStreamEvent]: - list_input = ItemHelpers.input_to_new_input_list(input) - list_input = _to_dump_compatible(list_input) - - if model_settings.parallel_tool_calls and tools: - parallel_tool_calls: bool | Omit = True - elif model_settings.parallel_tool_calls is False: - parallel_tool_calls = False - else: - parallel_tool_calls = omit - - tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) - converted_tools = Converter.convert_tools(tools, handoffs) - converted_tools_payload = _to_dump_compatible(converted_tools.tools) - response_format = Converter.get_response_format(output_schema) - - include_set: set[str] = set(converted_tools.includes) - if model_settings.response_include is not None: - include_set.update(model_settings.response_include) - if model_settings.top_logprobs is not None: - include_set.add("message.output_text.logprobs") - include = cast(list[ResponseIncludable], list(include_set)) - - if _debug.DONT_LOG_MODEL_DATA: - logger.debug("Calling LLM") - else: - input_json = json.dumps( - list_input, - indent=2, - ensure_ascii=False, - ) - tools_json = json.dumps( - converted_tools_payload, - indent=2, - ensure_ascii=False, - ) - logger.debug( - f"Calling LLM {self.model} with input:\n" - f"{input_json}\n" - f"Tools:\n{tools_json}\n" - f"Stream: {stream}\n" - f"Tool choice: {tool_choice}\n" - f"Response format: {response_format}\n" - f"Previous response id: {previous_response_id}\n" - f"Conversation id: {conversation_id}\n" - ) - - extra_args = dict(model_settings.extra_args or {}) - if model_settings.top_logprobs is not None: - extra_args["top_logprobs"] = model_settings.top_logprobs - if model_settings.verbosity is not None: - if response_format is not omit: - response_format["verbosity"] = model_settings.verbosity # type: ignore [index] - else: - response_format = {"verbosity": model_settings.verbosity} - - stream_param: Literal[True] | Omit = True if stream else omit - - response = await self._client.responses.create( - previous_response_id=self._non_null_or_omit(previous_response_id), - conversation=self._non_null_or_omit(conversation_id), - instructions=self._non_null_or_omit(system_instructions), - model=self.model, - input=list_input, - include=include, - tools=converted_tools_payload, - prompt=self._non_null_or_omit(prompt), - temperature=self._non_null_or_omit(model_settings.temperature), - top_p=self._non_null_or_omit(model_settings.top_p), - truncation=self._non_null_or_omit(model_settings.truncation), - max_output_tokens=self._non_null_or_omit(model_settings.max_tokens), - tool_choice=tool_choice, - parallel_tool_calls=parallel_tool_calls, - stream=cast(Any, stream_param), - extra_headers=self._merge_headers(model_settings), - extra_query=model_settings.extra_query, - extra_body=model_settings.extra_body, - text=response_format, - store=self._non_null_or_omit(model_settings.store), - reasoning=self._non_null_or_omit(model_settings.reasoning), - metadata=self._non_null_or_omit(model_settings.metadata), - **extra_args, - ) - return cast(Union[Response, AsyncStream[ResponseStreamEvent]], response) - - def _get_client(self) -> AsyncOpenAI: - if self._client is None: - self._client = AsyncOpenAI() - return self._client - - def _merge_headers(self, model_settings: ModelSettings): - return { - **_HEADERS, - **(model_settings.extra_headers or {}), - **(_HEADERS_OVERRIDE.get() or {}), - } - - -@dataclass -class ConvertedTools: - tools: list[ToolParam] - includes: list[ResponseIncludable] - - -class Converter: - @classmethod - def convert_tool_choice( - cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None - ) -> response_create_params.ToolChoice | Omit: - if tool_choice is None: - return omit - elif isinstance(tool_choice, MCPToolChoice): - return { - "server_label": tool_choice.server_label, - "type": "mcp", - "name": tool_choice.name, - } - elif tool_choice == "required": - return "required" - elif tool_choice == "auto": - return "auto" - elif tool_choice == "none": - return "none" - elif tool_choice == "file_search": - return { - "type": "file_search", - } - elif tool_choice == "web_search": - return { - # TODO: revist the type: ignore comment when ToolChoice is updated in the future - "type": "web_search", # type: ignore [typeddict-item] - } - elif tool_choice == "web_search_preview": - return { - "type": "web_search_preview", - } - elif tool_choice == "computer_use_preview": - return { - "type": "computer_use_preview", - } - elif tool_choice == "image_generation": - return { - "type": "image_generation", - } - elif tool_choice == "code_interpreter": - return { - "type": "code_interpreter", - } - elif tool_choice == "mcp": - # Note that this is still here for backwards compatibility, - # but migrating to MCPToolChoice is recommended. - return {"type": "mcp"} # type: ignore [typeddict-item] - else: - return { - "type": "function", - "name": tool_choice, - } - - @classmethod - def get_response_format( - cls, output_schema: AgentOutputSchemaBase | None - ) -> ResponseTextConfigParam | Omit: - if output_schema is None or output_schema.is_plain_text(): - return omit - else: - return { - "format": { - "type": "json_schema", - "name": "final_output", - "schema": output_schema.json_schema(), - "strict": output_schema.is_strict_json_schema(), - } - } - - @classmethod - def convert_tools( - cls, - tools: list[Tool], - handoffs: list[Handoff[Any, Any]], - ) -> ConvertedTools: - converted_tools: list[ToolParam] = [] - includes: list[ResponseIncludable] = [] - - computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)] - if len(computer_tools) > 1: - raise UserError(f"You can only provide one computer tool. Got {len(computer_tools)}") - - for tool in tools: - converted_tool, include = cls._convert_tool(tool) - converted_tools.append(converted_tool) - if include: - includes.append(include) - - for handoff in handoffs: - converted_tools.append(cls._convert_handoff_tool(handoff)) - - return ConvertedTools(tools=converted_tools, includes=includes) - - @classmethod - def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None]: - """Returns converted tool and includes""" - - if isinstance(tool, FunctionTool): - converted_tool: ToolParam = { - "name": tool.name, - "parameters": tool.params_json_schema, - "strict": tool.strict_json_schema, - "type": "function", - "description": tool.description, - } - includes: ResponseIncludable | None = None - elif isinstance(tool, WebSearchTool): - # TODO: revist the type: ignore comment when ToolParam is updated in the future - converted_tool = { - "type": "web_search", - "filters": tool.filters.model_dump() if tool.filters is not None else None, # type: ignore [typeddict-item] - "user_location": tool.user_location, - "search_context_size": tool.search_context_size, - } - includes = None - elif isinstance(tool, FileSearchTool): - converted_tool = { - "type": "file_search", - "vector_store_ids": tool.vector_store_ids, - } - if tool.max_num_results: - converted_tool["max_num_results"] = tool.max_num_results - if tool.ranking_options: - converted_tool["ranking_options"] = tool.ranking_options - if tool.filters: - converted_tool["filters"] = tool.filters - - includes = "file_search_call.results" if tool.include_search_results else None - elif isinstance(tool, ComputerTool): - converted_tool = { - "type": "computer_use_preview", - "environment": tool.computer.environment, - "display_width": tool.computer.dimensions[0], - "display_height": tool.computer.dimensions[1], - } - includes = None - elif isinstance(tool, HostedMCPTool): - converted_tool = tool.tool_config - includes = None - elif isinstance(tool, ImageGenerationTool): - converted_tool = tool.tool_config - includes = None - elif isinstance(tool, CodeInterpreterTool): - converted_tool = tool.tool_config - includes = None - elif isinstance(tool, LocalShellTool): - converted_tool = { - "type": "local_shell", - } - includes = None - else: - raise UserError(f"Unknown tool type: {type(tool)}, tool") - - return converted_tool, includes - - @classmethod - def _convert_handoff_tool(cls, handoff: Handoff) -> ToolParam: - return { - "name": handoff.tool_name, - "parameters": handoff.input_json_schema, - "strict": handoff.strict_json_schema, - "type": "function", - "description": handoff.tool_description, - } +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from contextvars import ContextVar +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Union, cast, overload + +from openai import APIStatusError, AsyncOpenAI, AsyncStream, Omit, omit +from openai.types import ChatModel +from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseIncludable, + ResponseStreamEvent, + ResponseTextConfigParam, + ToolParam, + response_create_params, +) +from openai.types.responses.response_prompt_param import ResponsePromptParam + +from .. import _debug +from ..agent_output import AgentOutputSchemaBase +from ..exceptions import UserError +from ..handoffs import Handoff +from ..items import ItemHelpers, ModelResponse, TResponseInputItem +from ..logger import logger +from ..model_settings import MCPToolChoice +from ..tool import ( + CodeInterpreterTool, + ComputerTool, + FileSearchTool, + FunctionTool, + HostedMCPTool, + ImageGenerationTool, + LocalShellTool, + Tool, + WebSearchTool, +) +from ..tracing import SpanError, response_span +from ..usage import Usage +from ..util._json import _to_dump_compatible +from ..version import __version__ +from .interface import Model, ModelTracing + +if TYPE_CHECKING: + from ..model_settings import ModelSettings + + +_USER_AGENT = f"Agents/Python {__version__}" +_HEADERS = {"User-Agent": _USER_AGENT} + +# Override headers used by the Responses API. +_HEADERS_OVERRIDE: ContextVar[dict[str, str] | None] = ContextVar( + "openai_responses_headers_override", default=None +) + + +class OpenAIResponsesModel(Model): + """ + Implementation of `Model` that uses the OpenAI Responses API. + """ + + def __init__( + self, + model: str | ChatModel, + openai_client: AsyncOpenAI, + ) -> None: + self.model = model + self._client = openai_client + + def _non_null_or_omit(self, value: Any) -> Any: + return value if value is not None else omit + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, + conversation_id: str | None = None, + prompt: ResponsePromptParam | None = None, + ) -> ModelResponse: + with response_span(disabled=tracing.is_disabled()) as span_response: + try: + response = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + stream=False, + prompt=prompt, + ) + + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("LLM responded") + else: + logger.debug( + "LLM resp:\n" + f"""{ + json.dumps( + [x.model_dump() for x in response.output], + indent=2, + ensure_ascii=False, + ) + }\n""" + ) + + usage = ( + Usage( + requests=1, + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.input_tokens_details, + output_tokens_details=response.usage.output_tokens_details, + ) + if response.usage + else Usage() + ) + + if tracing.include_data(): + span_response.span_data.response = response + span_response.span_data.input = input + except Exception as e: + span_response.set_error( + SpanError( + message="Error getting response", + data={ + "error": str(e) if tracing.include_data() else e.__class__.__name__, + }, + ) + ) + request_id = e.request_id if isinstance(e, APIStatusError) else None + logger.error(f"Error getting response: {e}. (request_id: {request_id})") + raise + + return ModelResponse( + output=response.output, + usage=usage, + response_id=response.id, + ) + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, + conversation_id: str | None = None, + prompt: ResponsePromptParam | None = None, + ) -> AsyncIterator[ResponseStreamEvent]: + """ + Yields a partial message as it is generated, as well as the usage information. + """ + with response_span(disabled=tracing.is_disabled()) as span_response: + try: + stream = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + stream=True, + prompt=prompt, + ) + + final_response: Response | None = None + + async for chunk in stream: + if isinstance(chunk, ResponseCompletedEvent): + final_response = chunk.response + yield chunk + + if final_response and tracing.include_data(): + span_response.span_data.response = final_response + span_response.span_data.input = input + + except Exception as e: + span_response.set_error( + SpanError( + message="Error streaming response", + data={ + "error": str(e) if tracing.include_data() else e.__class__.__name__, + }, + ) + ) + logger.error(f"Error streaming response: {e}") + raise + + @overload + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + previous_response_id: str | None, + conversation_id: str | None, + stream: Literal[True], + prompt: ResponsePromptParam | None = None, + ) -> AsyncStream[ResponseStreamEvent]: ... + + @overload + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + previous_response_id: str | None, + conversation_id: str | None, + stream: Literal[False], + prompt: ResponsePromptParam | None = None, + ) -> Response: ... + + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + previous_response_id: str | None = None, + conversation_id: str | None = None, + stream: Literal[True] | Literal[False] = False, + prompt: ResponsePromptParam | None = None, + ) -> Response | AsyncStream[ResponseStreamEvent]: + list_input = ItemHelpers.input_to_new_input_list(input) + list_input = _to_dump_compatible(list_input) + + if model_settings.parallel_tool_calls and tools: + parallel_tool_calls: bool | Omit = True + elif model_settings.parallel_tool_calls is False: + parallel_tool_calls = False + else: + parallel_tool_calls = omit + + tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) + converted_tools = Converter.convert_tools(tools, handoffs) + converted_tools_payload = _to_dump_compatible(converted_tools.tools) + response_format = Converter.get_response_format(output_schema) + + include_set: set[str] = set(converted_tools.includes) + if model_settings.response_include is not None: + include_set.update(model_settings.response_include) + if model_settings.top_logprobs is not None: + include_set.add("message.output_text.logprobs") + include = cast(list[ResponseIncludable], list(include_set)) + + if _debug.DONT_LOG_MODEL_DATA: + logger.debug("Calling LLM") + else: + input_json = json.dumps( + list_input, + indent=2, + ensure_ascii=False, + ) + tools_json = json.dumps( + converted_tools_payload, + indent=2, + ensure_ascii=False, + ) + logger.debug( + f"Calling LLM {self.model} with input:\n" + f"{input_json}\n" + f"Tools:\n{tools_json}\n" + f"Stream: {stream}\n" + f"Tool choice: {tool_choice}\n" + f"Response format: {response_format}\n" + f"Previous response id: {previous_response_id}\n" + f"Conversation id: {conversation_id}\n" + ) + + extra_args = dict(model_settings.extra_args or {}) + if model_settings.top_logprobs is not None: + extra_args["top_logprobs"] = model_settings.top_logprobs + if model_settings.verbosity is not None: + if response_format is not omit: + response_format["verbosity"] = model_settings.verbosity # type: ignore [index] + else: + response_format = {"verbosity": model_settings.verbosity} + + stream_param: Literal[True] | Omit = True if stream else omit + + response = await self._client.responses.create( + previous_response_id=self._non_null_or_omit(previous_response_id), + conversation=self._non_null_or_omit(conversation_id), + instructions=self._non_null_or_omit(system_instructions), + model=self.model, + input=list_input, + include=include, + tools=converted_tools_payload, + prompt=self._non_null_or_omit(prompt), + temperature=self._non_null_or_omit(model_settings.temperature), + top_p=self._non_null_or_omit(model_settings.top_p), + truncation=self._non_null_or_omit(model_settings.truncation), + max_output_tokens=self._non_null_or_omit(model_settings.max_tokens), + tool_choice=tool_choice, + parallel_tool_calls=parallel_tool_calls, + stream=cast(Any, stream_param), + extra_headers=self._merge_headers(model_settings), + extra_query=model_settings.extra_query, + extra_body=model_settings.extra_body, + text=response_format, + store=self._non_null_or_omit(model_settings.store), + reasoning=self._non_null_or_omit(model_settings.reasoning), + metadata=self._non_null_or_omit(model_settings.metadata), + **extra_args, + ) + return cast(Union[Response, AsyncStream[ResponseStreamEvent]], response) + + def _get_client(self) -> AsyncOpenAI: + if self._client is None: + self._client = AsyncOpenAI() + return self._client + + def _merge_headers(self, model_settings: ModelSettings): + return { + **_HEADERS, + **(model_settings.extra_headers or {}), + **(_HEADERS_OVERRIDE.get() or {}), + } + + +@dataclass +class ConvertedTools: + tools: list[ToolParam] + includes: list[ResponseIncludable] + + +class Converter: + @classmethod + def convert_tool_choice( + cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None + ) -> response_create_params.ToolChoice | Omit: + if tool_choice is None: + return omit + elif isinstance(tool_choice, MCPToolChoice): + return { + "server_label": tool_choice.server_label, + "type": "mcp", + "name": tool_choice.name, + } + elif tool_choice == "required": + return "required" + elif tool_choice == "auto": + return "auto" + elif tool_choice == "none": + return "none" + elif tool_choice == "file_search": + return { + "type": "file_search", + } + elif tool_choice == "web_search": + return { + # TODO: revist the type: ignore comment when ToolChoice is updated in the future + "type": "web_search", # type: ignore [typeddict-item] + } + elif tool_choice == "web_search_preview": + return { + "type": "web_search_preview", + } + elif tool_choice == "computer_use_preview": + return { + "type": "computer_use_preview", + } + elif tool_choice == "image_generation": + return { + "type": "image_generation", + } + elif tool_choice == "code_interpreter": + return { + "type": "code_interpreter", + } + elif tool_choice == "mcp": + # Note that this is still here for backwards compatibility, + # but migrating to MCPToolChoice is recommended. + return {"type": "mcp"} # type: ignore [typeddict-item] + else: + return { + "type": "function", + "name": tool_choice, + } + + @classmethod + def get_response_format( + cls, output_schema: AgentOutputSchemaBase | None + ) -> ResponseTextConfigParam | Omit: + if output_schema is None or output_schema.is_plain_text(): + return omit + else: + return { + "format": { + "type": "json_schema", + "name": "final_output", + "schema": output_schema.json_schema(), + "strict": output_schema.is_strict_json_schema(), + } + } + + @classmethod + def convert_tools( + cls, + tools: list[Tool], + handoffs: list[Handoff[Any, Any]], + ) -> ConvertedTools: + converted_tools: list[ToolParam] = [] + includes: list[ResponseIncludable] = [] + + computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)] + if len(computer_tools) > 1: + raise UserError(f"You can only provide one computer tool. Got {len(computer_tools)}") + + for tool in tools: + converted_tool, include = cls._convert_tool(tool) + converted_tools.append(converted_tool) + if include: + includes.append(include) + + for handoff in handoffs: + converted_tools.append(cls._convert_handoff_tool(handoff)) + + return ConvertedTools(tools=converted_tools, includes=includes) + + @classmethod + def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, ResponseIncludable | None]: + """Returns converted tool and includes""" + + if isinstance(tool, FunctionTool): + converted_tool: ToolParam = { + "name": tool.name, + "parameters": tool.params_json_schema, + "strict": tool.strict_json_schema, + "type": "function", + "description": tool.description, + } + includes: ResponseIncludable | None = None + elif isinstance(tool, WebSearchTool): + # TODO: revist the type: ignore comment when ToolParam is updated in the future + converted_tool = { + "type": "web_search", + "filters": tool.filters.model_dump() if tool.filters is not None else None, # type: ignore [typeddict-item] + "user_location": tool.user_location, + "search_context_size": tool.search_context_size, + } + includes = None + elif isinstance(tool, FileSearchTool): + converted_tool = { + "type": "file_search", + "vector_store_ids": tool.vector_store_ids, + } + if tool.max_num_results: + converted_tool["max_num_results"] = tool.max_num_results + if tool.ranking_options: + converted_tool["ranking_options"] = tool.ranking_options + if tool.filters: + converted_tool["filters"] = tool.filters + + includes = "file_search_call.results" if tool.include_search_results else None + elif isinstance(tool, ComputerTool): + converted_tool = { + "type": "computer_use_preview", + "environment": tool.computer.environment, + "display_width": tool.computer.dimensions[0], + "display_height": tool.computer.dimensions[1], + } + includes = None + elif isinstance(tool, HostedMCPTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, ImageGenerationTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, CodeInterpreterTool): + converted_tool = tool.tool_config + includes = None + elif isinstance(tool, LocalShellTool): + converted_tool = { + "type": "local_shell", + } + includes = None + else: + raise UserError(f"Unknown tool type: {type(tool)}, tool") + + return converted_tool, includes + + @classmethod + def _convert_handoff_tool(cls, handoff: Handoff) -> ToolParam: + return { + "name": handoff.tool_name, + "parameters": handoff.input_json_schema, + "strict": handoff.strict_json_schema, + "type": "function", + "description": handoff.tool_description, + } diff --git a/src/agents/run.py b/src/agents/run.py index 55a0b59e3..5b25df4f2 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1,1895 +1,1895 @@ -from __future__ import annotations - -import asyncio -import contextlib -import inspect -import os -import warnings -from dataclasses import dataclass, field -from typing import Any, Callable, Generic, cast, get_args - -from openai.types.responses import ( - ResponseCompletedEvent, - ResponseOutputItemDoneEvent, -) -from openai.types.responses.response_prompt_param import ( - ResponsePromptParam, -) -from openai.types.responses.response_reasoning_item import ResponseReasoningItem -from typing_extensions import NotRequired, TypedDict, Unpack - -from ._run_impl import ( - AgentToolUseTracker, - NextStepFinalOutput, - NextStepHandoff, - NextStepRunAgain, - QueueCompleteSentinel, - RunImpl, - SingleStepResult, - TraceCtxManager, - get_model_tracing_impl, -) -from .agent import Agent -from .agent_output import AgentOutputSchema, AgentOutputSchemaBase -from .exceptions import ( - AgentsException, - InputGuardrailTripwireTriggered, - MaxTurnsExceeded, - ModelBehaviorError, - OutputGuardrailTripwireTriggered, - RunErrorDetails, - UserError, -) -from .guardrail import ( - InputGuardrail, - InputGuardrailResult, - OutputGuardrail, - OutputGuardrailResult, -) -from .handoffs import Handoff, HandoffInputFilter, handoff -from .items import ( - HandoffCallItem, - ItemHelpers, - ModelResponse, - ReasoningItem, - RunItem, - ToolCallItem, - ToolCallItemTypes, - TResponseInputItem, -) -from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase -from .logger import logger -from .memory import Session, SessionInputCallback -from .model_settings import ModelSettings -from .models.interface import Model, ModelProvider -from .models.multi_provider import MultiProvider -from .result import RunResult, RunResultStreaming -from .run_context import RunContextWrapper, TContext -from .stream_events import ( - AgentUpdatedStreamEvent, - RawResponsesStreamEvent, - RunItemStreamEvent, - StreamEvent, -) -from .tool import Tool -from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult -from .tracing import Span, SpanError, agent_span, get_current_trace, trace -from .tracing.span_data import AgentSpanData -from .usage import Usage -from .util import _coro, _error_tracing -from .util._types import MaybeAwaitable - -DEFAULT_MAX_TURNS = 10 - -DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore -# the value is set at the end of the module - - -def set_default_agent_runner(runner: AgentRunner | None) -> None: - """ - WARNING: this class is experimental and not part of the public API - It should not be used directly. - """ - global DEFAULT_AGENT_RUNNER - DEFAULT_AGENT_RUNNER = runner or AgentRunner() - - -def get_default_agent_runner() -> AgentRunner: - """ - WARNING: this class is experimental and not part of the public API - It should not be used directly. - """ - global DEFAULT_AGENT_RUNNER - return DEFAULT_AGENT_RUNNER - - -def _default_trace_include_sensitive_data() -> bool: - """Returns the default value for trace_include_sensitive_data based on environment variable.""" - val = os.getenv("OPENAI_AGENTS_TRACE_INCLUDE_SENSITIVE_DATA", "true") - return val.strip().lower() in ("1", "true", "yes", "on") - - -@dataclass -class ModelInputData: - """Container for the data that will be sent to the model.""" - - input: list[TResponseInputItem] - instructions: str | None - - -@dataclass -class CallModelData(Generic[TContext]): - """Data passed to `RunConfig.call_model_input_filter` prior to model call.""" - - model_data: ModelInputData - agent: Agent[TContext] - context: TContext | None - - -@dataclass -class _ServerConversationTracker: - """Tracks server-side conversation state for either conversation_id or - previous_response_id modes.""" - - conversation_id: str | None = None - previous_response_id: str | None = None - sent_items: set[int] = field(default_factory=set) - server_items: set[int] = field(default_factory=set) - - def track_server_items(self, model_response: ModelResponse) -> None: - for output_item in model_response.output: - self.server_items.add(id(output_item)) - - # Update previous_response_id only when using previous_response_id - if ( - self.conversation_id is None - and self.previous_response_id is not None - and model_response.response_id is not None - ): - self.previous_response_id = model_response.response_id - - def prepare_input( - self, - original_input: str | list[TResponseInputItem], - generated_items: list[RunItem], - ) -> list[TResponseInputItem]: - input_items: list[TResponseInputItem] = [] - - # On first call (when there are no generated items yet), include the original input - if not generated_items: - input_items.extend(ItemHelpers.input_to_new_input_list(original_input)) - - # Process generated_items, skip items already sent or from server - for item in generated_items: - raw_item_id = id(item.raw_item) - - if raw_item_id in self.sent_items or raw_item_id in self.server_items: - continue - input_items.append(item.to_input_item()) - self.sent_items.add(raw_item_id) - - return input_items - - -# Type alias for the optional input filter callback -CallModelInputFilter = Callable[[CallModelData[Any]], MaybeAwaitable[ModelInputData]] - - -@dataclass -class RunConfig: - """Configures settings for the entire agent run.""" - - model: str | Model | None = None - """The model to use for the entire agent run. If set, will override the model set on every - agent. The model_provider passed in below must be able to resolve this model name. - """ - - model_provider: ModelProvider = field(default_factory=MultiProvider) - """The model provider to use when looking up string model names. Defaults to OpenAI.""" - - model_settings: ModelSettings | None = None - """Configure global model settings. Any non-null values will override the agent-specific model - settings. - """ - - handoff_input_filter: HandoffInputFilter | None = None - """A global input filter to apply to all handoffs. If `Handoff.input_filter` is set, then that - will take precedence. The input filter allows you to edit the inputs that are sent to the new - agent. See the documentation in `Handoff.input_filter` for more details. - """ - - input_guardrails: list[InputGuardrail[Any]] | None = None - """A list of input guardrails to run on the initial run input.""" - - output_guardrails: list[OutputGuardrail[Any]] | None = None - """A list of output guardrails to run on the final output of the run.""" - - tracing_disabled: bool = False - """Whether tracing is disabled for the agent run. If disabled, we will not trace the agent run. - """ - - trace_include_sensitive_data: bool = field( - default_factory=_default_trace_include_sensitive_data - ) - """Whether we include potentially sensitive data (for example: inputs/outputs of tool calls or - LLM generations) in traces. If False, we'll still create spans for these events, but the - sensitive data will not be included. - """ - - workflow_name: str = "Agent workflow" - """The name of the run, used for tracing. Should be a logical name for the run, like - "Code generation workflow" or "Customer support agent". - """ - - trace_id: str | None = None - """A custom trace ID to use for tracing. If not provided, we will generate a new trace ID.""" - - group_id: str | None = None - """ - A grouping identifier to use for tracing, to link multiple traces from the same conversation - or process. For example, you might use a chat thread ID. - """ - - trace_metadata: dict[str, Any] | None = None - """ - An optional dictionary of additional metadata to include with the trace. - """ - - session_input_callback: SessionInputCallback | None = None - """Defines how to handle session history when new input is provided. - - `None` (default): The new input is appended to the session history. - - `SessionInputCallback`: A custom function that receives the history and new input, and - returns the desired combined list of items. - """ - - call_model_input_filter: CallModelInputFilter | None = None - """ - Optional callback that is invoked immediately before calling the model. It receives the current - agent, context and the model input (instructions and input items), and must return a possibly - modified `ModelInputData` to use for the model call. - - This allows you to edit the input sent to the model e.g. to stay within a token limit. - For example, you can use this to add a system prompt to the input. - """ - - -class RunOptions(TypedDict, Generic[TContext]): - """Arguments for ``AgentRunner`` methods.""" - - context: NotRequired[TContext | None] - """The context for the run.""" - - max_turns: NotRequired[int] - """The maximum number of turns to run for.""" - - hooks: NotRequired[RunHooks[TContext] | None] - """Lifecycle hooks for the run.""" - - run_config: NotRequired[RunConfig | None] - """Run configuration.""" - - previous_response_id: NotRequired[str | None] - """The ID of the previous response, if any.""" - - conversation_id: NotRequired[str | None] - """The ID of the stored conversation, if any.""" - - session: NotRequired[Session | None] - """The session for the run.""" - - -class Runner: - @classmethod - async def run( - cls, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, - conversation_id: str | None = None, - session: Session | None = None, - ) -> RunResult: - """ - Run a workflow starting at the given agent. - - The agent will run in a loop until a final output is generated. The loop runs like so: - - 1. The agent is invoked with the given input. - 2. If there is a final output (i.e. the agent produces something of type - `agent.output_type`), the loop terminates. - 3. If there's a handoff, we run the loop again, with the new agent. - 4. Else, we run tool calls (if any), and re-run the loop. - - In two cases, the agent may raise an exception: - - 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. - 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered - exception is raised. - - Note: - Only the first agent's input guardrails are run. - - Args: - starting_agent: The starting agent to run. - input: The initial input to the agent. You can pass a single string for a - user message, or a list of input items. - context: The context to run the agent with. - max_turns: The maximum number of turns to run the agent for. A turn is - defined as one AI invocation (including any tool calls that might occur). - hooks: An object that receives callbacks on various lifecycle events. - run_config: Global settings for the entire agent run. - previous_response_id: The ID of the previous response. If using OpenAI - models via the Responses API, this allows you to skip passing in input - from the previous turn. - conversation_id: The conversation ID - (https://platform.openai.com/docs/guides/conversation-state?api-mode=responses). - If provided, the conversation will be used to read and write items. - Every agent will have access to the conversation history so far, - and its output items will be written to the conversation. - We recommend only using this if you are exclusively using OpenAI models; - other model providers don't write to the Conversation object, - so you'll end up having partial conversations stored. - session: A session for automatic conversation history management. - - Returns: - A run result containing all the inputs, guardrail results and the output of - the last agent. Agents may perform handoffs, so we don't know the specific - type of the output. - """ - - runner = DEFAULT_AGENT_RUNNER - return await runner.run( - starting_agent, - input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) - - @classmethod - def run_sync( - cls, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - *, - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, - conversation_id: str | None = None, - session: Session | None = None, - ) -> RunResult: - """ - Run a workflow synchronously, starting at the given agent. - - Note: - This just wraps the `run` method, so it will not work if there's already an - event loop (e.g. inside an async function, or in a Jupyter notebook or async - context like FastAPI). For those cases, use the `run` method instead. - - The agent will run in a loop until a final output is generated. The loop runs: - - 1. The agent is invoked with the given input. - 2. If there is a final output (i.e. the agent produces something of type - `agent.output_type`), the loop terminates. - 3. If there's a handoff, we run the loop again, with the new agent. - 4. Else, we run tool calls (if any), and re-run the loop. - - In two cases, the agent may raise an exception: - - 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. - 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered - exception is raised. - - Note: - Only the first agent's input guardrails are run. - - Args: - starting_agent: The starting agent to run. - input: The initial input to the agent. You can pass a single string for a - user message, or a list of input items. - context: The context to run the agent with. - max_turns: The maximum number of turns to run the agent for. A turn is - defined as one AI invocation (including any tool calls that might occur). - hooks: An object that receives callbacks on various lifecycle events. - run_config: Global settings for the entire agent run. - previous_response_id: The ID of the previous response, if using OpenAI - models via the Responses API, this allows you to skip passing in input - from the previous turn. - conversation_id: The ID of the stored conversation, if any. - session: A session for automatic conversation history management. - - Returns: - A run result containing all the inputs, guardrail results and the output of - the last agent. Agents may perform handoffs, so we don't know the specific - type of the output. - """ - - runner = DEFAULT_AGENT_RUNNER - return runner.run_sync( - starting_agent, - input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) - - @classmethod - def run_streamed( - cls, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, - hooks: RunHooks[TContext] | None = None, - run_config: RunConfig | None = None, - previous_response_id: str | None = None, - conversation_id: str | None = None, - session: Session | None = None, - ) -> RunResultStreaming: - """ - Run a workflow starting at the given agent in streaming mode. - - The returned result object contains a method you can use to stream semantic - events as they are generated. - - The agent will run in a loop until a final output is generated. The loop runs like so: - - 1. The agent is invoked with the given input. - 2. If there is a final output (i.e. the agent produces something of type - `agent.output_type`), the loop terminates. - 3. If there's a handoff, we run the loop again, with the new agent. - 4. Else, we run tool calls (if any), and re-run the loop. - - In two cases, the agent may raise an exception: - - 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. - 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered - exception is raised. - - Note: - Only the first agent's input guardrails are run. - - Args: - starting_agent: The starting agent to run. - input: The initial input to the agent. You can pass a single string for a - user message, or a list of input items. - context: The context to run the agent with. - max_turns: The maximum number of turns to run the agent for. A turn is - defined as one AI invocation (including any tool calls that might occur). - hooks: An object that receives callbacks on various lifecycle events. - run_config: Global settings for the entire agent run. - previous_response_id: The ID of the previous response, if using OpenAI - models via the Responses API, this allows you to skip passing in input - from the previous turn. - conversation_id: The ID of the stored conversation, if any. - session: A session for automatic conversation history management. - - Returns: - A result object that contains data about the run, as well as a method to - stream events. - """ - - runner = DEFAULT_AGENT_RUNNER - return runner.run_streamed( - starting_agent, - input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) - - -class AgentRunner: - """ - WARNING: this class is experimental and not part of the public API - It should not be used directly or subclassed. - """ - - async def run( - self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - **kwargs: Unpack[RunOptions[TContext]], - ) -> RunResult: - context = kwargs.get("context") - max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) - run_config = kwargs.get("run_config") - previous_response_id = kwargs.get("previous_response_id") - conversation_id = kwargs.get("conversation_id") - session = kwargs.get("session") - if run_config is None: - run_config = RunConfig() - - if conversation_id is not None or previous_response_id is not None: - server_conversation_tracker = _ServerConversationTracker( - conversation_id=conversation_id, previous_response_id=previous_response_id - ) - else: - server_conversation_tracker = None - - # Keep original user input separate from session-prepared input - original_user_input = input - prepared_input = await self._prepare_input_with_session( - input, session, run_config.session_input_callback - ) - - tool_use_tracker = AgentToolUseTracker() - - with TraceCtxManager( - workflow_name=run_config.workflow_name, - trace_id=run_config.trace_id, - group_id=run_config.group_id, - metadata=run_config.trace_metadata, - disabled=run_config.tracing_disabled, - ): - current_turn = 0 - original_input: str | list[TResponseInputItem] = _copy_str_or_list(prepared_input) - generated_items: list[RunItem] = [] - model_responses: list[ModelResponse] = [] - - context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( - context=context, # type: ignore - ) - - input_guardrail_results: list[InputGuardrailResult] = [] - tool_input_guardrail_results: list[ToolInputGuardrailResult] = [] - tool_output_guardrail_results: list[ToolOutputGuardrailResult] = [] - - current_span: Span[AgentSpanData] | None = None - current_agent = starting_agent - should_run_agent_start_hooks = True - - # save only the new user input to the session, not the combined history - await self._save_result_to_session(session, original_user_input, []) - - try: - while True: - all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper) - - # Start an agent span if we don't have one. This span is ended if the current - # agent changes, or if the agent loop ends. - if current_span is None: - handoff_names = [ - h.agent_name - for h in await AgentRunner._get_handoffs(current_agent, context_wrapper) - ] - if output_schema := AgentRunner._get_output_schema(current_agent): - output_type_name = output_schema.name() - else: - output_type_name = "str" - - current_span = agent_span( - name=current_agent.name, - handoffs=handoff_names, - output_type=output_type_name, - ) - current_span.start(mark_as_current=True) - current_span.span_data.tools = [t.name for t in all_tools] - - current_turn += 1 - if current_turn > max_turns: - _error_tracing.attach_error_to_span( - current_span, - SpanError( - message="Max turns exceeded", - data={"max_turns": max_turns}, - ), - ) - raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded") - - logger.debug( - f"Running agent {current_agent.name} (turn {current_turn})", - ) - - if current_turn == 1: - input_guardrail_results, turn_result = await asyncio.gather( - self._run_input_guardrails( - starting_agent, - starting_agent.input_guardrails - + (run_config.input_guardrails or []), - _copy_str_or_list(prepared_input), - context_wrapper, - ), - self._run_single_turn( - agent=current_agent, - all_tools=all_tools, - original_input=original_input, - generated_items=generated_items, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - should_run_agent_start_hooks=should_run_agent_start_hooks, - tool_use_tracker=tool_use_tracker, - server_conversation_tracker=server_conversation_tracker, - ), - ) - else: - turn_result = await self._run_single_turn( - agent=current_agent, - all_tools=all_tools, - original_input=original_input, - generated_items=generated_items, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - should_run_agent_start_hooks=should_run_agent_start_hooks, - tool_use_tracker=tool_use_tracker, - server_conversation_tracker=server_conversation_tracker, - ) - should_run_agent_start_hooks = False - - model_responses.append(turn_result.model_response) - original_input = turn_result.original_input - generated_items = turn_result.generated_items - - if server_conversation_tracker is not None: - server_conversation_tracker.track_server_items(turn_result.model_response) - - # Collect tool guardrail results from this turn - tool_input_guardrail_results.extend(turn_result.tool_input_guardrail_results) - tool_output_guardrail_results.extend(turn_result.tool_output_guardrail_results) - - if isinstance(turn_result.next_step, NextStepFinalOutput): - output_guardrail_results = await self._run_output_guardrails( - current_agent.output_guardrails + (run_config.output_guardrails or []), - current_agent, - turn_result.next_step.output, - context_wrapper, - ) - result = RunResult( - input=original_input, - new_items=generated_items, - raw_responses=model_responses, - final_output=turn_result.next_step.output, - _last_agent=current_agent, - input_guardrail_results=input_guardrail_results, - output_guardrail_results=output_guardrail_results, - tool_input_guardrail_results=tool_input_guardrail_results, - tool_output_guardrail_results=tool_output_guardrail_results, - context_wrapper=context_wrapper, - ) - if not any( - guardrail_result.output.tripwire_triggered - for guardrail_result in input_guardrail_results - ): - await self._save_result_to_session( - session, [], turn_result.new_step_items - ) - - return result - elif isinstance(turn_result.next_step, NextStepHandoff): - current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) - current_span.finish(reset_current=True) - current_span = None - should_run_agent_start_hooks = True - elif isinstance(turn_result.next_step, NextStepRunAgain): - if not any( - guardrail_result.output.tripwire_triggered - for guardrail_result in input_guardrail_results - ): - await self._save_result_to_session( - session, [], turn_result.new_step_items - ) - else: - raise AgentsException( - f"Unknown next step type: {type(turn_result.next_step)}" - ) - except AgentsException as exc: - exc.run_data = RunErrorDetails( - input=original_input, - new_items=generated_items, - raw_responses=model_responses, - last_agent=current_agent, - context_wrapper=context_wrapper, - input_guardrail_results=input_guardrail_results, - output_guardrail_results=[], - ) - raise - finally: - if current_span: - current_span.finish(reset_current=True) - - def run_sync( - self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - **kwargs: Unpack[RunOptions[TContext]], - ) -> RunResult: - context = kwargs.get("context") - max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = kwargs.get("hooks") - run_config = kwargs.get("run_config") - previous_response_id = kwargs.get("previous_response_id") - conversation_id = kwargs.get("conversation_id") - session = kwargs.get("session") - - # Python 3.14 stopped implicitly wiring up a default event loop - # when synchronous code touches asyncio APIs for the first time. - # Several of our synchronous entry points (for example the Redis/SQLAlchemy session helpers) - # construct asyncio primitives like asyncio.Lock during __init__, - # which binds them to whatever loop happens to be the thread's default at that moment. - # To keep those locks usable we must ensure that run_sync reuses that same default loop - # instead of hopping over to a brand-new asyncio.run() loop. - try: - already_running_loop = asyncio.get_running_loop() - except RuntimeError: - already_running_loop = None - - if already_running_loop is not None: - # This method is only expected to run when no loop is already active. - # (Each thread has its own default loop; concurrent sync runs should happen on - # different threads. In a single thread use the async API to interleave work.) - raise RuntimeError( - "AgentRunner.run_sync() cannot be called when an event loop is already running." - ) - - policy = asyncio.get_event_loop_policy() - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - try: - default_loop = policy.get_event_loop() - except RuntimeError: - default_loop = policy.new_event_loop() - policy.set_event_loop(default_loop) - - # We intentionally leave the default loop open even if we had to create one above. Session - # instances and other helpers stash loop-bound primitives between calls and expect to find - # the same default loop every time run_sync is invoked on this thread. - # Schedule the async run on the default loop so that we can manage cancellation explicitly. - task = default_loop.create_task( - self.run( - starting_agent, - input, - session=session, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - ) - ) - - try: - # Drive the coroutine to completion, harvesting the final RunResult. - return default_loop.run_until_complete(task) - except BaseException: - # If the sync caller aborts (KeyboardInterrupt, etc.), make sure the scheduled task - # does not linger on the shared loop by cancelling it and waiting for completion. - if not task.done(): - task.cancel() - with contextlib.suppress(asyncio.CancelledError): - default_loop.run_until_complete(task) - raise - finally: - if not default_loop.is_closed(): - # The loop stays open for subsequent runs, but we still need to flush any pending - # async generators so their cleanup code executes promptly. - with contextlib.suppress(RuntimeError): - default_loop.run_until_complete(default_loop.shutdown_asyncgens()) - - def run_streamed( - self, - starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], - **kwargs: Unpack[RunOptions[TContext]], - ) -> RunResultStreaming: - context = kwargs.get("context") - max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) - run_config = kwargs.get("run_config") - previous_response_id = kwargs.get("previous_response_id") - conversation_id = kwargs.get("conversation_id") - session = kwargs.get("session") - - if run_config is None: - run_config = RunConfig() - - # If there's already a trace, we don't create a new one. In addition, we can't end the - # trace here, because the actual work is done in `stream_events` and this method ends - # before that. - new_trace = ( - None - if get_current_trace() - else trace( - workflow_name=run_config.workflow_name, - trace_id=run_config.trace_id, - group_id=run_config.group_id, - metadata=run_config.trace_metadata, - disabled=run_config.tracing_disabled, - ) - ) - - output_schema = AgentRunner._get_output_schema(starting_agent) - context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( - context=context # type: ignore - ) - - streamed_result = RunResultStreaming( - input=_copy_str_or_list(input), - new_items=[], - current_agent=starting_agent, - raw_responses=[], - final_output=None, - is_complete=False, - current_turn=0, - max_turns=max_turns, - input_guardrail_results=[], - output_guardrail_results=[], - tool_input_guardrail_results=[], - tool_output_guardrail_results=[], - _current_agent_output_schema=output_schema, - trace=new_trace, - context_wrapper=context_wrapper, - ) - - # Kick off the actual agent loop in the background and return the streamed result object. - streamed_result._run_impl_task = asyncio.create_task( - self._start_streaming( - starting_input=input, - streamed_result=streamed_result, - starting_agent=starting_agent, - max_turns=max_turns, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - session=session, - ) - ) - return streamed_result - - @staticmethod - def _validate_run_hooks( - hooks: RunHooksBase[Any, Agent[Any]] | AgentHooksBase[Any, Agent[Any]] | Any | None, - ) -> RunHooks[Any]: - if hooks is None: - return RunHooks[Any]() - input_hook_type = type(hooks).__name__ - if isinstance(hooks, AgentHooksBase): - raise TypeError( - "Run hooks must be instances of RunHooks. " - f"Received agent-scoped hooks ({input_hook_type}). " - "Attach AgentHooks to an Agent via Agent(..., hooks=...)." - ) - if not isinstance(hooks, RunHooksBase): - raise TypeError(f"Run hooks must be instances of RunHooks. Received {input_hook_type}.") - return hooks - - @classmethod - async def _maybe_filter_model_input( - cls, - *, - agent: Agent[TContext], - run_config: RunConfig, - context_wrapper: RunContextWrapper[TContext], - input_items: list[TResponseInputItem], - system_instructions: str | None, - ) -> ModelInputData: - """Apply optional call_model_input_filter to modify model input. - - Returns a `ModelInputData` that will be sent to the model. - """ - effective_instructions = system_instructions - effective_input: list[TResponseInputItem] = input_items - - if run_config.call_model_input_filter is None: - return ModelInputData(input=effective_input, instructions=effective_instructions) - - try: - model_input = ModelInputData( - input=effective_input.copy(), - instructions=effective_instructions, - ) - filter_payload: CallModelData[TContext] = CallModelData( - model_data=model_input, - agent=agent, - context=context_wrapper.context, - ) - maybe_updated = run_config.call_model_input_filter(filter_payload) - updated = await maybe_updated if inspect.isawaitable(maybe_updated) else maybe_updated - if not isinstance(updated, ModelInputData): - raise UserError("call_model_input_filter must return a ModelInputData instance") - return updated - except Exception as e: - _error_tracing.attach_error_to_current_span( - SpanError(message="Error in call_model_input_filter", data={"error": str(e)}) - ) - raise - - @classmethod - async def _run_input_guardrails_with_queue( - cls, - agent: Agent[Any], - guardrails: list[InputGuardrail[TContext]], - input: str | list[TResponseInputItem], - context: RunContextWrapper[TContext], - streamed_result: RunResultStreaming, - parent_span: Span[Any], - ): - queue = streamed_result._input_guardrail_queue - - # We'll run the guardrails and push them onto the queue as they complete - guardrail_tasks = [ - asyncio.create_task( - RunImpl.run_single_input_guardrail(agent, guardrail, input, context) - ) - for guardrail in guardrails - ] - guardrail_results = [] - try: - for done in asyncio.as_completed(guardrail_tasks): - result = await done - if result.output.tripwire_triggered: - _error_tracing.attach_error_to_span( - parent_span, - SpanError( - message="Guardrail tripwire triggered", - data={ - "guardrail": result.guardrail.get_name(), - "type": "input_guardrail", - }, - ), - ) - queue.put_nowait(result) - guardrail_results.append(result) - except Exception: - for t in guardrail_tasks: - t.cancel() - raise - - streamed_result.input_guardrail_results = guardrail_results - - @classmethod - async def _start_streaming( - cls, - starting_input: str | list[TResponseInputItem], - streamed_result: RunResultStreaming, - starting_agent: Agent[TContext], - max_turns: int, - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - previous_response_id: str | None, - conversation_id: str | None, - session: Session | None, - ): - if streamed_result.trace: - streamed_result.trace.start(mark_as_current=True) - - current_span: Span[AgentSpanData] | None = None - current_agent = starting_agent - current_turn = 0 - should_run_agent_start_hooks = True - tool_use_tracker = AgentToolUseTracker() - - if conversation_id is not None or previous_response_id is not None: - server_conversation_tracker = _ServerConversationTracker( - conversation_id=conversation_id, previous_response_id=previous_response_id - ) - else: - server_conversation_tracker = None - - streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) - - try: - # Prepare input with session if enabled - prepared_input = await AgentRunner._prepare_input_with_session( - starting_input, session, run_config.session_input_callback - ) - - # Update the streamed result with the prepared input - streamed_result.input = prepared_input - - await AgentRunner._save_result_to_session(session, starting_input, []) - - while True: - # Check for soft cancel before starting new turn - if streamed_result._cancel_mode == "after_turn": - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - break - - if streamed_result.is_complete: - break - - all_tools = await cls._get_all_tools(current_agent, context_wrapper) - - # Start an agent span if we don't have one. This span is ended if the current - # agent changes, or if the agent loop ends. - if current_span is None: - handoff_names = [ - h.agent_name - for h in await cls._get_handoffs(current_agent, context_wrapper) - ] - if output_schema := cls._get_output_schema(current_agent): - output_type_name = output_schema.name() - else: - output_type_name = "str" - - current_span = agent_span( - name=current_agent.name, - handoffs=handoff_names, - output_type=output_type_name, - ) - current_span.start(mark_as_current=True) - tool_names = [t.name for t in all_tools] - current_span.span_data.tools = tool_names - current_turn += 1 - streamed_result.current_turn = current_turn - - if current_turn > max_turns: - _error_tracing.attach_error_to_span( - current_span, - SpanError( - message="Max turns exceeded", - data={"max_turns": max_turns}, - ), - ) - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - break - - if current_turn == 1: - # Run the input guardrails in the background and put the results on the queue - streamed_result._input_guardrails_task = asyncio.create_task( - cls._run_input_guardrails_with_queue( - starting_agent, - starting_agent.input_guardrails + (run_config.input_guardrails or []), - ItemHelpers.input_to_new_input_list(prepared_input), - context_wrapper, - streamed_result, - current_span, - ) - ) - try: - turn_result = await cls._run_single_turn_streamed( - streamed_result, - current_agent, - hooks, - context_wrapper, - run_config, - should_run_agent_start_hooks, - tool_use_tracker, - all_tools, - server_conversation_tracker, - ) - should_run_agent_start_hooks = False - - streamed_result.raw_responses = streamed_result.raw_responses + [ - turn_result.model_response - ] - streamed_result.input = turn_result.original_input - streamed_result.new_items = turn_result.generated_items - - if server_conversation_tracker is not None: - server_conversation_tracker.track_server_items(turn_result.model_response) - - if isinstance(turn_result.next_step, NextStepHandoff): - # Save the conversation to session if enabled (before handoff) - # Note: Non-streaming path doesn't save handoff turns immediately, - # but streaming needs to for graceful cancellation support - if session is not None: - should_skip_session_save = ( - await AgentRunner._input_guardrail_tripwire_triggered_for_stream( - streamed_result - ) - ) - if should_skip_session_save is False: - await AgentRunner._save_result_to_session( - session, [], turn_result.new_step_items - ) - - current_agent = turn_result.next_step.new_agent - current_span.finish(reset_current=True) - current_span = None - should_run_agent_start_hooks = True - streamed_result._event_queue.put_nowait( - AgentUpdatedStreamEvent(new_agent=current_agent) - ) - - # Check for soft cancel after handoff - if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - break - elif isinstance(turn_result.next_step, NextStepFinalOutput): - streamed_result._output_guardrails_task = asyncio.create_task( - cls._run_output_guardrails( - current_agent.output_guardrails - + (run_config.output_guardrails or []), - current_agent, - turn_result.next_step.output, - context_wrapper, - ) - ) - - try: - output_guardrail_results = await streamed_result._output_guardrails_task - except Exception: - # Exceptions will be checked in the stream_events loop - output_guardrail_results = [] - - streamed_result.output_guardrail_results = output_guardrail_results - streamed_result.final_output = turn_result.next_step.output - streamed_result.is_complete = True - - # Save the conversation to session if enabled - if session is not None: - should_skip_session_save = ( - await AgentRunner._input_guardrail_tripwire_triggered_for_stream( - streamed_result - ) - ) - if should_skip_session_save is False: - await AgentRunner._save_result_to_session( - session, [], turn_result.new_step_items - ) - - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - elif isinstance(turn_result.next_step, NextStepRunAgain): - if session is not None: - should_skip_session_save = ( - await AgentRunner._input_guardrail_tripwire_triggered_for_stream( - streamed_result - ) - ) - if should_skip_session_save is False: - await AgentRunner._save_result_to_session( - session, [], turn_result.new_step_items - ) - - # Check for soft cancel after turn completion - if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - break - except AgentsException as exc: - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - exc.run_data = RunErrorDetails( - input=streamed_result.input, - new_items=streamed_result.new_items, - raw_responses=streamed_result.raw_responses, - last_agent=current_agent, - context_wrapper=context_wrapper, - input_guardrail_results=streamed_result.input_guardrail_results, - output_guardrail_results=streamed_result.output_guardrail_results, - ) - raise - except Exception as e: - if current_span: - _error_tracing.attach_error_to_span( - current_span, - SpanError( - message="Error in agent run", - data={"error": str(e)}, - ), - ) - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) - raise - - streamed_result.is_complete = True - finally: - if streamed_result._input_guardrails_task: - try: - await AgentRunner._input_guardrail_tripwire_triggered_for_stream( - streamed_result - ) - except Exception as e: - logger.debug( - f"Error in streamed_result finalize for agent {current_agent.name} - {e}" - ) - if current_span: - current_span.finish(reset_current=True) - if streamed_result.trace: - streamed_result.trace.finish(reset_current=True) - - @classmethod - async def _run_single_turn_streamed( - cls, - streamed_result: RunResultStreaming, - agent: Agent[TContext], - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - should_run_agent_start_hooks: bool, - tool_use_tracker: AgentToolUseTracker, - all_tools: list[Tool], - server_conversation_tracker: _ServerConversationTracker | None = None, - ) -> SingleStepResult: - emitted_tool_call_ids: set[str] = set() - emitted_reasoning_item_ids: set[str] = set() - - if should_run_agent_start_hooks: - await asyncio.gather( - hooks.on_agent_start(context_wrapper, agent), - ( - agent.hooks.on_start(context_wrapper, agent) - if agent.hooks - else _coro.noop_coroutine() - ), - ) - - output_schema = cls._get_output_schema(agent) - - streamed_result.current_agent = agent - streamed_result._current_agent_output_schema = output_schema - - system_prompt, prompt_config = await asyncio.gather( - agent.get_system_prompt(context_wrapper), - agent.get_prompt(context_wrapper), - ) - - handoffs = await cls._get_handoffs(agent, context_wrapper) - model = cls._get_model(agent, run_config) - model_settings = agent.model_settings.resolve(run_config.model_settings) - model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) - - final_response: ModelResponse | None = None - - if server_conversation_tracker is not None: - input = server_conversation_tracker.prepare_input( - streamed_result.input, streamed_result.new_items - ) - else: - input = ItemHelpers.input_to_new_input_list(streamed_result.input) - input.extend([item.to_input_item() for item in streamed_result.new_items]) - - # THIS IS THE RESOLVED CONFLICT BLOCK - filtered = await cls._maybe_filter_model_input( - agent=agent, - run_config=run_config, - context_wrapper=context_wrapper, - input_items=input, - system_instructions=system_prompt, - ) - - # Call hook just before the model is invoked, with the correct system_prompt. - await asyncio.gather( - hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), - ( - agent.hooks.on_llm_start( - context_wrapper, agent, filtered.instructions, filtered.input - ) - if agent.hooks - else _coro.noop_coroutine() - ), - ) - - previous_response_id = ( - server_conversation_tracker.previous_response_id - if server_conversation_tracker - else None - ) - conversation_id = ( - server_conversation_tracker.conversation_id if server_conversation_tracker else None - ) - - # 1. Stream the output events - async for event in model.stream_response( - filtered.instructions, - filtered.input, - model_settings, - all_tools, - output_schema, - handoffs, - get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data - ), - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt_config, - ): - # Emit the raw event ASAP - streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) - - if isinstance(event, ResponseCompletedEvent): - usage = ( - Usage( - requests=1, - input_tokens=event.response.usage.input_tokens, - output_tokens=event.response.usage.output_tokens, - total_tokens=event.response.usage.total_tokens, - input_tokens_details=event.response.usage.input_tokens_details, - output_tokens_details=event.response.usage.output_tokens_details, - ) - if event.response.usage - else Usage() - ) - final_response = ModelResponse( - output=event.response.output, - usage=usage, - response_id=event.response.id, - ) - context_wrapper.usage.add(usage) - - if isinstance(event, ResponseOutputItemDoneEvent): - output_item = event.item - - if isinstance(output_item, _TOOL_CALL_TYPES): - call_id: str | None = getattr( - output_item, "call_id", getattr(output_item, "id", None) - ) - - if call_id and call_id not in emitted_tool_call_ids: - emitted_tool_call_ids.add(call_id) - - tool_item = ToolCallItem( - raw_item=cast(ToolCallItemTypes, output_item), - agent=agent, - ) - streamed_result._event_queue.put_nowait( - RunItemStreamEvent(item=tool_item, name="tool_called") - ) - - elif isinstance(output_item, ResponseReasoningItem): - reasoning_id: str | None = getattr(output_item, "id", None) - - if reasoning_id and reasoning_id not in emitted_reasoning_item_ids: - emitted_reasoning_item_ids.add(reasoning_id) - - reasoning_item = ReasoningItem(raw_item=output_item, agent=agent) - streamed_result._event_queue.put_nowait( - RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created") - ) - - # Call hook just after the model response is finalized. - if final_response is not None: - await asyncio.gather( - ( - agent.hooks.on_llm_end(context_wrapper, agent, final_response) - if agent.hooks - else _coro.noop_coroutine() - ), - hooks.on_llm_end(context_wrapper, agent, final_response), - ) - - # 2. At this point, the streaming is complete for this turn of the agent loop. - if not final_response: - raise ModelBehaviorError("Model did not produce a final response!") - - # 3. Now, we can process the turn as we do in the non-streaming case - single_step_result = await cls._get_single_step_result_from_response( - agent=agent, - original_input=streamed_result.input, - pre_step_items=streamed_result.new_items, - new_response=final_response, - output_schema=output_schema, - all_tools=all_tools, - handoffs=handoffs, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - tool_use_tracker=tool_use_tracker, - event_queue=streamed_result._event_queue, - ) - - import dataclasses as _dc - - # Filter out items that have already been sent to avoid duplicates - items_to_filter = single_step_result.new_step_items - - if emitted_tool_call_ids: - # Filter out tool call items that were already emitted during streaming - items_to_filter = [ - item - for item in items_to_filter - if not ( - isinstance(item, ToolCallItem) - and ( - call_id := getattr( - item.raw_item, "call_id", getattr(item.raw_item, "id", None) - ) - ) - and call_id in emitted_tool_call_ids - ) - ] - - if emitted_reasoning_item_ids: - # Filter out reasoning items that were already emitted during streaming - items_to_filter = [ - item - for item in items_to_filter - if not ( - isinstance(item, ReasoningItem) - and (reasoning_id := getattr(item.raw_item, "id", None)) - and reasoning_id in emitted_reasoning_item_ids - ) - ] - - # Filter out HandoffCallItem to avoid duplicates (already sent earlier) - items_to_filter = [ - item for item in items_to_filter if not isinstance(item, HandoffCallItem) - ] - - # Create filtered result and send to queue - filtered_result = _dc.replace(single_step_result, new_step_items=items_to_filter) - RunImpl.stream_step_result_to_queue(filtered_result, streamed_result._event_queue) - return single_step_result - - @classmethod - async def _run_single_turn( - cls, - *, - agent: Agent[TContext], - all_tools: list[Tool], - original_input: str | list[TResponseInputItem], - generated_items: list[RunItem], - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - should_run_agent_start_hooks: bool, - tool_use_tracker: AgentToolUseTracker, - server_conversation_tracker: _ServerConversationTracker | None = None, - ) -> SingleStepResult: - # Ensure we run the hooks before anything else - if should_run_agent_start_hooks: - await asyncio.gather( - hooks.on_agent_start(context_wrapper, agent), - ( - agent.hooks.on_start(context_wrapper, agent) - if agent.hooks - else _coro.noop_coroutine() - ), - ) - - system_prompt, prompt_config = await asyncio.gather( - agent.get_system_prompt(context_wrapper), - agent.get_prompt(context_wrapper), - ) - - output_schema = cls._get_output_schema(agent) - handoffs = await cls._get_handoffs(agent, context_wrapper) - if server_conversation_tracker is not None: - input = server_conversation_tracker.prepare_input(original_input, generated_items) - else: - input = ItemHelpers.input_to_new_input_list(original_input) - input.extend([generated_item.to_input_item() for generated_item in generated_items]) - - new_response = await cls._get_new_response( - agent, - system_prompt, - input, - output_schema, - all_tools, - handoffs, - hooks, - context_wrapper, - run_config, - tool_use_tracker, - server_conversation_tracker, - prompt_config, - ) - - return await cls._get_single_step_result_from_response( - agent=agent, - original_input=original_input, - pre_step_items=generated_items, - new_response=new_response, - output_schema=output_schema, - all_tools=all_tools, - handoffs=handoffs, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - tool_use_tracker=tool_use_tracker, - ) - - @classmethod - async def _get_single_step_result_from_response( - cls, - *, - agent: Agent[TContext], - all_tools: list[Tool], - original_input: str | list[TResponseInputItem], - pre_step_items: list[RunItem], - new_response: ModelResponse, - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - tool_use_tracker: AgentToolUseTracker, - event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None, - ) -> SingleStepResult: - processed_response = RunImpl.process_model_response( - agent=agent, - all_tools=all_tools, - response=new_response, - output_schema=output_schema, - handoffs=handoffs, - ) - - tool_use_tracker.add_tool_use(agent, processed_response.tools_used) - - # Send handoff items immediately for streaming, but avoid duplicates - if event_queue is not None and processed_response.new_items: - handoff_items = [ - item for item in processed_response.new_items if isinstance(item, HandoffCallItem) - ] - if handoff_items: - RunImpl.stream_step_items_to_queue(cast(list[RunItem], handoff_items), event_queue) - - return await RunImpl.execute_tools_and_side_effects( - agent=agent, - original_input=original_input, - pre_step_items=pre_step_items, - new_response=new_response, - processed_response=processed_response, - output_schema=output_schema, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - ) - - @classmethod - async def _get_single_step_result_from_streamed_response( - cls, - *, - agent: Agent[TContext], - all_tools: list[Tool], - streamed_result: RunResultStreaming, - new_response: ModelResponse, - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - tool_use_tracker: AgentToolUseTracker, - ) -> SingleStepResult: - original_input = streamed_result.input - pre_step_items = streamed_result.new_items - event_queue = streamed_result._event_queue - - processed_response = RunImpl.process_model_response( - agent=agent, - all_tools=all_tools, - response=new_response, - output_schema=output_schema, - handoffs=handoffs, - ) - new_items_processed_response = processed_response.new_items - tool_use_tracker.add_tool_use(agent, processed_response.tools_used) - RunImpl.stream_step_items_to_queue(new_items_processed_response, event_queue) - - single_step_result = await RunImpl.execute_tools_and_side_effects( - agent=agent, - original_input=original_input, - pre_step_items=pre_step_items, - new_response=new_response, - processed_response=processed_response, - output_schema=output_schema, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - ) - new_step_items = [ - item - for item in single_step_result.new_step_items - if item not in new_items_processed_response - ] - RunImpl.stream_step_items_to_queue(new_step_items, event_queue) - - return single_step_result - - @classmethod - async def _run_input_guardrails( - cls, - agent: Agent[Any], - guardrails: list[InputGuardrail[TContext]], - input: str | list[TResponseInputItem], - context: RunContextWrapper[TContext], - ) -> list[InputGuardrailResult]: - if not guardrails: - return [] - - guardrail_tasks = [ - asyncio.create_task( - RunImpl.run_single_input_guardrail(agent, guardrail, input, context) - ) - for guardrail in guardrails - ] - - guardrail_results = [] - - for done in asyncio.as_completed(guardrail_tasks): - result = await done - if result.output.tripwire_triggered: - # Cancel all guardrail tasks if a tripwire is triggered. - for t in guardrail_tasks: - t.cancel() - _error_tracing.attach_error_to_current_span( - SpanError( - message="Guardrail tripwire triggered", - data={"guardrail": result.guardrail.get_name()}, - ) - ) - raise InputGuardrailTripwireTriggered(result) - else: - guardrail_results.append(result) - - return guardrail_results - - @classmethod - async def _run_output_guardrails( - cls, - guardrails: list[OutputGuardrail[TContext]], - agent: Agent[TContext], - agent_output: Any, - context: RunContextWrapper[TContext], - ) -> list[OutputGuardrailResult]: - if not guardrails: - return [] - - guardrail_tasks = [ - asyncio.create_task( - RunImpl.run_single_output_guardrail(guardrail, agent, agent_output, context) - ) - for guardrail in guardrails - ] - - guardrail_results = [] - - for done in asyncio.as_completed(guardrail_tasks): - result = await done - if result.output.tripwire_triggered: - # Cancel all guardrail tasks if a tripwire is triggered. - for t in guardrail_tasks: - t.cancel() - _error_tracing.attach_error_to_current_span( - SpanError( - message="Guardrail tripwire triggered", - data={"guardrail": result.guardrail.get_name()}, - ) - ) - raise OutputGuardrailTripwireTriggered(result) - else: - guardrail_results.append(result) - - return guardrail_results - - @classmethod - async def _get_new_response( - cls, - agent: Agent[TContext], - system_prompt: str | None, - input: list[TResponseInputItem], - output_schema: AgentOutputSchemaBase | None, - all_tools: list[Tool], - handoffs: list[Handoff], - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - tool_use_tracker: AgentToolUseTracker, - server_conversation_tracker: _ServerConversationTracker | None, - prompt_config: ResponsePromptParam | None, - ) -> ModelResponse: - # Allow user to modify model input right before the call, if configured - filtered = await cls._maybe_filter_model_input( - agent=agent, - run_config=run_config, - context_wrapper=context_wrapper, - input_items=input, - system_instructions=system_prompt, - ) - - model = cls._get_model(agent, run_config) - model_settings = agent.model_settings.resolve(run_config.model_settings) - model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) - - # If we have run hooks, or if the agent has hooks, we need to call them before the LLM call - await asyncio.gather( - hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), - ( - agent.hooks.on_llm_start( - context_wrapper, - agent, - filtered.instructions, # Use filtered instructions - filtered.input, # Use filtered input - ) - if agent.hooks - else _coro.noop_coroutine() - ), - ) - - previous_response_id = ( - server_conversation_tracker.previous_response_id - if server_conversation_tracker - else None - ) - conversation_id = ( - server_conversation_tracker.conversation_id if server_conversation_tracker else None - ) - - new_response = await model.get_response( - system_instructions=filtered.instructions, - input=filtered.input, - model_settings=model_settings, - tools=all_tools, - output_schema=output_schema, - handoffs=handoffs, - tracing=get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data - ), - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt_config, - ) - - context_wrapper.usage.add(new_response.usage) - - # If we have run hooks, or if the agent has hooks, we need to call them after the LLM call - await asyncio.gather( - ( - agent.hooks.on_llm_end(context_wrapper, agent, new_response) - if agent.hooks - else _coro.noop_coroutine() - ), - hooks.on_llm_end(context_wrapper, agent, new_response), - ) - - return new_response - - @classmethod - def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: - if agent.output_type is None or agent.output_type is str: - return None - elif isinstance(agent.output_type, AgentOutputSchemaBase): - return agent.output_type - - return AgentOutputSchema(agent.output_type) - - @classmethod - async def _get_handoffs( - cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] - ) -> list[Handoff]: - handoffs = [] - for handoff_item in agent.handoffs: - if isinstance(handoff_item, Handoff): - handoffs.append(handoff_item) - elif isinstance(handoff_item, Agent): - handoffs.append(handoff(handoff_item)) - - async def _check_handoff_enabled(handoff_obj: Handoff) -> bool: - attr = handoff_obj.is_enabled - if isinstance(attr, bool): - return attr - res = attr(context_wrapper, agent) - if inspect.isawaitable(res): - return bool(await res) - return bool(res) - - results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs)) - enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok] - return enabled - - @classmethod - async def _get_all_tools( - cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] - ) -> list[Tool]: - return await agent.get_all_tools(context_wrapper) - - @classmethod - def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: - if isinstance(run_config.model, Model): - return run_config.model - elif isinstance(run_config.model, str): - return run_config.model_provider.get_model(run_config.model) - elif isinstance(agent.model, Model): - return agent.model - - return run_config.model_provider.get_model(agent.model) - - @classmethod - async def _prepare_input_with_session( - cls, - input: str | list[TResponseInputItem], - session: Session | None, - session_input_callback: SessionInputCallback | None, - ) -> str | list[TResponseInputItem]: - """Prepare input by combining it with session history if enabled.""" - if session is None: - return input - - # If the user doesn't specify an input callback and pass a list as input - if isinstance(input, list) and not session_input_callback: - raise UserError( - "When using session memory, list inputs require a " - "`RunConfig.session_input_callback` to define how they should be merged " - "with the conversation history. If you don't want to use a callback, " - "provide your input as a string instead, or disable session memory " - "(session=None) and pass a list to manage the history manually." - ) - - # Get previous conversation history - history = await session.get_items() - - # Convert input to list format - new_input_list = ItemHelpers.input_to_new_input_list(input) - - if session_input_callback is None: - return history + new_input_list - elif callable(session_input_callback): - res = session_input_callback(history, new_input_list) - if inspect.isawaitable(res): - return await res - return res - else: - raise UserError( - f"Invalid `session_input_callback` value: {session_input_callback}. " - "Choose between `None` or a custom callable function." - ) - - @classmethod - async def _save_result_to_session( - cls, - session: Session | None, - original_input: str | list[TResponseInputItem], - new_items: list[RunItem], - ) -> None: - """ - Save the conversation turn to session. - It does not account for any filtering or modification performed by - `RunConfig.session_input_callback`. - """ - if session is None: - return - - # Convert original input to list format if needed - input_list = ItemHelpers.input_to_new_input_list(original_input) - - # Convert new items to input format - new_items_as_input = [item.to_input_item() for item in new_items] - - # Save all items from this turn - items_to_save = input_list + new_items_as_input - await session.add_items(items_to_save) - - @staticmethod - async def _input_guardrail_tripwire_triggered_for_stream( - streamed_result: RunResultStreaming, - ) -> bool: - """Return True if any input guardrail triggered during a streamed run.""" - - task = streamed_result._input_guardrails_task - if task is None: - return False - - if not task.done(): - await task - - return any( - guardrail_result.output.tripwire_triggered - for guardrail_result in streamed_result.input_guardrail_results - ) - - -DEFAULT_AGENT_RUNNER = AgentRunner() -_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes) - - -def _copy_str_or_list(input: str | list[TResponseInputItem]) -> str | list[TResponseInputItem]: - if isinstance(input, str): - return input - return input.copy() +from __future__ import annotations + +import asyncio +import contextlib +import inspect +import os +import warnings +from dataclasses import dataclass, field +from typing import Any, Callable, Generic, cast, get_args + +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseOutputItemDoneEvent, +) +from openai.types.responses.response_prompt_param import ( + ResponsePromptParam, +) +from openai.types.responses.response_reasoning_item import ResponseReasoningItem +from typing_extensions import NotRequired, TypedDict, Unpack + +from ._run_impl import ( + AgentToolUseTracker, + NextStepFinalOutput, + NextStepHandoff, + NextStepRunAgain, + QueueCompleteSentinel, + RunImpl, + SingleStepResult, + TraceCtxManager, + get_model_tracing_impl, +) +from .agent import Agent +from .agent_output import AgentOutputSchema, AgentOutputSchemaBase +from .exceptions import ( + AgentsException, + InputGuardrailTripwireTriggered, + MaxTurnsExceeded, + ModelBehaviorError, + OutputGuardrailTripwireTriggered, + RunErrorDetails, + UserError, +) +from .guardrail import ( + InputGuardrail, + InputGuardrailResult, + OutputGuardrail, + OutputGuardrailResult, +) +from .handoffs import Handoff, HandoffInputFilter, handoff +from .items import ( + HandoffCallItem, + ItemHelpers, + ModelResponse, + ReasoningItem, + RunItem, + ToolCallItem, + ToolCallItemTypes, + TResponseInputItem, +) +from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase +from .logger import logger +from .memory import Session, SessionInputCallback +from .model_settings import ModelSettings +from .models.interface import Model, ModelProvider +from .models.multi_provider import MultiProvider +from .result import RunResult, RunResultStreaming +from .run_context import RunContextWrapper, TContext +from .stream_events import ( + AgentUpdatedStreamEvent, + RawResponsesStreamEvent, + RunItemStreamEvent, + StreamEvent, +) +from .tool import Tool +from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult +from .tracing import Span, SpanError, agent_span, get_current_trace, trace +from .tracing.span_data import AgentSpanData +from .usage import Usage +from .util import _coro, _error_tracing +from .util._types import MaybeAwaitable + +DEFAULT_MAX_TURNS = 10 + +DEFAULT_AGENT_RUNNER: AgentRunner = None # type: ignore +# the value is set at the end of the module + + +def set_default_agent_runner(runner: AgentRunner | None) -> None: + """ + WARNING: this class is experimental and not part of the public API + It should not be used directly. + """ + global DEFAULT_AGENT_RUNNER + DEFAULT_AGENT_RUNNER = runner or AgentRunner() + + +def get_default_agent_runner() -> AgentRunner: + """ + WARNING: this class is experimental and not part of the public API + It should not be used directly. + """ + global DEFAULT_AGENT_RUNNER + return DEFAULT_AGENT_RUNNER + + +def _default_trace_include_sensitive_data() -> bool: + """Returns the default value for trace_include_sensitive_data based on environment variable.""" + val = os.getenv("OPENAI_AGENTS_TRACE_INCLUDE_SENSITIVE_DATA", "true") + return val.strip().lower() in ("1", "true", "yes", "on") + + +@dataclass +class ModelInputData: + """Container for the data that will be sent to the model.""" + + input: list[TResponseInputItem] + instructions: str | None + + +@dataclass +class CallModelData(Generic[TContext]): + """Data passed to `RunConfig.call_model_input_filter` prior to model call.""" + + model_data: ModelInputData + agent: Agent[TContext] + context: TContext | None + + +@dataclass +class _ServerConversationTracker: + """Tracks server-side conversation state for either conversation_id or + previous_response_id modes.""" + + conversation_id: str | None = None + previous_response_id: str | None = None + sent_items: set[int] = field(default_factory=set) + server_items: set[int] = field(default_factory=set) + + def track_server_items(self, model_response: ModelResponse) -> None: + for output_item in model_response.output: + self.server_items.add(id(output_item)) + + # Update previous_response_id only when using previous_response_id + if ( + self.conversation_id is None + and self.previous_response_id is not None + and model_response.response_id is not None + ): + self.previous_response_id = model_response.response_id + + def prepare_input( + self, + original_input: str | list[TResponseInputItem], + generated_items: list[RunItem], + ) -> list[TResponseInputItem]: + input_items: list[TResponseInputItem] = [] + + # On first call (when there are no generated items yet), include the original input + if not generated_items: + input_items.extend(ItemHelpers.input_to_new_input_list(original_input)) + + # Process generated_items, skip items already sent or from server + for item in generated_items: + raw_item_id = id(item.raw_item) + + if raw_item_id in self.sent_items or raw_item_id in self.server_items: + continue + input_items.append(item.to_input_item()) + self.sent_items.add(raw_item_id) + + return input_items + + +# Type alias for the optional input filter callback +CallModelInputFilter = Callable[[CallModelData[Any]], MaybeAwaitable[ModelInputData]] + + +@dataclass +class RunConfig: + """Configures settings for the entire agent run.""" + + model: str | Model | None = None + """The model to use for the entire agent run. If set, will override the model set on every + agent. The model_provider passed in below must be able to resolve this model name. + """ + + model_provider: ModelProvider = field(default_factory=MultiProvider) + """The model provider to use when looking up string model names. Defaults to OpenAI.""" + + model_settings: ModelSettings | None = None + """Configure global model settings. Any non-null values will override the agent-specific model + settings. + """ + + handoff_input_filter: HandoffInputFilter | None = None + """A global input filter to apply to all handoffs. If `Handoff.input_filter` is set, then that + will take precedence. The input filter allows you to edit the inputs that are sent to the new + agent. See the documentation in `Handoff.input_filter` for more details. + """ + + input_guardrails: list[InputGuardrail[Any]] | None = None + """A list of input guardrails to run on the initial run input.""" + + output_guardrails: list[OutputGuardrail[Any]] | None = None + """A list of output guardrails to run on the final output of the run.""" + + tracing_disabled: bool = False + """Whether tracing is disabled for the agent run. If disabled, we will not trace the agent run. + """ + + trace_include_sensitive_data: bool = field( + default_factory=_default_trace_include_sensitive_data + ) + """Whether we include potentially sensitive data (for example: inputs/outputs of tool calls or + LLM generations) in traces. If False, we'll still create spans for these events, but the + sensitive data will not be included. + """ + + workflow_name: str = "Agent workflow" + """The name of the run, used for tracing. Should be a logical name for the run, like + "Code generation workflow" or "Customer support agent". + """ + + trace_id: str | None = None + """A custom trace ID to use for tracing. If not provided, we will generate a new trace ID.""" + + group_id: str | None = None + """ + A grouping identifier to use for tracing, to link multiple traces from the same conversation + or process. For example, you might use a chat thread ID. + """ + + trace_metadata: dict[str, Any] | None = None + """ + An optional dictionary of additional metadata to include with the trace. + """ + + session_input_callback: SessionInputCallback | None = None + """Defines how to handle session history when new input is provided. + - `None` (default): The new input is appended to the session history. + - `SessionInputCallback`: A custom function that receives the history and new input, and + returns the desired combined list of items. + """ + + call_model_input_filter: CallModelInputFilter | None = None + """ + Optional callback that is invoked immediately before calling the model. It receives the current + agent, context and the model input (instructions and input items), and must return a possibly + modified `ModelInputData` to use for the model call. + + This allows you to edit the input sent to the model e.g. to stay within a token limit. + For example, you can use this to add a system prompt to the input. + """ + + +class RunOptions(TypedDict, Generic[TContext]): + """Arguments for ``AgentRunner`` methods.""" + + context: NotRequired[TContext | None] + """The context for the run.""" + + max_turns: NotRequired[int] + """The maximum number of turns to run for.""" + + hooks: NotRequired[RunHooks[TContext] | None] + """Lifecycle hooks for the run.""" + + run_config: NotRequired[RunConfig | None] + """Run configuration.""" + + previous_response_id: NotRequired[str | None] + """The ID of the previous response, if any.""" + + conversation_id: NotRequired[str | None] + """The ID of the stored conversation, if any.""" + + session: NotRequired[Session | None] + """The session for the run.""" + + +class Runner: + @classmethod + async def run( + cls, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + conversation_id: str | None = None, + session: Session | None = None, + ) -> RunResult: + """ + Run a workflow starting at the given agent. + + The agent will run in a loop until a final output is generated. The loop runs like so: + + 1. The agent is invoked with the given input. + 2. If there is a final output (i.e. the agent produces something of type + `agent.output_type`), the loop terminates. + 3. If there's a handoff, we run the loop again, with the new agent. + 4. Else, we run tool calls (if any), and re-run the loop. + + In two cases, the agent may raise an exception: + + 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. + 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered + exception is raised. + + Note: + Only the first agent's input guardrails are run. + + Args: + starting_agent: The starting agent to run. + input: The initial input to the agent. You can pass a single string for a + user message, or a list of input items. + context: The context to run the agent with. + max_turns: The maximum number of turns to run the agent for. A turn is + defined as one AI invocation (including any tool calls that might occur). + hooks: An object that receives callbacks on various lifecycle events. + run_config: Global settings for the entire agent run. + previous_response_id: The ID of the previous response. If using OpenAI + models via the Responses API, this allows you to skip passing in input + from the previous turn. + conversation_id: The conversation ID + (https://platform.openai.com/docs/guides/conversation-state?api-mode=responses). + If provided, the conversation will be used to read and write items. + Every agent will have access to the conversation history so far, + and its output items will be written to the conversation. + We recommend only using this if you are exclusively using OpenAI models; + other model providers don't write to the Conversation object, + so you'll end up having partial conversations stored. + session: A session for automatic conversation history management. + + Returns: + A run result containing all the inputs, guardrail results and the output of + the last agent. Agents may perform handoffs, so we don't know the specific + type of the output. + """ + + runner = DEFAULT_AGENT_RUNNER + return await runner.run( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + + @classmethod + def run_sync( + cls, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + *, + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + conversation_id: str | None = None, + session: Session | None = None, + ) -> RunResult: + """ + Run a workflow synchronously, starting at the given agent. + + Note: + This just wraps the `run` method, so it will not work if there's already an + event loop (e.g. inside an async function, or in a Jupyter notebook or async + context like FastAPI). For those cases, use the `run` method instead. + + The agent will run in a loop until a final output is generated. The loop runs: + + 1. The agent is invoked with the given input. + 2. If there is a final output (i.e. the agent produces something of type + `agent.output_type`), the loop terminates. + 3. If there's a handoff, we run the loop again, with the new agent. + 4. Else, we run tool calls (if any), and re-run the loop. + + In two cases, the agent may raise an exception: + + 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. + 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered + exception is raised. + + Note: + Only the first agent's input guardrails are run. + + Args: + starting_agent: The starting agent to run. + input: The initial input to the agent. You can pass a single string for a + user message, or a list of input items. + context: The context to run the agent with. + max_turns: The maximum number of turns to run the agent for. A turn is + defined as one AI invocation (including any tool calls that might occur). + hooks: An object that receives callbacks on various lifecycle events. + run_config: Global settings for the entire agent run. + previous_response_id: The ID of the previous response, if using OpenAI + models via the Responses API, this allows you to skip passing in input + from the previous turn. + conversation_id: The ID of the stored conversation, if any. + session: A session for automatic conversation history management. + + Returns: + A run result containing all the inputs, guardrail results and the output of + the last agent. Agents may perform handoffs, so we don't know the specific + type of the output. + """ + + runner = DEFAULT_AGENT_RUNNER + return runner.run_sync( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + + @classmethod + def run_streamed( + cls, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + context: TContext | None = None, + max_turns: int = DEFAULT_MAX_TURNS, + hooks: RunHooks[TContext] | None = None, + run_config: RunConfig | None = None, + previous_response_id: str | None = None, + conversation_id: str | None = None, + session: Session | None = None, + ) -> RunResultStreaming: + """ + Run a workflow starting at the given agent in streaming mode. + + The returned result object contains a method you can use to stream semantic + events as they are generated. + + The agent will run in a loop until a final output is generated. The loop runs like so: + + 1. The agent is invoked with the given input. + 2. If there is a final output (i.e. the agent produces something of type + `agent.output_type`), the loop terminates. + 3. If there's a handoff, we run the loop again, with the new agent. + 4. Else, we run tool calls (if any), and re-run the loop. + + In two cases, the agent may raise an exception: + + 1. If the max_turns is exceeded, a MaxTurnsExceeded exception is raised. + 2. If a guardrail tripwire is triggered, a GuardrailTripwireTriggered + exception is raised. + + Note: + Only the first agent's input guardrails are run. + + Args: + starting_agent: The starting agent to run. + input: The initial input to the agent. You can pass a single string for a + user message, or a list of input items. + context: The context to run the agent with. + max_turns: The maximum number of turns to run the agent for. A turn is + defined as one AI invocation (including any tool calls that might occur). + hooks: An object that receives callbacks on various lifecycle events. + run_config: Global settings for the entire agent run. + previous_response_id: The ID of the previous response, if using OpenAI + models via the Responses API, this allows you to skip passing in input + from the previous turn. + conversation_id: The ID of the stored conversation, if any. + session: A session for automatic conversation history management. + + Returns: + A result object that contains data about the run, as well as a method to + stream events. + """ + + runner = DEFAULT_AGENT_RUNNER + return runner.run_streamed( + starting_agent, + input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + + +class AgentRunner: + """ + WARNING: this class is experimental and not part of the public API + It should not be used directly or subclassed. + """ + + async def run( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + **kwargs: Unpack[RunOptions[TContext]], + ) -> RunResult: + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") + session = kwargs.get("session") + if run_config is None: + run_config = RunConfig() + + if conversation_id is not None or previous_response_id is not None: + server_conversation_tracker = _ServerConversationTracker( + conversation_id=conversation_id, previous_response_id=previous_response_id + ) + else: + server_conversation_tracker = None + + # Keep original user input separate from session-prepared input + original_user_input = input + prepared_input = await self._prepare_input_with_session( + input, session, run_config.session_input_callback + ) + + tool_use_tracker = AgentToolUseTracker() + + with TraceCtxManager( + workflow_name=run_config.workflow_name, + trace_id=run_config.trace_id, + group_id=run_config.group_id, + metadata=run_config.trace_metadata, + disabled=run_config.tracing_disabled, + ): + current_turn = 0 + original_input: str | list[TResponseInputItem] = _copy_str_or_list(prepared_input) + generated_items: list[RunItem] = [] + model_responses: list[ModelResponse] = [] + + context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( + context=context, # type: ignore + ) + + input_guardrail_results: list[InputGuardrailResult] = [] + tool_input_guardrail_results: list[ToolInputGuardrailResult] = [] + tool_output_guardrail_results: list[ToolOutputGuardrailResult] = [] + + current_span: Span[AgentSpanData] | None = None + current_agent = starting_agent + should_run_agent_start_hooks = True + + # save only the new user input to the session, not the combined history + await self._save_result_to_session(session, original_user_input, []) + + try: + while True: + all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper) + + # Start an agent span if we don't have one. This span is ended if the current + # agent changes, or if the agent loop ends. + if current_span is None: + handoff_names = [ + h.agent_name + for h in await AgentRunner._get_handoffs(current_agent, context_wrapper) + ] + if output_schema := AgentRunner._get_output_schema(current_agent): + output_type_name = output_schema.name() + else: + output_type_name = "str" + + current_span = agent_span( + name=current_agent.name, + handoffs=handoff_names, + output_type=output_type_name, + ) + current_span.start(mark_as_current=True) + current_span.span_data.tools = [t.name for t in all_tools] + + current_turn += 1 + if current_turn > max_turns: + _error_tracing.attach_error_to_span( + current_span, + SpanError( + message="Max turns exceeded", + data={"max_turns": max_turns}, + ), + ) + raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded") + + logger.debug( + f"Running agent {current_agent.name} (turn {current_turn})", + ) + + if current_turn == 1: + input_guardrail_results, turn_result = await asyncio.gather( + self._run_input_guardrails( + starting_agent, + starting_agent.input_guardrails + + (run_config.input_guardrails or []), + _copy_str_or_list(prepared_input), + context_wrapper, + ), + self._run_single_turn( + agent=current_agent, + all_tools=all_tools, + original_input=original_input, + generated_items=generated_items, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + should_run_agent_start_hooks=should_run_agent_start_hooks, + tool_use_tracker=tool_use_tracker, + server_conversation_tracker=server_conversation_tracker, + ), + ) + else: + turn_result = await self._run_single_turn( + agent=current_agent, + all_tools=all_tools, + original_input=original_input, + generated_items=generated_items, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + should_run_agent_start_hooks=should_run_agent_start_hooks, + tool_use_tracker=tool_use_tracker, + server_conversation_tracker=server_conversation_tracker, + ) + should_run_agent_start_hooks = False + + model_responses.append(turn_result.model_response) + original_input = turn_result.original_input + generated_items = turn_result.generated_items + + if server_conversation_tracker is not None: + server_conversation_tracker.track_server_items(turn_result.model_response) + + # Collect tool guardrail results from this turn + tool_input_guardrail_results.extend(turn_result.tool_input_guardrail_results) + tool_output_guardrail_results.extend(turn_result.tool_output_guardrail_results) + + if isinstance(turn_result.next_step, NextStepFinalOutput): + output_guardrail_results = await self._run_output_guardrails( + current_agent.output_guardrails + (run_config.output_guardrails or []), + current_agent, + turn_result.next_step.output, + context_wrapper, + ) + result = RunResult( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + final_output=turn_result.next_step.output, + _last_agent=current_agent, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=output_guardrail_results, + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, + context_wrapper=context_wrapper, + ) + if not any( + guardrail_result.output.tripwire_triggered + for guardrail_result in input_guardrail_results + ): + await self._save_result_to_session( + session, [], turn_result.new_step_items + ) + + return result + elif isinstance(turn_result.next_step, NextStepHandoff): + current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) + current_span.finish(reset_current=True) + current_span = None + should_run_agent_start_hooks = True + elif isinstance(turn_result.next_step, NextStepRunAgain): + if not any( + guardrail_result.output.tripwire_triggered + for guardrail_result in input_guardrail_results + ): + await self._save_result_to_session( + session, [], turn_result.new_step_items + ) + else: + raise AgentsException( + f"Unknown next step type: {type(turn_result.next_step)}" + ) + except AgentsException as exc: + exc.run_data = RunErrorDetails( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + last_agent=current_agent, + context_wrapper=context_wrapper, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=[], + ) + raise + finally: + if current_span: + current_span.finish(reset_current=True) + + def run_sync( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + **kwargs: Unpack[RunOptions[TContext]], + ) -> RunResult: + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = kwargs.get("hooks") + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") + session = kwargs.get("session") + + # Python 3.14 stopped implicitly wiring up a default event loop + # when synchronous code touches asyncio APIs for the first time. + # Several of our synchronous entry points (for example the Redis/SQLAlchemy session helpers) + # construct asyncio primitives like asyncio.Lock during __init__, + # which binds them to whatever loop happens to be the thread's default at that moment. + # To keep those locks usable we must ensure that run_sync reuses that same default loop + # instead of hopping over to a brand-new asyncio.run() loop. + try: + already_running_loop = asyncio.get_running_loop() + except RuntimeError: + already_running_loop = None + + if already_running_loop is not None: + # This method is only expected to run when no loop is already active. + # (Each thread has its own default loop; concurrent sync runs should happen on + # different threads. In a single thread use the async API to interleave work.) + raise RuntimeError( + "AgentRunner.run_sync() cannot be called when an event loop is already running." + ) + + policy = asyncio.get_event_loop_policy() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + try: + default_loop = policy.get_event_loop() + except RuntimeError: + default_loop = policy.new_event_loop() + policy.set_event_loop(default_loop) + + # We intentionally leave the default loop open even if we had to create one above. Session + # instances and other helpers stash loop-bound primitives between calls and expect to find + # the same default loop every time run_sync is invoked on this thread. + # Schedule the async run on the default loop so that we can manage cancellation explicitly. + task = default_loop.create_task( + self.run( + starting_agent, + input, + session=session, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + ) + ) + + try: + # Drive the coroutine to completion, harvesting the final RunResult. + return default_loop.run_until_complete(task) + except BaseException: + # If the sync caller aborts (KeyboardInterrupt, etc.), make sure the scheduled task + # does not linger on the shared loop by cancelling it and waiting for completion. + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + default_loop.run_until_complete(task) + raise + finally: + if not default_loop.is_closed(): + # The loop stays open for subsequent runs, but we still need to flush any pending + # async generators so their cleanup code executes promptly. + with contextlib.suppress(RuntimeError): + default_loop.run_until_complete(default_loop.shutdown_asyncgens()) + + def run_streamed( + self, + starting_agent: Agent[TContext], + input: str | list[TResponseInputItem], + **kwargs: Unpack[RunOptions[TContext]], + ) -> RunResultStreaming: + context = kwargs.get("context") + max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) + run_config = kwargs.get("run_config") + previous_response_id = kwargs.get("previous_response_id") + conversation_id = kwargs.get("conversation_id") + session = kwargs.get("session") + + if run_config is None: + run_config = RunConfig() + + # If there's already a trace, we don't create a new one. In addition, we can't end the + # trace here, because the actual work is done in `stream_events` and this method ends + # before that. + new_trace = ( + None + if get_current_trace() + else trace( + workflow_name=run_config.workflow_name, + trace_id=run_config.trace_id, + group_id=run_config.group_id, + metadata=run_config.trace_metadata, + disabled=run_config.tracing_disabled, + ) + ) + + output_schema = AgentRunner._get_output_schema(starting_agent) + context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( + context=context # type: ignore + ) + + streamed_result = RunResultStreaming( + input=_copy_str_or_list(input), + new_items=[], + current_agent=starting_agent, + raw_responses=[], + final_output=None, + is_complete=False, + current_turn=0, + max_turns=max_turns, + input_guardrail_results=[], + output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + _current_agent_output_schema=output_schema, + trace=new_trace, + context_wrapper=context_wrapper, + ) + + # Kick off the actual agent loop in the background and return the streamed result object. + streamed_result._run_impl_task = asyncio.create_task( + self._start_streaming( + starting_input=input, + streamed_result=streamed_result, + starting_agent=starting_agent, + max_turns=max_turns, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, + ) + ) + return streamed_result + + @staticmethod + def _validate_run_hooks( + hooks: RunHooksBase[Any, Agent[Any]] | AgentHooksBase[Any, Agent[Any]] | Any | None, + ) -> RunHooks[Any]: + if hooks is None: + return RunHooks[Any]() + input_hook_type = type(hooks).__name__ + if isinstance(hooks, AgentHooksBase): + raise TypeError( + "Run hooks must be instances of RunHooks. " + f"Received agent-scoped hooks ({input_hook_type}). " + "Attach AgentHooks to an Agent via Agent(..., hooks=...)." + ) + if not isinstance(hooks, RunHooksBase): + raise TypeError(f"Run hooks must be instances of RunHooks. Received {input_hook_type}.") + return hooks + + @classmethod + async def _maybe_filter_model_input( + cls, + *, + agent: Agent[TContext], + run_config: RunConfig, + context_wrapper: RunContextWrapper[TContext], + input_items: list[TResponseInputItem], + system_instructions: str | None, + ) -> ModelInputData: + """Apply optional call_model_input_filter to modify model input. + + Returns a `ModelInputData` that will be sent to the model. + """ + effective_instructions = system_instructions + effective_input: list[TResponseInputItem] = input_items + + if run_config.call_model_input_filter is None: + return ModelInputData(input=effective_input, instructions=effective_instructions) + + try: + model_input = ModelInputData( + input=effective_input.copy(), + instructions=effective_instructions, + ) + filter_payload: CallModelData[TContext] = CallModelData( + model_data=model_input, + agent=agent, + context=context_wrapper.context, + ) + maybe_updated = run_config.call_model_input_filter(filter_payload) + updated = await maybe_updated if inspect.isawaitable(maybe_updated) else maybe_updated + if not isinstance(updated, ModelInputData): + raise UserError("call_model_input_filter must return a ModelInputData instance") + return updated + except Exception as e: + _error_tracing.attach_error_to_current_span( + SpanError(message="Error in call_model_input_filter", data={"error": str(e)}) + ) + raise + + @classmethod + async def _run_input_guardrails_with_queue( + cls, + agent: Agent[Any], + guardrails: list[InputGuardrail[TContext]], + input: str | list[TResponseInputItem], + context: RunContextWrapper[TContext], + streamed_result: RunResultStreaming, + parent_span: Span[Any], + ): + queue = streamed_result._input_guardrail_queue + + # We'll run the guardrails and push them onto the queue as they complete + guardrail_tasks = [ + asyncio.create_task( + RunImpl.run_single_input_guardrail(agent, guardrail, input, context) + ) + for guardrail in guardrails + ] + guardrail_results = [] + try: + for done in asyncio.as_completed(guardrail_tasks): + result = await done + if result.output.tripwire_triggered: + _error_tracing.attach_error_to_span( + parent_span, + SpanError( + message="Guardrail tripwire triggered", + data={ + "guardrail": result.guardrail.get_name(), + "type": "input_guardrail", + }, + ), + ) + queue.put_nowait(result) + guardrail_results.append(result) + except Exception: + for t in guardrail_tasks: + t.cancel() + raise + + streamed_result.input_guardrail_results = guardrail_results + + @classmethod + async def _start_streaming( + cls, + starting_input: str | list[TResponseInputItem], + streamed_result: RunResultStreaming, + starting_agent: Agent[TContext], + max_turns: int, + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + previous_response_id: str | None, + conversation_id: str | None, + session: Session | None, + ): + if streamed_result.trace: + streamed_result.trace.start(mark_as_current=True) + + current_span: Span[AgentSpanData] | None = None + current_agent = starting_agent + current_turn = 0 + should_run_agent_start_hooks = True + tool_use_tracker = AgentToolUseTracker() + + if conversation_id is not None or previous_response_id is not None: + server_conversation_tracker = _ServerConversationTracker( + conversation_id=conversation_id, previous_response_id=previous_response_id + ) + else: + server_conversation_tracker = None + + streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) + + try: + # Prepare input with session if enabled + prepared_input = await AgentRunner._prepare_input_with_session( + starting_input, session, run_config.session_input_callback + ) + + # Update the streamed result with the prepared input + streamed_result.input = prepared_input + + await AgentRunner._save_result_to_session(session, starting_input, []) + + while True: + # Check for soft cancel before starting new turn + if streamed_result._cancel_mode == "after_turn": + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + + if streamed_result.is_complete: + break + + all_tools = await cls._get_all_tools(current_agent, context_wrapper) + + # Start an agent span if we don't have one. This span is ended if the current + # agent changes, or if the agent loop ends. + if current_span is None: + handoff_names = [ + h.agent_name + for h in await cls._get_handoffs(current_agent, context_wrapper) + ] + if output_schema := cls._get_output_schema(current_agent): + output_type_name = output_schema.name() + else: + output_type_name = "str" + + current_span = agent_span( + name=current_agent.name, + handoffs=handoff_names, + output_type=output_type_name, + ) + current_span.start(mark_as_current=True) + tool_names = [t.name for t in all_tools] + current_span.span_data.tools = tool_names + current_turn += 1 + streamed_result.current_turn = current_turn + + if current_turn > max_turns: + _error_tracing.attach_error_to_span( + current_span, + SpanError( + message="Max turns exceeded", + data={"max_turns": max_turns}, + ), + ) + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + + if current_turn == 1: + # Run the input guardrails in the background and put the results on the queue + streamed_result._input_guardrails_task = asyncio.create_task( + cls._run_input_guardrails_with_queue( + starting_agent, + starting_agent.input_guardrails + (run_config.input_guardrails or []), + ItemHelpers.input_to_new_input_list(prepared_input), + context_wrapper, + streamed_result, + current_span, + ) + ) + try: + turn_result = await cls._run_single_turn_streamed( + streamed_result, + current_agent, + hooks, + context_wrapper, + run_config, + should_run_agent_start_hooks, + tool_use_tracker, + all_tools, + server_conversation_tracker, + ) + should_run_agent_start_hooks = False + + streamed_result.raw_responses = streamed_result.raw_responses + [ + turn_result.model_response + ] + streamed_result.input = turn_result.original_input + streamed_result.new_items = turn_result.generated_items + + if server_conversation_tracker is not None: + server_conversation_tracker.track_server_items(turn_result.model_response) + + if isinstance(turn_result.next_step, NextStepHandoff): + # Save the conversation to session if enabled (before handoff) + # Note: Non-streaming path doesn't save handoff turns immediately, + # but streaming needs to for graceful cancellation support + if session is not None: + should_skip_session_save = ( + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) + ) + if should_skip_session_save is False: + await AgentRunner._save_result_to_session( + session, [], turn_result.new_step_items + ) + + current_agent = turn_result.next_step.new_agent + current_span.finish(reset_current=True) + current_span = None + should_run_agent_start_hooks = True + streamed_result._event_queue.put_nowait( + AgentUpdatedStreamEvent(new_agent=current_agent) + ) + + # Check for soft cancel after handoff + if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + elif isinstance(turn_result.next_step, NextStepFinalOutput): + streamed_result._output_guardrails_task = asyncio.create_task( + cls._run_output_guardrails( + current_agent.output_guardrails + + (run_config.output_guardrails or []), + current_agent, + turn_result.next_step.output, + context_wrapper, + ) + ) + + try: + output_guardrail_results = await streamed_result._output_guardrails_task + except Exception: + # Exceptions will be checked in the stream_events loop + output_guardrail_results = [] + + streamed_result.output_guardrail_results = output_guardrail_results + streamed_result.final_output = turn_result.next_step.output + streamed_result.is_complete = True + + # Save the conversation to session if enabled + if session is not None: + should_skip_session_save = ( + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) + ) + if should_skip_session_save is False: + await AgentRunner._save_result_to_session( + session, [], turn_result.new_step_items + ) + + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + elif isinstance(turn_result.next_step, NextStepRunAgain): + if session is not None: + should_skip_session_save = ( + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) + ) + if should_skip_session_save is False: + await AgentRunner._save_result_to_session( + session, [], turn_result.new_step_items + ) + + # Check for soft cancel after turn completion + if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + except AgentsException as exc: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + exc.run_data = RunErrorDetails( + input=streamed_result.input, + new_items=streamed_result.new_items, + raw_responses=streamed_result.raw_responses, + last_agent=current_agent, + context_wrapper=context_wrapper, + input_guardrail_results=streamed_result.input_guardrail_results, + output_guardrail_results=streamed_result.output_guardrail_results, + ) + raise + except Exception as e: + if current_span: + _error_tracing.attach_error_to_span( + current_span, + SpanError( + message="Error in agent run", + data={"error": str(e)}, + ), + ) + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + raise + + streamed_result.is_complete = True + finally: + if streamed_result._input_guardrails_task: + try: + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) + except Exception as e: + logger.debug( + f"Error in streamed_result finalize for agent {current_agent.name} - {e}" + ) + if current_span: + current_span.finish(reset_current=True) + if streamed_result.trace: + streamed_result.trace.finish(reset_current=True) + + @classmethod + async def _run_single_turn_streamed( + cls, + streamed_result: RunResultStreaming, + agent: Agent[TContext], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + should_run_agent_start_hooks: bool, + tool_use_tracker: AgentToolUseTracker, + all_tools: list[Tool], + server_conversation_tracker: _ServerConversationTracker | None = None, + ) -> SingleStepResult: + emitted_tool_call_ids: set[str] = set() + emitted_reasoning_item_ids: set[str] = set() + + if should_run_agent_start_hooks: + await asyncio.gather( + hooks.on_agent_start(context_wrapper, agent), + ( + agent.hooks.on_start(context_wrapper, agent) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + output_schema = cls._get_output_schema(agent) + + streamed_result.current_agent = agent + streamed_result._current_agent_output_schema = output_schema + + system_prompt, prompt_config = await asyncio.gather( + agent.get_system_prompt(context_wrapper), + agent.get_prompt(context_wrapper), + ) + + handoffs = await cls._get_handoffs(agent, context_wrapper) + model = cls._get_model(agent, run_config) + model_settings = agent.model_settings.resolve(run_config.model_settings) + model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + + final_response: ModelResponse | None = None + + if server_conversation_tracker is not None: + input = server_conversation_tracker.prepare_input( + streamed_result.input, streamed_result.new_items + ) + else: + input = ItemHelpers.input_to_new_input_list(streamed_result.input) + input.extend([item.to_input_item() for item in streamed_result.new_items]) + + # THIS IS THE RESOLVED CONFLICT BLOCK + filtered = await cls._maybe_filter_model_input( + agent=agent, + run_config=run_config, + context_wrapper=context_wrapper, + input_items=input, + system_instructions=system_prompt, + ) + + # Call hook just before the model is invoked, with the correct system_prompt. + await asyncio.gather( + hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), + ( + agent.hooks.on_llm_start( + context_wrapper, agent, filtered.instructions, filtered.input + ) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + previous_response_id = ( + server_conversation_tracker.previous_response_id + if server_conversation_tracker + else None + ) + conversation_id = ( + server_conversation_tracker.conversation_id if server_conversation_tracker else None + ) + + # 1. Stream the output events + async for event in model.stream_response( + filtered.instructions, + filtered.input, + model_settings, + all_tools, + output_schema, + handoffs, + get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data + ), + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt_config, + ): + # Emit the raw event ASAP + streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) + + if isinstance(event, ResponseCompletedEvent): + usage = ( + Usage( + requests=1, + input_tokens=event.response.usage.input_tokens, + output_tokens=event.response.usage.output_tokens, + total_tokens=event.response.usage.total_tokens, + input_tokens_details=event.response.usage.input_tokens_details, + output_tokens_details=event.response.usage.output_tokens_details, + ) + if event.response.usage + else Usage() + ) + final_response = ModelResponse( + output=event.response.output, + usage=usage, + response_id=event.response.id, + ) + context_wrapper.usage.add(usage) + + if isinstance(event, ResponseOutputItemDoneEvent): + output_item = event.item + + if isinstance(output_item, _TOOL_CALL_TYPES): + call_id: str | None = getattr( + output_item, "call_id", getattr(output_item, "id", None) + ) + + if call_id and call_id not in emitted_tool_call_ids: + emitted_tool_call_ids.add(call_id) + + tool_item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, output_item), + agent=agent, + ) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent(item=tool_item, name="tool_called") + ) + + elif isinstance(output_item, ResponseReasoningItem): + reasoning_id: str | None = getattr(output_item, "id", None) + + if reasoning_id and reasoning_id not in emitted_reasoning_item_ids: + emitted_reasoning_item_ids.add(reasoning_id) + + reasoning_item = ReasoningItem(raw_item=output_item, agent=agent) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created") + ) + + # Call hook just after the model response is finalized. + if final_response is not None: + await asyncio.gather( + ( + agent.hooks.on_llm_end(context_wrapper, agent, final_response) + if agent.hooks + else _coro.noop_coroutine() + ), + hooks.on_llm_end(context_wrapper, agent, final_response), + ) + + # 2. At this point, the streaming is complete for this turn of the agent loop. + if not final_response: + raise ModelBehaviorError("Model did not produce a final response!") + + # 3. Now, we can process the turn as we do in the non-streaming case + single_step_result = await cls._get_single_step_result_from_response( + agent=agent, + original_input=streamed_result.input, + pre_step_items=streamed_result.new_items, + new_response=final_response, + output_schema=output_schema, + all_tools=all_tools, + handoffs=handoffs, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + tool_use_tracker=tool_use_tracker, + event_queue=streamed_result._event_queue, + ) + + import dataclasses as _dc + + # Filter out items that have already been sent to avoid duplicates + items_to_filter = single_step_result.new_step_items + + if emitted_tool_call_ids: + # Filter out tool call items that were already emitted during streaming + items_to_filter = [ + item + for item in items_to_filter + if not ( + isinstance(item, ToolCallItem) + and ( + call_id := getattr( + item.raw_item, "call_id", getattr(item.raw_item, "id", None) + ) + ) + and call_id in emitted_tool_call_ids + ) + ] + + if emitted_reasoning_item_ids: + # Filter out reasoning items that were already emitted during streaming + items_to_filter = [ + item + for item in items_to_filter + if not ( + isinstance(item, ReasoningItem) + and (reasoning_id := getattr(item.raw_item, "id", None)) + and reasoning_id in emitted_reasoning_item_ids + ) + ] + + # Filter out HandoffCallItem to avoid duplicates (already sent earlier) + items_to_filter = [ + item for item in items_to_filter if not isinstance(item, HandoffCallItem) + ] + + # Create filtered result and send to queue + filtered_result = _dc.replace(single_step_result, new_step_items=items_to_filter) + RunImpl.stream_step_result_to_queue(filtered_result, streamed_result._event_queue) + return single_step_result + + @classmethod + async def _run_single_turn( + cls, + *, + agent: Agent[TContext], + all_tools: list[Tool], + original_input: str | list[TResponseInputItem], + generated_items: list[RunItem], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + should_run_agent_start_hooks: bool, + tool_use_tracker: AgentToolUseTracker, + server_conversation_tracker: _ServerConversationTracker | None = None, + ) -> SingleStepResult: + # Ensure we run the hooks before anything else + if should_run_agent_start_hooks: + await asyncio.gather( + hooks.on_agent_start(context_wrapper, agent), + ( + agent.hooks.on_start(context_wrapper, agent) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + system_prompt, prompt_config = await asyncio.gather( + agent.get_system_prompt(context_wrapper), + agent.get_prompt(context_wrapper), + ) + + output_schema = cls._get_output_schema(agent) + handoffs = await cls._get_handoffs(agent, context_wrapper) + if server_conversation_tracker is not None: + input = server_conversation_tracker.prepare_input(original_input, generated_items) + else: + input = ItemHelpers.input_to_new_input_list(original_input) + input.extend([generated_item.to_input_item() for generated_item in generated_items]) + + new_response = await cls._get_new_response( + agent, + system_prompt, + input, + output_schema, + all_tools, + handoffs, + hooks, + context_wrapper, + run_config, + tool_use_tracker, + server_conversation_tracker, + prompt_config, + ) + + return await cls._get_single_step_result_from_response( + agent=agent, + original_input=original_input, + pre_step_items=generated_items, + new_response=new_response, + output_schema=output_schema, + all_tools=all_tools, + handoffs=handoffs, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + tool_use_tracker=tool_use_tracker, + ) + + @classmethod + async def _get_single_step_result_from_response( + cls, + *, + agent: Agent[TContext], + all_tools: list[Tool], + original_input: str | list[TResponseInputItem], + pre_step_items: list[RunItem], + new_response: ModelResponse, + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + tool_use_tracker: AgentToolUseTracker, + event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None, + ) -> SingleStepResult: + processed_response = RunImpl.process_model_response( + agent=agent, + all_tools=all_tools, + response=new_response, + output_schema=output_schema, + handoffs=handoffs, + ) + + tool_use_tracker.add_tool_use(agent, processed_response.tools_used) + + # Send handoff items immediately for streaming, but avoid duplicates + if event_queue is not None and processed_response.new_items: + handoff_items = [ + item for item in processed_response.new_items if isinstance(item, HandoffCallItem) + ] + if handoff_items: + RunImpl.stream_step_items_to_queue(cast(list[RunItem], handoff_items), event_queue) + + return await RunImpl.execute_tools_and_side_effects( + agent=agent, + original_input=original_input, + pre_step_items=pre_step_items, + new_response=new_response, + processed_response=processed_response, + output_schema=output_schema, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + ) + + @classmethod + async def _get_single_step_result_from_streamed_response( + cls, + *, + agent: Agent[TContext], + all_tools: list[Tool], + streamed_result: RunResultStreaming, + new_response: ModelResponse, + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + tool_use_tracker: AgentToolUseTracker, + ) -> SingleStepResult: + original_input = streamed_result.input + pre_step_items = streamed_result.new_items + event_queue = streamed_result._event_queue + + processed_response = RunImpl.process_model_response( + agent=agent, + all_tools=all_tools, + response=new_response, + output_schema=output_schema, + handoffs=handoffs, + ) + new_items_processed_response = processed_response.new_items + tool_use_tracker.add_tool_use(agent, processed_response.tools_used) + RunImpl.stream_step_items_to_queue(new_items_processed_response, event_queue) + + single_step_result = await RunImpl.execute_tools_and_side_effects( + agent=agent, + original_input=original_input, + pre_step_items=pre_step_items, + new_response=new_response, + processed_response=processed_response, + output_schema=output_schema, + hooks=hooks, + context_wrapper=context_wrapper, + run_config=run_config, + ) + new_step_items = [ + item + for item in single_step_result.new_step_items + if item not in new_items_processed_response + ] + RunImpl.stream_step_items_to_queue(new_step_items, event_queue) + + return single_step_result + + @classmethod + async def _run_input_guardrails( + cls, + agent: Agent[Any], + guardrails: list[InputGuardrail[TContext]], + input: str | list[TResponseInputItem], + context: RunContextWrapper[TContext], + ) -> list[InputGuardrailResult]: + if not guardrails: + return [] + + guardrail_tasks = [ + asyncio.create_task( + RunImpl.run_single_input_guardrail(agent, guardrail, input, context) + ) + for guardrail in guardrails + ] + + guardrail_results = [] + + for done in asyncio.as_completed(guardrail_tasks): + result = await done + if result.output.tripwire_triggered: + # Cancel all guardrail tasks if a tripwire is triggered. + for t in guardrail_tasks: + t.cancel() + _error_tracing.attach_error_to_current_span( + SpanError( + message="Guardrail tripwire triggered", + data={"guardrail": result.guardrail.get_name()}, + ) + ) + raise InputGuardrailTripwireTriggered(result) + else: + guardrail_results.append(result) + + return guardrail_results + + @classmethod + async def _run_output_guardrails( + cls, + guardrails: list[OutputGuardrail[TContext]], + agent: Agent[TContext], + agent_output: Any, + context: RunContextWrapper[TContext], + ) -> list[OutputGuardrailResult]: + if not guardrails: + return [] + + guardrail_tasks = [ + asyncio.create_task( + RunImpl.run_single_output_guardrail(guardrail, agent, agent_output, context) + ) + for guardrail in guardrails + ] + + guardrail_results = [] + + for done in asyncio.as_completed(guardrail_tasks): + result = await done + if result.output.tripwire_triggered: + # Cancel all guardrail tasks if a tripwire is triggered. + for t in guardrail_tasks: + t.cancel() + _error_tracing.attach_error_to_current_span( + SpanError( + message="Guardrail tripwire triggered", + data={"guardrail": result.guardrail.get_name()}, + ) + ) + raise OutputGuardrailTripwireTriggered(result) + else: + guardrail_results.append(result) + + return guardrail_results + + @classmethod + async def _get_new_response( + cls, + agent: Agent[TContext], + system_prompt: str | None, + input: list[TResponseInputItem], + output_schema: AgentOutputSchemaBase | None, + all_tools: list[Tool], + handoffs: list[Handoff], + hooks: RunHooks[TContext], + context_wrapper: RunContextWrapper[TContext], + run_config: RunConfig, + tool_use_tracker: AgentToolUseTracker, + server_conversation_tracker: _ServerConversationTracker | None, + prompt_config: ResponsePromptParam | None, + ) -> ModelResponse: + # Allow user to modify model input right before the call, if configured + filtered = await cls._maybe_filter_model_input( + agent=agent, + run_config=run_config, + context_wrapper=context_wrapper, + input_items=input, + system_instructions=system_prompt, + ) + + model = cls._get_model(agent, run_config) + model_settings = agent.model_settings.resolve(run_config.model_settings) + model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + + # If we have run hooks, or if the agent has hooks, we need to call them before the LLM call + await asyncio.gather( + hooks.on_llm_start(context_wrapper, agent, filtered.instructions, filtered.input), + ( + agent.hooks.on_llm_start( + context_wrapper, + agent, + filtered.instructions, # Use filtered instructions + filtered.input, # Use filtered input + ) + if agent.hooks + else _coro.noop_coroutine() + ), + ) + + previous_response_id = ( + server_conversation_tracker.previous_response_id + if server_conversation_tracker + else None + ) + conversation_id = ( + server_conversation_tracker.conversation_id if server_conversation_tracker else None + ) + + new_response = await model.get_response( + system_instructions=filtered.instructions, + input=filtered.input, + model_settings=model_settings, + tools=all_tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data + ), + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt_config, + ) + + context_wrapper.usage.add(new_response.usage) + + # If we have run hooks, or if the agent has hooks, we need to call them after the LLM call + await asyncio.gather( + ( + agent.hooks.on_llm_end(context_wrapper, agent, new_response) + if agent.hooks + else _coro.noop_coroutine() + ), + hooks.on_llm_end(context_wrapper, agent, new_response), + ) + + return new_response + + @classmethod + def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: + if agent.output_type is None or agent.output_type is str: + return None + elif isinstance(agent.output_type, AgentOutputSchemaBase): + return agent.output_type + + return AgentOutputSchema(agent.output_type) + + @classmethod + async def _get_handoffs( + cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] + ) -> list[Handoff]: + handoffs = [] + for handoff_item in agent.handoffs: + if isinstance(handoff_item, Handoff): + handoffs.append(handoff_item) + elif isinstance(handoff_item, Agent): + handoffs.append(handoff(handoff_item)) + + async def _check_handoff_enabled(handoff_obj: Handoff) -> bool: + attr = handoff_obj.is_enabled + if isinstance(attr, bool): + return attr + res = attr(context_wrapper, agent) + if inspect.isawaitable(res): + return bool(await res) + return bool(res) + + results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs)) + enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok] + return enabled + + @classmethod + async def _get_all_tools( + cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] + ) -> list[Tool]: + return await agent.get_all_tools(context_wrapper) + + @classmethod + def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: + if isinstance(run_config.model, Model): + return run_config.model + elif isinstance(run_config.model, str): + return run_config.model_provider.get_model(run_config.model) + elif isinstance(agent.model, Model): + return agent.model + + return run_config.model_provider.get_model(agent.model) + + @classmethod + async def _prepare_input_with_session( + cls, + input: str | list[TResponseInputItem], + session: Session | None, + session_input_callback: SessionInputCallback | None, + ) -> str | list[TResponseInputItem]: + """Prepare input by combining it with session history if enabled.""" + if session is None: + return input + + # If the user doesn't specify an input callback and pass a list as input + if isinstance(input, list) and not session_input_callback: + raise UserError( + "When using session memory, list inputs require a " + "`RunConfig.session_input_callback` to define how they should be merged " + "with the conversation history. If you don't want to use a callback, " + "provide your input as a string instead, or disable session memory " + "(session=None) and pass a list to manage the history manually." + ) + + # Get previous conversation history + history = await session.get_items() + + # Convert input to list format + new_input_list = ItemHelpers.input_to_new_input_list(input) + + if session_input_callback is None: + return history + new_input_list + elif callable(session_input_callback): + res = session_input_callback(history, new_input_list) + if inspect.isawaitable(res): + return await res + return res + else: + raise UserError( + f"Invalid `session_input_callback` value: {session_input_callback}. " + "Choose between `None` or a custom callable function." + ) + + @classmethod + async def _save_result_to_session( + cls, + session: Session | None, + original_input: str | list[TResponseInputItem], + new_items: list[RunItem], + ) -> None: + """ + Save the conversation turn to session. + It does not account for any filtering or modification performed by + `RunConfig.session_input_callback`. + """ + if session is None: + return + + # Convert original input to list format if needed + input_list = ItemHelpers.input_to_new_input_list(original_input) + + # Convert new items to input format + new_items_as_input = [item.to_input_item() for item in new_items] + + # Save all items from this turn + items_to_save = input_list + new_items_as_input + await session.add_items(items_to_save) + + @staticmethod + async def _input_guardrail_tripwire_triggered_for_stream( + streamed_result: RunResultStreaming, + ) -> bool: + """Return True if any input guardrail triggered during a streamed run.""" + + task = streamed_result._input_guardrails_task + if task is None: + return False + + if not task.done(): + await task + + return any( + guardrail_result.output.tripwire_triggered + for guardrail_result in streamed_result.input_guardrail_results + ) + + +DEFAULT_AGENT_RUNNER = AgentRunner() +_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes) + + +def _copy_str_or_list(input: str | list[TResponseInputItem]) -> str | list[TResponseInputItem]: + if isinstance(input, str): + return input + return input.copy() diff --git a/tests/fake_model.py b/tests/fake_model.py index efedeb7fe..6e13a02a4 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -1,343 +1,343 @@ -from __future__ import annotations - -from collections.abc import AsyncIterator -from typing import Any - -from openai.types.responses import ( - Response, - ResponseCompletedEvent, - ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, - ResponseCreatedEvent, - ResponseFunctionCallArgumentsDeltaEvent, - ResponseFunctionCallArgumentsDoneEvent, - ResponseFunctionToolCall, - ResponseInProgressEvent, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponseOutputMessage, - ResponseOutputText, - ResponseReasoningSummaryPartAddedEvent, - ResponseReasoningSummaryPartDoneEvent, - ResponseReasoningSummaryTextDeltaEvent, - ResponseReasoningSummaryTextDoneEvent, - ResponseTextDeltaEvent, - ResponseTextDoneEvent, - ResponseUsage, -) -from openai.types.responses.response_reasoning_item import ResponseReasoningItem -from openai.types.responses.response_reasoning_summary_part_added_event import ( - Part as AddedEventPart, -) -from openai.types.responses.response_reasoning_summary_part_done_event import Part as DoneEventPart -from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails - -from agents.agent_output import AgentOutputSchemaBase -from agents.handoffs import Handoff -from agents.items import ( - ModelResponse, - TResponseInputItem, - TResponseOutputItem, - TResponseStreamEvent, -) -from agents.model_settings import ModelSettings -from agents.models.interface import Model, ModelTracing -from agents.tool import Tool -from agents.tracing import SpanError, generation_span -from agents.usage import Usage - - -class FakeModel(Model): - def __init__( - self, - tracing_enabled: bool = False, - initial_output: list[TResponseOutputItem] | Exception | None = None, - ): - if initial_output is None: - initial_output = [] - self.turn_outputs: list[list[TResponseOutputItem] | Exception] = ( - [initial_output] if initial_output else [] - ) - self.tracing_enabled = tracing_enabled - self.last_turn_args: dict[str, Any] = {} - self.first_turn_args: dict[str, Any] | None = None - self.hardcoded_usage: Usage | None = None - - def set_hardcoded_usage(self, usage: Usage): - self.hardcoded_usage = usage - - def set_next_output(self, output: list[TResponseOutputItem] | Exception): - self.turn_outputs.append(output) - - def add_multiple_turn_outputs(self, outputs: list[list[TResponseOutputItem] | Exception]): - self.turn_outputs.extend(outputs) - - def get_next_output(self) -> list[TResponseOutputItem] | Exception: - if not self.turn_outputs: - return [] - return self.turn_outputs.pop(0) - - async def get_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None, - conversation_id: str | None, - prompt: Any | None, - ) -> ModelResponse: - turn_args = { - "system_instructions": system_instructions, - "input": input, - "model_settings": model_settings, - "tools": tools, - "output_schema": output_schema, - "previous_response_id": previous_response_id, - "conversation_id": conversation_id, - } - - if self.first_turn_args is None: - self.first_turn_args = turn_args.copy() - - self.last_turn_args = turn_args - - with generation_span(disabled=not self.tracing_enabled) as span: - output = self.get_next_output() - - if isinstance(output, Exception): - span.set_error( - SpanError( - message="Error", - data={ - "name": output.__class__.__name__, - "message": str(output), - }, - ) - ) - raise output - - return ModelResponse( - output=output, - usage=self.hardcoded_usage or Usage(), - response_id="resp-789", - ) - - async def stream_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None = None, - conversation_id: str | None = None, - prompt: Any | None = None, - ) -> AsyncIterator[TResponseStreamEvent]: - turn_args = { - "system_instructions": system_instructions, - "input": input, - "model_settings": model_settings, - "tools": tools, - "output_schema": output_schema, - "previous_response_id": previous_response_id, - "conversation_id": conversation_id, - } - - if self.first_turn_args is None: - self.first_turn_args = turn_args.copy() - - self.last_turn_args = turn_args - with generation_span(disabled=not self.tracing_enabled) as span: - output = self.get_next_output() - if isinstance(output, Exception): - span.set_error( - SpanError( - message="Error", - data={ - "name": output.__class__.__name__, - "message": str(output), - }, - ) - ) - raise output - - response = get_response_obj(output, usage=self.hardcoded_usage) - sequence_number = 0 - - yield ResponseCreatedEvent( - type="response.created", - response=response, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseInProgressEvent( - type="response.in_progress", - response=response, - sequence_number=sequence_number, - ) - sequence_number += 1 - - for output_index, output_item in enumerate(output): - yield ResponseOutputItemAddedEvent( - type="response.output_item.added", - item=output_item, - output_index=output_index, - sequence_number=sequence_number, - ) - sequence_number += 1 - - if isinstance(output_item, ResponseReasoningItem): - if output_item.summary: - for summary_index, summary in enumerate(output_item.summary): - yield ResponseReasoningSummaryPartAddedEvent( - type="response.reasoning_summary_part.added", - item_id=output_item.id, - output_index=output_index, - summary_index=summary_index, - part=AddedEventPart(text=summary.text, type=summary.type), - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseReasoningSummaryTextDeltaEvent( - type="response.reasoning_summary_text.delta", - item_id=output_item.id, - output_index=output_index, - summary_index=summary_index, - delta=summary.text, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseReasoningSummaryTextDoneEvent( - type="response.reasoning_summary_text.done", - item_id=output_item.id, - output_index=output_index, - summary_index=summary_index, - text=summary.text, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseReasoningSummaryPartDoneEvent( - type="response.reasoning_summary_part.done", - item_id=output_item.id, - output_index=output_index, - summary_index=summary_index, - part=DoneEventPart(text=summary.text, type=summary.type), - sequence_number=sequence_number, - ) - sequence_number += 1 - - elif isinstance(output_item, ResponseFunctionToolCall): - yield ResponseFunctionCallArgumentsDeltaEvent( - type="response.function_call_arguments.delta", - item_id=output_item.call_id, - output_index=output_index, - delta=output_item.arguments, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseFunctionCallArgumentsDoneEvent( - type="response.function_call_arguments.done", - item_id=output_item.call_id, - output_index=output_index, - arguments=output_item.arguments, - name=output_item.name, - sequence_number=sequence_number, - ) - sequence_number += 1 - - elif isinstance(output_item, ResponseOutputMessage): - for content_index, content_part in enumerate(output_item.content): - if isinstance(content_part, ResponseOutputText): - yield ResponseContentPartAddedEvent( - type="response.content_part.added", - item_id=output_item.id, - output_index=output_index, - content_index=content_index, - part=content_part, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseTextDeltaEvent( - type="response.output_text.delta", - item_id=output_item.id, - output_index=output_index, - content_index=content_index, - delta=content_part.text, - logprobs=[], - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseTextDoneEvent( - type="response.output_text.done", - item_id=output_item.id, - output_index=output_index, - content_index=content_index, - text=content_part.text, - logprobs=[], - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseContentPartDoneEvent( - type="response.content_part.done", - item_id=output_item.id, - output_index=output_index, - content_index=content_index, - part=content_part, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseOutputItemDoneEvent( - type="response.output_item.done", - item=output_item, - output_index=output_index, - sequence_number=sequence_number, - ) - sequence_number += 1 - - yield ResponseCompletedEvent( - type="response.completed", - response=response, - sequence_number=sequence_number, - ) - - -def get_response_obj( - output: list[TResponseOutputItem], - response_id: str | None = None, - usage: Usage | None = None, -) -> Response: - return Response( - id=response_id or "resp-789", - created_at=123, - model="test_model", - object="response", - output=output, - tool_choice="none", - tools=[], - top_p=None, - parallel_tool_calls=False, - usage=ResponseUsage( - input_tokens=usage.input_tokens if usage else 0, - output_tokens=usage.output_tokens if usage else 0, - total_tokens=usage.total_tokens if usage else 0, - input_tokens_details=InputTokensDetails(cached_tokens=0), - output_tokens_details=OutputTokensDetails(reasoning_tokens=0), - ), - ) +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any + +from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseFunctionCallArgumentsDeltaEvent, + ResponseFunctionCallArgumentsDoneEvent, + ResponseFunctionToolCall, + ResponseInProgressEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningSummaryPartAddedEvent, + ResponseReasoningSummaryPartDoneEvent, + ResponseReasoningSummaryTextDeltaEvent, + ResponseReasoningSummaryTextDoneEvent, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, + ResponseUsage, +) +from openai.types.responses.response_reasoning_item import ResponseReasoningItem +from openai.types.responses.response_reasoning_summary_part_added_event import ( + Part as AddedEventPart, +) +from openai.types.responses.response_reasoning_summary_part_done_event import Part as DoneEventPart +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from agents.agent_output import AgentOutputSchemaBase +from agents.handoffs import Handoff +from agents.items import ( + ModelResponse, + TResponseInputItem, + TResponseOutputItem, + TResponseStreamEvent, +) +from agents.model_settings import ModelSettings +from agents.models.interface import Model, ModelTracing +from agents.tool import Tool +from agents.tracing import SpanError, generation_span +from agents.usage import Usage + + +class FakeModel(Model): + def __init__( + self, + tracing_enabled: bool = False, + initial_output: list[TResponseOutputItem] | Exception | None = None, + ): + if initial_output is None: + initial_output = [] + self.turn_outputs: list[list[TResponseOutputItem] | Exception] = ( + [initial_output] if initial_output else [] + ) + self.tracing_enabled = tracing_enabled + self.last_turn_args: dict[str, Any] = {} + self.first_turn_args: dict[str, Any] | None = None + self.hardcoded_usage: Usage | None = None + + def set_hardcoded_usage(self, usage: Usage): + self.hardcoded_usage = usage + + def set_next_output(self, output: list[TResponseOutputItem] | Exception): + self.turn_outputs.append(output) + + def add_multiple_turn_outputs(self, outputs: list[list[TResponseOutputItem] | Exception]): + self.turn_outputs.extend(outputs) + + def get_next_output(self) -> list[TResponseOutputItem] | Exception: + if not self.turn_outputs: + return [] + return self.turn_outputs.pop(0) + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: Any | None, + ) -> ModelResponse: + turn_args = { + "system_instructions": system_instructions, + "input": input, + "model_settings": model_settings, + "tools": tools, + "output_schema": output_schema, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + } + + if self.first_turn_args is None: + self.first_turn_args = turn_args.copy() + + self.last_turn_args = turn_args + + with generation_span(disabled=not self.tracing_enabled) as span: + output = self.get_next_output() + + if isinstance(output, Exception): + span.set_error( + SpanError( + message="Error", + data={ + "name": output.__class__.__name__, + "message": str(output), + }, + ) + ) + raise output + + return ModelResponse( + output=output, + usage=self.hardcoded_usage or Usage(), + response_id="resp-789", + ) + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None = None, + conversation_id: str | None = None, + prompt: Any | None = None, + ) -> AsyncIterator[TResponseStreamEvent]: + turn_args = { + "system_instructions": system_instructions, + "input": input, + "model_settings": model_settings, + "tools": tools, + "output_schema": output_schema, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + } + + if self.first_turn_args is None: + self.first_turn_args = turn_args.copy() + + self.last_turn_args = turn_args + with generation_span(disabled=not self.tracing_enabled) as span: + output = self.get_next_output() + if isinstance(output, Exception): + span.set_error( + SpanError( + message="Error", + data={ + "name": output.__class__.__name__, + "message": str(output), + }, + ) + ) + raise output + + response = get_response_obj(output, usage=self.hardcoded_usage) + sequence_number = 0 + + yield ResponseCreatedEvent( + type="response.created", + response=response, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseInProgressEvent( + type="response.in_progress", + response=response, + sequence_number=sequence_number, + ) + sequence_number += 1 + + for output_index, output_item in enumerate(output): + yield ResponseOutputItemAddedEvent( + type="response.output_item.added", + item=output_item, + output_index=output_index, + sequence_number=sequence_number, + ) + sequence_number += 1 + + if isinstance(output_item, ResponseReasoningItem): + if output_item.summary: + for summary_index, summary in enumerate(output_item.summary): + yield ResponseReasoningSummaryPartAddedEvent( + type="response.reasoning_summary_part.added", + item_id=output_item.id, + output_index=output_index, + summary_index=summary_index, + part=AddedEventPart(text=summary.text, type=summary.type), + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseReasoningSummaryTextDeltaEvent( + type="response.reasoning_summary_text.delta", + item_id=output_item.id, + output_index=output_index, + summary_index=summary_index, + delta=summary.text, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseReasoningSummaryTextDoneEvent( + type="response.reasoning_summary_text.done", + item_id=output_item.id, + output_index=output_index, + summary_index=summary_index, + text=summary.text, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseReasoningSummaryPartDoneEvent( + type="response.reasoning_summary_part.done", + item_id=output_item.id, + output_index=output_index, + summary_index=summary_index, + part=DoneEventPart(text=summary.text, type=summary.type), + sequence_number=sequence_number, + ) + sequence_number += 1 + + elif isinstance(output_item, ResponseFunctionToolCall): + yield ResponseFunctionCallArgumentsDeltaEvent( + type="response.function_call_arguments.delta", + item_id=output_item.call_id, + output_index=output_index, + delta=output_item.arguments, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseFunctionCallArgumentsDoneEvent( + type="response.function_call_arguments.done", + item_id=output_item.call_id, + output_index=output_index, + arguments=output_item.arguments, + name=output_item.name, + sequence_number=sequence_number, + ) + sequence_number += 1 + + elif isinstance(output_item, ResponseOutputMessage): + for content_index, content_part in enumerate(output_item.content): + if isinstance(content_part, ResponseOutputText): + yield ResponseContentPartAddedEvent( + type="response.content_part.added", + item_id=output_item.id, + output_index=output_index, + content_index=content_index, + part=content_part, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseTextDeltaEvent( + type="response.output_text.delta", + item_id=output_item.id, + output_index=output_index, + content_index=content_index, + delta=content_part.text, + logprobs=[], + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseTextDoneEvent( + type="response.output_text.done", + item_id=output_item.id, + output_index=output_index, + content_index=content_index, + text=content_part.text, + logprobs=[], + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseContentPartDoneEvent( + type="response.content_part.done", + item_id=output_item.id, + output_index=output_index, + content_index=content_index, + part=content_part, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseOutputItemDoneEvent( + type="response.output_item.done", + item=output_item, + output_index=output_index, + sequence_number=sequence_number, + ) + sequence_number += 1 + + yield ResponseCompletedEvent( + type="response.completed", + response=response, + sequence_number=sequence_number, + ) + + +def get_response_obj( + output: list[TResponseOutputItem], + response_id: str | None = None, + usage: Usage | None = None, +) -> Response: + return Response( + id=response_id or "resp-789", + created_at=123, + model="test_model", + object="response", + output=output, + tool_choice="none", + tools=[], + top_p=None, + parallel_tool_calls=False, + usage=ResponseUsage( + input_tokens=usage.input_tokens if usage else 0, + output_tokens=usage.output_tokens if usage else 0, + total_tokens=usage.total_tokens if usage else 0, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ), + ) diff --git a/tests/test_agent_prompt.py b/tests/test_agent_prompt.py index b11c78893..3d5ed5a3f 100644 --- a/tests/test_agent_prompt.py +++ b/tests/test_agent_prompt.py @@ -1,99 +1,99 @@ -import pytest - -from agents import Agent, Prompt, RunContextWrapper, Runner - -from .fake_model import FakeModel -from .test_responses import get_text_message - - -class PromptCaptureFakeModel(FakeModel): - """Subclass of FakeModel that records the prompt passed to the model.""" - - def __init__(self): - super().__init__() - self.last_prompt = None - - async def get_response( - self, - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - tracing, - *, - previous_response_id, - conversation_id, - prompt, - ): - # Record the prompt that the agent resolved and passed in. - self.last_prompt = prompt - return await super().get_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - tracing, - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt, - ) - - -@pytest.mark.asyncio -async def test_static_prompt_is_resolved_correctly(): - static_prompt: Prompt = { - "id": "my_prompt", - "version": "1", - "variables": {"some_var": "some_value"}, - } - - agent = Agent(name="test", prompt=static_prompt) - context_wrapper = RunContextWrapper(context=None) - - resolved = await agent.get_prompt(context_wrapper) - - assert resolved == { - "id": "my_prompt", - "version": "1", - "variables": {"some_var": "some_value"}, - } - - -@pytest.mark.asyncio -async def test_dynamic_prompt_is_resolved_correctly(): - dynamic_prompt_value: Prompt = {"id": "dyn_prompt", "version": "2"} - - def dynamic_prompt_fn(_data): - return dynamic_prompt_value - - agent = Agent(name="test", prompt=dynamic_prompt_fn) - context_wrapper = RunContextWrapper(context=None) - - resolved = await agent.get_prompt(context_wrapper) - - assert resolved == {"id": "dyn_prompt", "version": "2", "variables": None} - - -@pytest.mark.asyncio -async def test_prompt_is_passed_to_model(): - static_prompt: Prompt = {"id": "model_prompt"} - - model = PromptCaptureFakeModel() - agent = Agent(name="test", model=model, prompt=static_prompt) - - # Ensure the model returns a simple message so the run completes in one turn. - model.set_next_output([get_text_message("done")]) - - await Runner.run(agent, input="hello") - - # The model should have received the prompt resolved by the agent. - expected_prompt = { - "id": "model_prompt", - "version": None, - "variables": None, - } - assert model.last_prompt == expected_prompt +import pytest + +from agents import Agent, Prompt, RunContextWrapper, Runner + +from .fake_model import FakeModel +from .test_responses import get_text_message + + +class PromptCaptureFakeModel(FakeModel): + """Subclass of FakeModel that records the prompt passed to the model.""" + + def __init__(self): + super().__init__() + self.last_prompt = None + + async def get_response( + self, + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + tracing, + *, + previous_response_id, + conversation_id, + prompt, + ): + # Record the prompt that the agent resolved and passed in. + self.last_prompt = prompt + return await super().get_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt, + ) + + +@pytest.mark.asyncio +async def test_static_prompt_is_resolved_correctly(): + static_prompt: Prompt = { + "id": "my_prompt", + "version": "1", + "variables": {"some_var": "some_value"}, + } + + agent = Agent(name="test", prompt=static_prompt) + context_wrapper = RunContextWrapper(context=None) + + resolved = await agent.get_prompt(context_wrapper) + + assert resolved == { + "id": "my_prompt", + "version": "1", + "variables": {"some_var": "some_value"}, + } + + +@pytest.mark.asyncio +async def test_dynamic_prompt_is_resolved_correctly(): + dynamic_prompt_value: Prompt = {"id": "dyn_prompt", "version": "2"} + + def dynamic_prompt_fn(_data): + return dynamic_prompt_value + + agent = Agent(name="test", prompt=dynamic_prompt_fn) + context_wrapper = RunContextWrapper(context=None) + + resolved = await agent.get_prompt(context_wrapper) + + assert resolved == {"id": "dyn_prompt", "version": "2", "variables": None} + + +@pytest.mark.asyncio +async def test_prompt_is_passed_to_model(): + static_prompt: Prompt = {"id": "model_prompt"} + + model = PromptCaptureFakeModel() + agent = Agent(name="test", model=model, prompt=static_prompt) + + # Ensure the model returns a simple message so the run completes in one turn. + model.set_next_output([get_text_message("done")]) + + await Runner.run(agent, input="hello") + + # The model should have received the prompt resolved by the agent. + expected_prompt = { + "id": "model_prompt", + "version": None, + "variables": None, + } + assert model.last_prompt == expected_prompt diff --git a/tests/test_streaming_tool_call_arguments.py b/tests/test_streaming_tool_call_arguments.py index 8e0f847c4..ce476e59b 100644 --- a/tests/test_streaming_tool_call_arguments.py +++ b/tests/test_streaming_tool_call_arguments.py @@ -1,373 +1,373 @@ -""" -Tests to ensure that tool call arguments are properly populated in streaming events. - -This test specifically guards against the regression where tool_called events -were emitted with empty arguments during streaming (Issue #1629). -""" - -import json -from collections.abc import AsyncIterator -from typing import Any, Optional, Union, cast - -import pytest -from openai.types.responses import ( - ResponseCompletedEvent, - ResponseFunctionToolCall, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, -) - -from agents import Agent, Runner, function_tool -from agents.agent_output import AgentOutputSchemaBase -from agents.handoffs import Handoff -from agents.items import TResponseInputItem, TResponseOutputItem, TResponseStreamEvent -from agents.model_settings import ModelSettings -from agents.models.interface import Model, ModelTracing -from agents.stream_events import RunItemStreamEvent -from agents.tool import Tool -from agents.tracing import generation_span - -from .fake_model import get_response_obj -from .test_responses import get_function_tool_call - - -class StreamingFakeModel(Model): - """A fake model that actually emits streaming events to test our streaming fix.""" - - def __init__(self): - self.turn_outputs: list[list[TResponseOutputItem]] = [] - self.last_turn_args: dict[str, Any] = {} - - def set_next_output(self, output: list[TResponseOutputItem]): - self.turn_outputs.append(output) - - def get_next_output(self) -> list[TResponseOutputItem]: - if not self.turn_outputs: - return [] - return self.turn_outputs.pop(0) - - async def get_response( - self, - system_instructions: Optional[str], - input: Union[str, list[TResponseInputItem]], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: Optional[AgentOutputSchemaBase], - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: Optional[str], - conversation_id: Optional[str], - prompt: Optional[Any], - ): - raise NotImplementedError("Use stream_response instead") - - async def stream_response( - self, - system_instructions: Optional[str], - input: Union[str, list[TResponseInputItem]], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: Optional[AgentOutputSchemaBase], - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: Optional[str] = None, - conversation_id: Optional[str] = None, - prompt: Optional[Any] = None, - ) -> AsyncIterator[TResponseStreamEvent]: - """Stream events that simulate real OpenAI streaming behavior for tool calls.""" - self.last_turn_args = { - "system_instructions": system_instructions, - "input": input, - "model_settings": model_settings, - "tools": tools, - "output_schema": output_schema, - "previous_response_id": previous_response_id, - "conversation_id": conversation_id, - } - - with generation_span(disabled=True) as _: - output = self.get_next_output() - - sequence_number = 0 - - # Emit each output item with proper streaming events - for item in output: - if isinstance(item, ResponseFunctionToolCall): - # First: emit ResponseOutputItemAddedEvent with EMPTY arguments - # (this simulates the real streaming behavior that was causing the bug) - empty_args_item = ResponseFunctionToolCall( - id=item.id, - call_id=item.call_id, - type=item.type, - name=item.name, - arguments="", # EMPTY - this is the bug condition! - ) - - yield ResponseOutputItemAddedEvent( - item=empty_args_item, - output_index=0, - type="response.output_item.added", - sequence_number=sequence_number, - ) - sequence_number += 1 - - # Then: emit ResponseOutputItemDoneEvent with COMPLETE arguments - yield ResponseOutputItemDoneEvent( - item=item, # This has the complete arguments - output_index=0, - type="response.output_item.done", - sequence_number=sequence_number, - ) - sequence_number += 1 - - # Finally: emit completion - yield ResponseCompletedEvent( - type="response.completed", - response=get_response_obj(output), - sequence_number=sequence_number, - ) - - -@function_tool -def calculate_sum(a: int, b: int) -> str: - """Add two numbers together.""" - return str(a + b) - - -@function_tool -def format_message(name: str, message: str, urgent: bool = False) -> str: - """Format a message with name and urgency.""" - prefix = "URGENT: " if urgent else "" - return f"{prefix}Hello {name}, {message}" - - -@pytest.mark.asyncio -async def test_streaming_tool_call_arguments_not_empty(): - """Test that tool_called events contain non-empty arguments during streaming.""" - model = StreamingFakeModel() - agent = Agent( - name="TestAgent", - model=model, - tools=[calculate_sum], - ) - - # Set up a tool call with arguments - expected_arguments = '{"a": 5, "b": 3}' - model.set_next_output( - [ - get_function_tool_call("calculate_sum", expected_arguments, "call_123"), - ] - ) - - result = Runner.run_streamed(agent, input="Add 5 and 3") - - tool_called_events = [] - async for event in result.stream_events(): - if ( - event.type == "run_item_stream_event" - and isinstance(event, RunItemStreamEvent) - and event.name == "tool_called" - ): - tool_called_events.append(event) - - # Verify we got exactly one tool_called event - assert len(tool_called_events) == 1, ( - f"Expected 1 tool_called event, got {len(tool_called_events)}" - ) - - tool_event = tool_called_events[0] - - # Verify the event has the expected structure - assert hasattr(tool_event.item, "raw_item"), "tool_called event should have raw_item" - assert hasattr(tool_event.item.raw_item, "arguments"), "raw_item should have arguments field" - - # The critical test: arguments should NOT be empty - # Cast to ResponseFunctionToolCall since we know that's what it is in our test - raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) - actual_arguments = raw_item.arguments - assert actual_arguments != "", ( - f"Tool call arguments should not be empty, got: '{actual_arguments}'" - ) - assert actual_arguments is not None, "Tool call arguments should not be None" - - # Verify arguments contain the expected data - assert actual_arguments == expected_arguments, ( - f"Expected arguments '{expected_arguments}', got '{actual_arguments}'" - ) - - # Verify arguments are valid JSON that can be parsed - try: - parsed_args = json.loads(actual_arguments) - assert parsed_args == {"a": 5, "b": 3}, ( - f"Parsed arguments should match expected values, got {parsed_args}" - ) - except json.JSONDecodeError as e: - pytest.fail( - f"Tool call arguments should be valid JSON, but got: '{actual_arguments}' with error: {e}" # noqa: E501 - ) - - -@pytest.mark.asyncio -async def test_streaming_tool_call_arguments_complex(): - """Test streaming tool calls with complex arguments including strings and booleans.""" - model = StreamingFakeModel() - agent = Agent( - name="TestAgent", - model=model, - tools=[format_message], - ) - - # Set up a tool call with complex arguments - expected_arguments = ( - '{"name": "Alice", "message": "Your meeting is starting soon", "urgent": true}' - ) - model.set_next_output( - [ - get_function_tool_call("format_message", expected_arguments, "call_456"), - ] - ) - - result = Runner.run_streamed(agent, input="Format a message for Alice") - - tool_called_events = [] - async for event in result.stream_events(): - if ( - event.type == "run_item_stream_event" - and isinstance(event, RunItemStreamEvent) - and event.name == "tool_called" - ): - tool_called_events.append(event) - - assert len(tool_called_events) == 1, ( - f"Expected 1 tool_called event, got {len(tool_called_events)}" - ) - - tool_event = tool_called_events[0] - # Cast to ResponseFunctionToolCall since we know that's what it is in our test - raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) - actual_arguments = raw_item.arguments - - # Critical checks for the regression - assert actual_arguments != "", "Tool call arguments should not be empty" - assert actual_arguments is not None, "Tool call arguments should not be None" - assert actual_arguments == expected_arguments, ( - f"Expected '{expected_arguments}', got '{actual_arguments}'" - ) - - # Verify the complex arguments parse correctly - parsed_args = json.loads(actual_arguments) - expected_parsed = {"name": "Alice", "message": "Your meeting is starting soon", "urgent": True} - assert parsed_args == expected_parsed, f"Parsed arguments should match, got {parsed_args}" - - -@pytest.mark.asyncio -async def test_streaming_multiple_tool_calls_arguments(): - """Test that multiple tool calls in streaming all have proper arguments.""" - model = StreamingFakeModel() - agent = Agent( - name="TestAgent", - model=model, - tools=[calculate_sum, format_message], - ) - - # Set up multiple tool calls - model.set_next_output( - [ - get_function_tool_call("calculate_sum", '{"a": 10, "b": 20}', "call_1"), - get_function_tool_call( - "format_message", '{"name": "Bob", "message": "Test"}', "call_2" - ), - ] - ) - - result = Runner.run_streamed(agent, input="Do some calculations") - - tool_called_events = [] - async for event in result.stream_events(): - if ( - event.type == "run_item_stream_event" - and isinstance(event, RunItemStreamEvent) - and event.name == "tool_called" - ): - tool_called_events.append(event) - - # Should have exactly 2 tool_called events - assert len(tool_called_events) == 2, ( - f"Expected 2 tool_called events, got {len(tool_called_events)}" - ) - - # Check first tool call - event1 = tool_called_events[0] - # Cast to ResponseFunctionToolCall since we know that's what it is in our test - raw_item1 = cast(ResponseFunctionToolCall, event1.item.raw_item) - args1 = raw_item1.arguments - assert args1 != "", "First tool call arguments should not be empty" - expected_args1 = '{"a": 10, "b": 20}' - assert args1 == expected_args1, ( - f"First tool call args: expected '{expected_args1}', got '{args1}'" - ) - - # Check second tool call - event2 = tool_called_events[1] - # Cast to ResponseFunctionToolCall since we know that's what it is in our test - raw_item2 = cast(ResponseFunctionToolCall, event2.item.raw_item) - args2 = raw_item2.arguments - assert args2 != "", "Second tool call arguments should not be empty" - expected_args2 = '{"name": "Bob", "message": "Test"}' - assert args2 == expected_args2, ( - f"Second tool call args: expected '{expected_args2}', got '{args2}'" - ) - - -@pytest.mark.asyncio -async def test_streaming_tool_call_with_empty_arguments(): - """Test that tool calls with legitimately empty arguments still work correctly.""" - model = StreamingFakeModel() - - @function_tool - def get_current_time() -> str: - """Get the current time (no arguments needed).""" - return "2024-01-15 10:30:00" - - agent = Agent( - name="TestAgent", - model=model, - tools=[get_current_time], - ) - - # Tool call with empty arguments (legitimate case) - model.set_next_output( - [ - get_function_tool_call("get_current_time", "{}", "call_time"), - ] - ) - - result = Runner.run_streamed(agent, input="What time is it?") - - tool_called_events = [] - async for event in result.stream_events(): - if ( - event.type == "run_item_stream_event" - and isinstance(event, RunItemStreamEvent) - and event.name == "tool_called" - ): - tool_called_events.append(event) - - assert len(tool_called_events) == 1, ( - f"Expected 1 tool_called event, got {len(tool_called_events)}" - ) - - tool_event = tool_called_events[0] - # Cast to ResponseFunctionToolCall since we know that's what it is in our test - raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) - actual_arguments = raw_item.arguments - - # Even "empty" arguments should be "{}", not literally empty string - assert actual_arguments is not None, "Arguments should not be None" - assert actual_arguments == "{}", f"Expected empty JSON object '{{}}', got '{actual_arguments}'" - - # Should parse as valid empty JSON - parsed_args = json.loads(actual_arguments) - assert parsed_args == {}, f"Should parse to empty dict, got {parsed_args}" +""" +Tests to ensure that tool call arguments are properly populated in streaming events. + +This test specifically guards against the regression where tool_called events +were emitted with empty arguments during streaming (Issue #1629). +""" + +import json +from collections.abc import AsyncIterator +from typing import Any, Optional, Union, cast + +import pytest +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseFunctionToolCall, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, +) + +from agents import Agent, Runner, function_tool +from agents.agent_output import AgentOutputSchemaBase +from agents.handoffs import Handoff +from agents.items import TResponseInputItem, TResponseOutputItem, TResponseStreamEvent +from agents.model_settings import ModelSettings +from agents.models.interface import Model, ModelTracing +from agents.stream_events import RunItemStreamEvent +from agents.tool import Tool +from agents.tracing import generation_span + +from .fake_model import get_response_obj +from .test_responses import get_function_tool_call + + +class StreamingFakeModel(Model): + """A fake model that actually emits streaming events to test our streaming fix.""" + + def __init__(self): + self.turn_outputs: list[list[TResponseOutputItem]] = [] + self.last_turn_args: dict[str, Any] = {} + + def set_next_output(self, output: list[TResponseOutputItem]): + self.turn_outputs.append(output) + + def get_next_output(self) -> list[TResponseOutputItem]: + if not self.turn_outputs: + return [] + return self.turn_outputs.pop(0) + + async def get_response( + self, + system_instructions: Optional[str], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Optional[AgentOutputSchemaBase], + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: Optional[str], + conversation_id: Optional[str], + prompt: Optional[Any], + ): + raise NotImplementedError("Use stream_response instead") + + async def stream_response( + self, + system_instructions: Optional[str], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Optional[AgentOutputSchemaBase], + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: Optional[str] = None, + conversation_id: Optional[str] = None, + prompt: Optional[Any] = None, + ) -> AsyncIterator[TResponseStreamEvent]: + """Stream events that simulate real OpenAI streaming behavior for tool calls.""" + self.last_turn_args = { + "system_instructions": system_instructions, + "input": input, + "model_settings": model_settings, + "tools": tools, + "output_schema": output_schema, + "previous_response_id": previous_response_id, + "conversation_id": conversation_id, + } + + with generation_span(disabled=True) as _: + output = self.get_next_output() + + sequence_number = 0 + + # Emit each output item with proper streaming events + for item in output: + if isinstance(item, ResponseFunctionToolCall): + # First: emit ResponseOutputItemAddedEvent with EMPTY arguments + # (this simulates the real streaming behavior that was causing the bug) + empty_args_item = ResponseFunctionToolCall( + id=item.id, + call_id=item.call_id, + type=item.type, + name=item.name, + arguments="", # EMPTY - this is the bug condition! + ) + + yield ResponseOutputItemAddedEvent( + item=empty_args_item, + output_index=0, + type="response.output_item.added", + sequence_number=sequence_number, + ) + sequence_number += 1 + + # Then: emit ResponseOutputItemDoneEvent with COMPLETE arguments + yield ResponseOutputItemDoneEvent( + item=item, # This has the complete arguments + output_index=0, + type="response.output_item.done", + sequence_number=sequence_number, + ) + sequence_number += 1 + + # Finally: emit completion + yield ResponseCompletedEvent( + type="response.completed", + response=get_response_obj(output), + sequence_number=sequence_number, + ) + + +@function_tool +def calculate_sum(a: int, b: int) -> str: + """Add two numbers together.""" + return str(a + b) + + +@function_tool +def format_message(name: str, message: str, urgent: bool = False) -> str: + """Format a message with name and urgency.""" + prefix = "URGENT: " if urgent else "" + return f"{prefix}Hello {name}, {message}" + + +@pytest.mark.asyncio +async def test_streaming_tool_call_arguments_not_empty(): + """Test that tool_called events contain non-empty arguments during streaming.""" + model = StreamingFakeModel() + agent = Agent( + name="TestAgent", + model=model, + tools=[calculate_sum], + ) + + # Set up a tool call with arguments + expected_arguments = '{"a": 5, "b": 3}' + model.set_next_output( + [ + get_function_tool_call("calculate_sum", expected_arguments, "call_123"), + ] + ) + + result = Runner.run_streamed(agent, input="Add 5 and 3") + + tool_called_events = [] + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and isinstance(event, RunItemStreamEvent) + and event.name == "tool_called" + ): + tool_called_events.append(event) + + # Verify we got exactly one tool_called event + assert len(tool_called_events) == 1, ( + f"Expected 1 tool_called event, got {len(tool_called_events)}" + ) + + tool_event = tool_called_events[0] + + # Verify the event has the expected structure + assert hasattr(tool_event.item, "raw_item"), "tool_called event should have raw_item" + assert hasattr(tool_event.item.raw_item, "arguments"), "raw_item should have arguments field" + + # The critical test: arguments should NOT be empty + # Cast to ResponseFunctionToolCall since we know that's what it is in our test + raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) + actual_arguments = raw_item.arguments + assert actual_arguments != "", ( + f"Tool call arguments should not be empty, got: '{actual_arguments}'" + ) + assert actual_arguments is not None, "Tool call arguments should not be None" + + # Verify arguments contain the expected data + assert actual_arguments == expected_arguments, ( + f"Expected arguments '{expected_arguments}', got '{actual_arguments}'" + ) + + # Verify arguments are valid JSON that can be parsed + try: + parsed_args = json.loads(actual_arguments) + assert parsed_args == {"a": 5, "b": 3}, ( + f"Parsed arguments should match expected values, got {parsed_args}" + ) + except json.JSONDecodeError as e: + pytest.fail( + f"Tool call arguments should be valid JSON, but got: '{actual_arguments}' with error: {e}" # noqa: E501 + ) + + +@pytest.mark.asyncio +async def test_streaming_tool_call_arguments_complex(): + """Test streaming tool calls with complex arguments including strings and booleans.""" + model = StreamingFakeModel() + agent = Agent( + name="TestAgent", + model=model, + tools=[format_message], + ) + + # Set up a tool call with complex arguments + expected_arguments = ( + '{"name": "Alice", "message": "Your meeting is starting soon", "urgent": true}' + ) + model.set_next_output( + [ + get_function_tool_call("format_message", expected_arguments, "call_456"), + ] + ) + + result = Runner.run_streamed(agent, input="Format a message for Alice") + + tool_called_events = [] + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and isinstance(event, RunItemStreamEvent) + and event.name == "tool_called" + ): + tool_called_events.append(event) + + assert len(tool_called_events) == 1, ( + f"Expected 1 tool_called event, got {len(tool_called_events)}" + ) + + tool_event = tool_called_events[0] + # Cast to ResponseFunctionToolCall since we know that's what it is in our test + raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) + actual_arguments = raw_item.arguments + + # Critical checks for the regression + assert actual_arguments != "", "Tool call arguments should not be empty" + assert actual_arguments is not None, "Tool call arguments should not be None" + assert actual_arguments == expected_arguments, ( + f"Expected '{expected_arguments}', got '{actual_arguments}'" + ) + + # Verify the complex arguments parse correctly + parsed_args = json.loads(actual_arguments) + expected_parsed = {"name": "Alice", "message": "Your meeting is starting soon", "urgent": True} + assert parsed_args == expected_parsed, f"Parsed arguments should match, got {parsed_args}" + + +@pytest.mark.asyncio +async def test_streaming_multiple_tool_calls_arguments(): + """Test that multiple tool calls in streaming all have proper arguments.""" + model = StreamingFakeModel() + agent = Agent( + name="TestAgent", + model=model, + tools=[calculate_sum, format_message], + ) + + # Set up multiple tool calls + model.set_next_output( + [ + get_function_tool_call("calculate_sum", '{"a": 10, "b": 20}', "call_1"), + get_function_tool_call( + "format_message", '{"name": "Bob", "message": "Test"}', "call_2" + ), + ] + ) + + result = Runner.run_streamed(agent, input="Do some calculations") + + tool_called_events = [] + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and isinstance(event, RunItemStreamEvent) + and event.name == "tool_called" + ): + tool_called_events.append(event) + + # Should have exactly 2 tool_called events + assert len(tool_called_events) == 2, ( + f"Expected 2 tool_called events, got {len(tool_called_events)}" + ) + + # Check first tool call + event1 = tool_called_events[0] + # Cast to ResponseFunctionToolCall since we know that's what it is in our test + raw_item1 = cast(ResponseFunctionToolCall, event1.item.raw_item) + args1 = raw_item1.arguments + assert args1 != "", "First tool call arguments should not be empty" + expected_args1 = '{"a": 10, "b": 20}' + assert args1 == expected_args1, ( + f"First tool call args: expected '{expected_args1}', got '{args1}'" + ) + + # Check second tool call + event2 = tool_called_events[1] + # Cast to ResponseFunctionToolCall since we know that's what it is in our test + raw_item2 = cast(ResponseFunctionToolCall, event2.item.raw_item) + args2 = raw_item2.arguments + assert args2 != "", "Second tool call arguments should not be empty" + expected_args2 = '{"name": "Bob", "message": "Test"}' + assert args2 == expected_args2, ( + f"Second tool call args: expected '{expected_args2}', got '{args2}'" + ) + + +@pytest.mark.asyncio +async def test_streaming_tool_call_with_empty_arguments(): + """Test that tool calls with legitimately empty arguments still work correctly.""" + model = StreamingFakeModel() + + @function_tool + def get_current_time() -> str: + """Get the current time (no arguments needed).""" + return "2024-01-15 10:30:00" + + agent = Agent( + name="TestAgent", + model=model, + tools=[get_current_time], + ) + + # Tool call with empty arguments (legitimate case) + model.set_next_output( + [ + get_function_tool_call("get_current_time", "{}", "call_time"), + ] + ) + + result = Runner.run_streamed(agent, input="What time is it?") + + tool_called_events = [] + async for event in result.stream_events(): + if ( + event.type == "run_item_stream_event" + and isinstance(event, RunItemStreamEvent) + and event.name == "tool_called" + ): + tool_called_events.append(event) + + assert len(tool_called_events) == 1, ( + f"Expected 1 tool_called event, got {len(tool_called_events)}" + ) + + tool_event = tool_called_events[0] + # Cast to ResponseFunctionToolCall since we know that's what it is in our test + raw_item = cast(ResponseFunctionToolCall, tool_event.item.raw_item) + actual_arguments = raw_item.arguments + + # Even "empty" arguments should be "{}", not literally empty string + assert actual_arguments is not None, "Arguments should not be None" + assert actual_arguments == "{}", f"Expected empty JSON object '{{}}', got '{actual_arguments}'" + + # Should parse as valid empty JSON + parsed_args = json.loads(actual_arguments) + assert parsed_args == {}, f"Should parse to empty dict, got {parsed_args}" diff --git a/tests/utils/test_prompts.py b/tests/utils/test_prompts.py index 503dd7dfe..cc277bcfc 100644 --- a/tests/utils/test_prompts.py +++ b/tests/utils/test_prompts.py @@ -105,3 +105,14 @@ def test_should_inject_json_prompt_multiple_tools(): ] result = should_inject_json_prompt(schema, tools, enable_structured_output_with_tools=True) assert result is True + + +def test_should_inject_json_prompt_with_handoffs_as_tools(): + """Test that handoffs (passed as tools) trigger injection when enabled.""" + schema = AgentOutputSchema(SimpleModel) + # Simulate handoffs being passed in the tools list + handoffs_as_tools = [{"type": "function", "name": "handoff_to_agent"}] + result = should_inject_json_prompt( + schema, handoffs_as_tools, enable_structured_output_with_tools=True + ) + assert result is True diff --git a/tests/voice/test_workflow.py b/tests/voice/test_workflow.py index a12be1dd1..402c52128 100644 --- a/tests/voice/test_workflow.py +++ b/tests/voice/test_workflow.py @@ -1,219 +1,219 @@ -from __future__ import annotations - -import json -from collections.abc import AsyncIterator -from typing import Any - -import pytest -from inline_snapshot import snapshot -from openai.types.responses import ResponseCompletedEvent -from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent - -from agents import Agent, Model, ModelSettings, ModelTracing, Tool -from agents.agent_output import AgentOutputSchemaBase -from agents.handoffs import Handoff -from agents.items import ( - ModelResponse, - TResponseInputItem, - TResponseOutputItem, - TResponseStreamEvent, -) - -from ..fake_model import get_response_obj -from ..test_responses import get_function_tool, get_function_tool_call, get_text_message - -try: - from agents.voice import SingleAgentVoiceWorkflow - -except ImportError: - pass - - -class FakeStreamingModel(Model): - def __init__(self): - self.turn_outputs: list[list[TResponseOutputItem]] = [] - - def set_next_output(self, output: list[TResponseOutputItem]): - self.turn_outputs.append(output) - - def add_multiple_turn_outputs(self, outputs: list[list[TResponseOutputItem]]): - self.turn_outputs.extend(outputs) - - def get_next_output(self) -> list[TResponseOutputItem]: - if not self.turn_outputs: - return [] - return self.turn_outputs.pop(0) - - async def get_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None, - conversation_id: str | None, - prompt: Any | None, - ) -> ModelResponse: - raise NotImplementedError("Not implemented") - - async def stream_response( - self, - system_instructions: str | None, - input: str | list[TResponseInputItem], - model_settings: ModelSettings, - tools: list[Tool], - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - tracing: ModelTracing, - *, - previous_response_id: str | None, - conversation_id: str | None, - prompt: Any | None, - ) -> AsyncIterator[TResponseStreamEvent]: - output = self.get_next_output() - for item in output: - if ( - item.type == "message" - and len(item.content) == 1 - and item.content[0].type == "output_text" - ): - yield ResponseTextDeltaEvent( - content_index=0, - delta=item.content[0].text, - type="response.output_text.delta", - output_index=0, - item_id=item.id, - sequence_number=0, - logprobs=[], - ) - - yield ResponseCompletedEvent( - type="response.completed", - response=get_response_obj(output), - sequence_number=1, - ) - - -@pytest.mark.asyncio -async def test_single_agent_workflow(monkeypatch) -> None: - model = FakeStreamingModel() - model.add_multiple_turn_outputs( - [ - # First turn: a message and a tool call - [ - get_function_tool_call("some_function", json.dumps({"a": "b"})), - get_text_message("a_message"), - ], - # Second turn: text message - [get_text_message("done")], - ] - ) - - agent = Agent( - "initial_agent", - model=model, - tools=[get_function_tool("some_function", "tool_result")], - ) - - workflow = SingleAgentVoiceWorkflow(agent) - output = [] - async for chunk in workflow.run("transcription_1"): - output.append(chunk) - - # Validate that the text yielded matches our fake events - assert output == ["a_message", "done"] - # Validate that internal state was updated - assert workflow._input_history == snapshot( - [ - {"content": "transcription_1", "role": "user"}, - { - "arguments": '{"a": "b"}', - "call_id": "2", - "name": "some_function", - "type": "function_call", - "id": "1", - }, - { - "id": "1", - "content": [ - {"annotations": [], "logprobs": [], "text": "a_message", "type": "output_text"} - ], - "role": "assistant", - "status": "completed", - "type": "message", - }, - { - "call_id": "2", - "output": "tool_result", - "type": "function_call_output", - }, - { - "id": "1", - "content": [ - {"annotations": [], "logprobs": [], "text": "done", "type": "output_text"} - ], - "role": "assistant", - "status": "completed", - "type": "message", - }, - ] - ) - assert workflow._current_agent == agent - - model.set_next_output([get_text_message("done_2")]) - - # Run it again with a new transcription to make sure the input history is updated - output = [] - async for chunk in workflow.run("transcription_2"): - output.append(chunk) - - assert workflow._input_history == snapshot( - [ - {"role": "user", "content": "transcription_1"}, - { - "arguments": '{"a": "b"}', - "call_id": "2", - "name": "some_function", - "type": "function_call", - "id": "1", - }, - { - "id": "1", - "content": [ - {"annotations": [], "logprobs": [], "text": "a_message", "type": "output_text"} - ], - "role": "assistant", - "status": "completed", - "type": "message", - }, - { - "call_id": "2", - "output": "tool_result", - "type": "function_call_output", - }, - { - "id": "1", - "content": [ - {"annotations": [], "logprobs": [], "text": "done", "type": "output_text"} - ], - "role": "assistant", - "status": "completed", - "type": "message", - }, - {"role": "user", "content": "transcription_2"}, - { - "id": "1", - "content": [ - {"annotations": [], "logprobs": [], "text": "done_2", "type": "output_text"} - ], - "role": "assistant", - "status": "completed", - "type": "message", - }, - ] - ) - assert workflow._current_agent == agent +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from typing import Any + +import pytest +from inline_snapshot import snapshot +from openai.types.responses import ResponseCompletedEvent +from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent + +from agents import Agent, Model, ModelSettings, ModelTracing, Tool +from agents.agent_output import AgentOutputSchemaBase +from agents.handoffs import Handoff +from agents.items import ( + ModelResponse, + TResponseInputItem, + TResponseOutputItem, + TResponseStreamEvent, +) + +from ..fake_model import get_response_obj +from ..test_responses import get_function_tool, get_function_tool_call, get_text_message + +try: + from agents.voice import SingleAgentVoiceWorkflow + +except ImportError: + pass + + +class FakeStreamingModel(Model): + def __init__(self): + self.turn_outputs: list[list[TResponseOutputItem]] = [] + + def set_next_output(self, output: list[TResponseOutputItem]): + self.turn_outputs.append(output) + + def add_multiple_turn_outputs(self, outputs: list[list[TResponseOutputItem]]): + self.turn_outputs.extend(outputs) + + def get_next_output(self) -> list[TResponseOutputItem]: + if not self.turn_outputs: + return [] + return self.turn_outputs.pop(0) + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: Any | None, + ) -> ModelResponse: + raise NotImplementedError("Not implemented") + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: Any | None, + ) -> AsyncIterator[TResponseStreamEvent]: + output = self.get_next_output() + for item in output: + if ( + item.type == "message" + and len(item.content) == 1 + and item.content[0].type == "output_text" + ): + yield ResponseTextDeltaEvent( + content_index=0, + delta=item.content[0].text, + type="response.output_text.delta", + output_index=0, + item_id=item.id, + sequence_number=0, + logprobs=[], + ) + + yield ResponseCompletedEvent( + type="response.completed", + response=get_response_obj(output), + sequence_number=1, + ) + + +@pytest.mark.asyncio +async def test_single_agent_workflow(monkeypatch) -> None: + model = FakeStreamingModel() + model.add_multiple_turn_outputs( + [ + # First turn: a message and a tool call + [ + get_function_tool_call("some_function", json.dumps({"a": "b"})), + get_text_message("a_message"), + ], + # Second turn: text message + [get_text_message("done")], + ] + ) + + agent = Agent( + "initial_agent", + model=model, + tools=[get_function_tool("some_function", "tool_result")], + ) + + workflow = SingleAgentVoiceWorkflow(agent) + output = [] + async for chunk in workflow.run("transcription_1"): + output.append(chunk) + + # Validate that the text yielded matches our fake events + assert output == ["a_message", "done"] + # Validate that internal state was updated + assert workflow._input_history == snapshot( + [ + {"content": "transcription_1", "role": "user"}, + { + "arguments": '{"a": "b"}', + "call_id": "2", + "name": "some_function", + "type": "function_call", + "id": "1", + }, + { + "id": "1", + "content": [ + {"annotations": [], "logprobs": [], "text": "a_message", "type": "output_text"} + ], + "role": "assistant", + "status": "completed", + "type": "message", + }, + { + "call_id": "2", + "output": "tool_result", + "type": "function_call_output", + }, + { + "id": "1", + "content": [ + {"annotations": [], "logprobs": [], "text": "done", "type": "output_text"} + ], + "role": "assistant", + "status": "completed", + "type": "message", + }, + ] + ) + assert workflow._current_agent == agent + + model.set_next_output([get_text_message("done_2")]) + + # Run it again with a new transcription to make sure the input history is updated + output = [] + async for chunk in workflow.run("transcription_2"): + output.append(chunk) + + assert workflow._input_history == snapshot( + [ + {"role": "user", "content": "transcription_1"}, + { + "arguments": '{"a": "b"}', + "call_id": "2", + "name": "some_function", + "type": "function_call", + "id": "1", + }, + { + "id": "1", + "content": [ + {"annotations": [], "logprobs": [], "text": "a_message", "type": "output_text"} + ], + "role": "assistant", + "status": "completed", + "type": "message", + }, + { + "call_id": "2", + "output": "tool_result", + "type": "function_call_output", + }, + { + "id": "1", + "content": [ + {"annotations": [], "logprobs": [], "text": "done", "type": "output_text"} + ], + "role": "assistant", + "status": "completed", + "type": "message", + }, + {"role": "user", "content": "transcription_2"}, + { + "id": "1", + "content": [ + {"annotations": [], "logprobs": [], "text": "done_2", "type": "output_text"} + ], + "role": "assistant", + "status": "completed", + "type": "message", + }, + ] + ) + assert workflow._current_agent == agent