Skip to content

Commit e5017cd

Browse files
qandrewAndrew Xia
andauthored
[gpt-oss] disable tool server initialization if no tool in request (vllm-project#25790)
Signed-off-by: Andrew Xia <[email protected]> Signed-off-by: Andrew Xia <[email protected]> Co-authored-by: Andrew Xia <[email protected]>
1 parent 6a7796e commit e5017cd

File tree

2 files changed

+148
-12
lines changed

2 files changed

+148
-12
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from contextlib import AsyncExitStack
5+
from unittest.mock import AsyncMock, MagicMock
6+
7+
import pytest
8+
import pytest_asyncio
9+
10+
from vllm.entrypoints.context import ConversationContext
11+
from vllm.entrypoints.openai.protocol import ResponsesRequest
12+
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
13+
from vllm.entrypoints.tool_server import ToolServer
14+
15+
16+
class MockConversationContext(ConversationContext):
17+
"""Mock conversation context for testing"""
18+
19+
def __init__(self):
20+
self.init_tool_sessions_called = False
21+
self.init_tool_sessions_args = None
22+
self.init_tool_sessions_kwargs = None
23+
24+
def append_output(self, output) -> None:
25+
pass
26+
27+
async def call_tool(self):
28+
return []
29+
30+
def need_builtin_tool_call(self) -> bool:
31+
return False
32+
33+
def render_for_completion(self):
34+
return []
35+
36+
async def init_tool_sessions(self, tool_server, exit_stack, request_id,
37+
mcp_tools):
38+
self.init_tool_sessions_called = True
39+
self.init_tool_sessions_args = (tool_server, exit_stack, request_id,
40+
mcp_tools)
41+
42+
async def cleanup_session(self) -> None:
43+
pass
44+
45+
46+
@pytest.fixture
47+
def mock_serving_responses():
48+
"""Create a mock OpenAIServingResponses instance"""
49+
serving_responses = MagicMock(spec=OpenAIServingResponses)
50+
serving_responses.tool_server = MagicMock(spec=ToolServer)
51+
return serving_responses
52+
53+
54+
@pytest.fixture
55+
def mock_context():
56+
"""Create a mock conversation context"""
57+
return MockConversationContext()
58+
59+
60+
@pytest.fixture
61+
def mock_exit_stack():
62+
"""Create a mock async exit stack"""
63+
return MagicMock(spec=AsyncExitStack)
64+
65+
66+
class TestInitializeToolSessions:
67+
"""Test class for _initialize_tool_sessions method"""
68+
69+
@pytest_asyncio.fixture
70+
async def serving_responses_instance(self):
71+
"""Create a real OpenAIServingResponses instance for testing"""
72+
# Create minimal mocks for required dependencies
73+
engine_client = MagicMock()
74+
engine_client.get_model_config = AsyncMock()
75+
76+
model_config = MagicMock()
77+
model_config.hf_config.model_type = "test"
78+
model_config.get_diff_sampling_param.return_value = {}
79+
80+
models = MagicMock()
81+
82+
tool_server = MagicMock(spec=ToolServer)
83+
84+
# Create the actual instance
85+
instance = OpenAIServingResponses(
86+
engine_client=engine_client,
87+
model_config=model_config,
88+
models=models,
89+
request_logger=None,
90+
chat_template=None,
91+
chat_template_content_format="auto",
92+
tool_server=tool_server,
93+
)
94+
95+
return instance
96+
97+
@pytest.mark.asyncio
98+
async def test_initialize_tool_sessions(self, serving_responses_instance,
99+
mock_context, mock_exit_stack):
100+
"""Test that method works correctly with only MCP tools"""
101+
102+
request = ResponsesRequest(input="test input", tools=[])
103+
104+
# Call the method
105+
await serving_responses_instance._initialize_tool_sessions(
106+
request, mock_context, mock_exit_stack)
107+
assert mock_context.init_tool_sessions_called is False
108+
109+
# Create only MCP tools
110+
tools = [
111+
{
112+
"type": "web_search_preview"
113+
},
114+
{
115+
"type": "code_interpreter",
116+
"container": {
117+
"type": "auto"
118+
}
119+
},
120+
]
121+
122+
request = ResponsesRequest(input="test input", tools=tools)
123+
124+
# Call the method
125+
await serving_responses_instance._initialize_tool_sessions(
126+
request, mock_context, mock_exit_stack)
127+
128+
# Verify that init_tool_sessions was called
129+
assert mock_context.init_tool_sessions_called

vllm/entrypoints/openai/serving_responses.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,19 @@ def _make_request_with_harmony(
445445

446446
return messages, [prompt_token_ids], [engine_prompt]
447447

448+
async def _initialize_tool_sessions(self, request: ResponsesRequest,
449+
context: ConversationContext,
450+
exit_stack: AsyncExitStack):
451+
# we should only initialize the tool session if the request needs tools
452+
if len(request.tools) == 0:
453+
return
454+
mcp_tools = {
455+
tool.server_label: tool
456+
for tool in request.tools if tool.type == "mcp"
457+
}
458+
await context.init_tool_sessions(self.tool_server, exit_stack,
459+
request.request_id, mcp_tools)
460+
448461
async def responses_full_generator(
449462
self,
450463
request: ResponsesRequest,
@@ -461,12 +474,8 @@ async def responses_full_generator(
461474

462475
async with AsyncExitStack() as exit_stack:
463476
try:
464-
mcp_tools = {
465-
tool.server_label: tool
466-
for tool in request.tools if tool.type == "mcp"
467-
}
468-
await context.init_tool_sessions(self.tool_server, exit_stack,
469-
request.request_id, mcp_tools)
477+
await self._initialize_tool_sessions(request, context,
478+
exit_stack)
470479
async for _ in result_generator:
471480
pass
472481
except asyncio.CancelledError:
@@ -1650,12 +1659,10 @@ def _increment_sequence_number_and_return(
16501659
async with AsyncExitStack() as exit_stack:
16511660
processer = None
16521661
if self.use_harmony:
1653-
mcp_tools = {
1654-
tool.server_label: tool
1655-
for tool in request.tools if tool.type == "mcp"
1656-
}
1657-
await context.init_tool_sessions(self.tool_server, exit_stack,
1658-
request.request_id, mcp_tools)
1662+
# TODO: in streaming, we noticed this bug:
1663+
# https://github.com/vllm-project/vllm/issues/25697
1664+
await self._initialize_tool_sessions(request, context,
1665+
exit_stack)
16591666
processer = self._process_harmony_streaming_events
16601667
else:
16611668
processer = self._process_simple_streaming_events

0 commit comments

Comments
 (0)