Skip to content
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/scripts/translate_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ def translate_single_source_file(file_path: str) -> None:

def main():
parser = argparse.ArgumentParser(description="Translate documentation files")
parser.add_argument("--file", type=str, help="Specific file to translate (relative to docs directory)")
parser.add_argument(
"--file", type=str, help="Specific file to translate (relative to docs directory)"
)
args = parser.parse_args()

if args.file:
Expand Down
2 changes: 1 addition & 1 deletion examples/basic/hello_world_jupyter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"agent = Agent(name=\"Assistant\", instructions=\"You are a helpful assistant\")\n",
"\n",
"# Intended for Jupyter notebooks where there's an existing event loop\n",
"result = await Runner.run(agent, \"Write a haiku about recursion in programming.\") # type: ignore[top-level-await] # noqa: F704\n",
"result = await Runner.run(agent, \"Write a haiku about recursion in programming.\") # type: ignore[top-level-await] # noqa: F704\n",
"print(result.final_output)"
]
}
Expand Down
12 changes: 9 additions & 3 deletions examples/mcp/streamablehttp_example/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from agents.model_settings import ModelSettings


async def run(mcp_server: MCPServer):
async def run(mcp_server: MCPServer, instructions):
agent = Agent(
name="Assistant",
instructions="Use the tools to answer the questions.",
instructions=instructions,
mcp_servers=[mcp_server],
model_settings=ModelSettings(tool_choice="required"),
)
Expand Down Expand Up @@ -46,8 +46,14 @@ async def main():
) as server:
trace_id = gen_trace_id()
with trace(workflow_name="Streamable HTTP Example", trace_id=trace_id):
# List available prompts
prompts = await server.list_prompts()
print(f"Prompts list -> {prompts}")
system_prompt = await server.get_prompt("system_prompt")
instructions = system_prompt.messages[0].content.text
print(f"instructions -> {instructions}")
print(f"View trace: https://platform.openai.com/traces/trace?trace_id={trace_id}\n")
await run(server)
await run(server, instructions)


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions examples/mcp/streamablehttp_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,10 @@ def get_current_weather(city: str) -> str:
return response.text


@mcp.prompt()
def system_prompt():
return "Use the tools to answer the questions."


if __name__ == "__main__":
mcp.run(transport="streamable-http")
4 changes: 2 additions & 2 deletions examples/reasoning_content/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def stream_with_reasoning_content():
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
prompt=None
prompt=None,
):
if event.type == "response.reasoning_summary_text.delta":
print(
Expand Down Expand Up @@ -82,7 +82,7 @@ async def get_response_with_reasoning_content():
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
prompt=None
prompt=None,
)

# Extract reasoning content and regular content from the response
Expand Down
4 changes: 1 addition & 3 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,7 @@ async def get_prompt(
"""Get the prompt for the agent."""
return await PromptUtil.to_model_input(self.prompt, run_context, self)

async def get_mcp_tools(
self, run_context: RunContextWrapper[TContext]
) -> list[Tool]:
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
"""Fetches the available tools from the MCP servers."""
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
return await MCPUtil.get_all_function_tools(
Expand Down
35 changes: 31 additions & 4 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mcp.client.sse import sse_client
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
from mcp.shared.message import SessionMessage
from mcp.types import CallToolResult, InitializeResult
from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult
from typing_extensions import NotRequired, TypedDict

from ..exceptions import UserError
Expand Down Expand Up @@ -63,6 +63,18 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
"""Invoke a tool on the server."""
pass

@abc.abstractmethod
async def list_prompts(self) -> ListPromptsResult | None:
"""List the prompts available on the server."""
pass

@abc.abstractmethod
async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult | None:
"""Returns an existing prompt from the server."""
pass


class _MCPServerWithClientSession(MCPServer, abc.ABC):
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
Expand Down Expand Up @@ -118,9 +130,7 @@ async def _apply_tool_filter(
return await self._apply_dynamic_tool_filter(tools, run_context, agent)

def _apply_static_tool_filter(
self,
tools: list[MCPTool],
static_filter: ToolFilterStatic
self, tools: list[MCPTool], static_filter: ToolFilterStatic
) -> list[MCPTool]:
"""Apply static tool filtering based on allowlist and blocklist."""
filtered_tools = tools
Expand Down Expand Up @@ -261,6 +271,23 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C

return await self.session.call_tool(tool_name, arguments)

async def list_prompts(
self,
) -> ListPromptsResult | None:
"""List the prompts available on the server."""
if not self.session:
raise UserError("Server not initialized. Make sure you call `connect()` first.")

return await self.session.list_prompts()

async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult | None:
if not self.session:
raise UserError("Server not initialized. Make sure you call `connect()` first.")

return await self.session.get_prompt(name, arguments)

async def cleanup(self):
"""Cleanup the server."""
async with self._cleanup_lock:
Expand Down
13 changes: 7 additions & 6 deletions src/agents/model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
class _OmitTypeAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
def validate_from_none(value: None) -> _Omit:
return _Omit()
Expand All @@ -39,13 +39,14 @@ def validate_from_none(value: None) -> _Omit:
from_none_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: None
),
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None),
)


Omit = Annotated[_Omit, _OmitTypeAnnotation]
Headers: TypeAlias = Mapping[str, Union[str, Omit]]


@dataclass
class ModelSettings:
"""Settings to use when calling an LLM.
Expand Down
1 change: 0 additions & 1 deletion src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from .util._types import MaybeAwaitable

if TYPE_CHECKING:

from .agent import Agent

ToolParams = ParamSpec("ToolParams")
Expand Down
14 changes: 13 additions & 1 deletion tests/mcp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from mcp import Tool as MCPTool
from mcp.types import CallToolResult, TextContent
from mcp.types import CallToolResult, GetPromptResult, ListPromptsResult, TextContent

from agents.mcp import MCPServer
from agents.mcp.server import _MCPServerWithClientSession
Expand Down Expand Up @@ -57,10 +57,12 @@ def name(self) -> str:
class FakeMCPServer(MCPServer):
def __init__(
self,
prompts: ListPromptsResult | None = None,
tools: list[MCPTool] | None = None,
tool_filter: ToolFilter = None,
server_name: str = "fake_mcp_server",
):
self.prompts = prompts
self.tools: list[MCPTool] = tools or []
self.tool_calls: list[str] = []
self.tool_results: list[str] = []
Expand Down Expand Up @@ -94,6 +96,16 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
content=[TextContent(text=self.tool_results[-1], type="text")],
)

async def list_prompts(
self,
) -> ListPromptsResult | None:
return self.prompts

async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
) -> GetPromptResult | None:
return None

@property
def name(self) -> str:
return self._server_name
9 changes: 6 additions & 3 deletions tests/mcp/test_tool_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
external dependencies (processes, network connections) and ensure fast, reliable unit tests.
FakeMCPServer delegates filtering logic to the real _MCPServerWithClientSession implementation.
"""

import asyncio

import pytest
Expand All @@ -27,6 +28,7 @@ def create_test_context() -> RunContextWrapper:

# === Static Tool Filtering Tests ===


@pytest.mark.asyncio
async def test_static_tool_filtering():
"""Test all static tool filtering scenarios: allowed, blocked, both, none, etc."""
Expand Down Expand Up @@ -55,7 +57,7 @@ async def test_static_tool_filtering():
# Test both filters together (allowed first, then blocked)
server.tool_filter = {
"allowed_tool_names": ["tool1", "tool2", "tool3"],
"blocked_tool_names": ["tool3"]
"blocked_tool_names": ["tool3"],
}
tools = await server.list_tools(run_context, agent)
assert len(tools) == 2
Expand All @@ -68,8 +70,7 @@ async def test_static_tool_filtering():

# Test helper function
server.tool_filter = create_static_tool_filter(
allowed_tool_names=["tool1", "tool2"],
blocked_tool_names=["tool2"]
allowed_tool_names=["tool1", "tool2"], blocked_tool_names=["tool2"]
)
tools = await server.list_tools(run_context, agent)
assert len(tools) == 1
Expand All @@ -78,6 +79,7 @@ async def test_static_tool_filtering():

# === Dynamic Tool Filtering Core Tests ===


@pytest.mark.asyncio
async def test_dynamic_filter_sync_and_async():
"""Test both synchronous and asynchronous dynamic filters"""
Expand Down Expand Up @@ -181,6 +183,7 @@ def error_prone_filter(context: ToolFilterContext, tool: MCPTool) -> bool:

# === Integration Tests ===


@pytest.mark.asyncio
async def test_agent_dynamic_filtering_integration():
"""Test dynamic filtering integration with Agent methods"""
Expand Down
2 changes: 1 addition & 1 deletion tests/model_settings/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def test_extra_args_resolve_both_none() -> None:
assert resolved.temperature == 0.5
assert resolved.top_p == 0.9

def test_pydantic_serialization() -> None:

def test_pydantic_serialization() -> None:
"""Tests whether ModelSettings can be serialized with Pydantic."""

# First, lets create a ModelSettings instance
Expand Down
14 changes: 5 additions & 9 deletions tests/test_reasoning_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@
# Helper functions to create test objects consistently
def create_content_delta(content: str) -> dict[str, Any]:
"""Create a delta dictionary with regular content"""
return {
"content": content,
"role": None,
"function_call": None,
"tool_calls": None
}
return {"content": content, "role": None, "function_call": None, "tool_calls": None}


def create_reasoning_delta(content: str) -> dict[str, Any]:
"""Create a delta dictionary with reasoning content. The Only difference is reasoning_content"""
Expand All @@ -41,7 +37,7 @@ def create_reasoning_delta(content: str) -> dict[str, Any]:
"role": None,
"function_call": None,
"tool_calls": None,
"reasoning_content": content
"reasoning_content": content,
}


Expand Down Expand Up @@ -188,7 +184,7 @@ async def test_get_response_with_reasoning_content(monkeypatch) -> None:
"index": 0,
"finish_reason": "stop",
"message": msg_with_reasoning,
"delta": None
"delta": None,
}

chat = ChatCompletion(
Expand Down Expand Up @@ -274,7 +270,7 @@ async def patched_fetch_response(self, *args, **kwargs):
handoffs=[],
tracing=ModelTracing.DISABLED,
previous_response_id=None,
prompt=None
prompt=None,
):
output_events.append(event)

Expand Down