Skip to content

Commit 3a77514

Browse files
qandrewAndrew Xia
andauthored
[responsesAPI] support input output messages for non harmony models (vllm-project#29549)
Signed-off-by: Andrew Xia <[email protected]> Co-authored-by: Andrew Xia <[email protected]>
1 parent bbfb55c commit 3a77514

File tree

4 files changed

+64
-11
lines changed

4 files changed

+64
-11
lines changed

tests/entrypoints/openai/test_response_api_simple.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,24 @@ async def test_basic(client: OpenAI, model_name: str):
4242
assert response.status == "completed"
4343

4444

45+
@pytest.mark.asyncio
46+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
47+
async def test_enable_response_messages(client: OpenAI, model_name: str):
48+
response = await client.responses.create(
49+
model=model_name,
50+
input="Hello?",
51+
extra_body={"enable_response_messages": True},
52+
)
53+
assert response.status == "completed"
54+
assert response.input_messages[0]["type"] == "raw_message_tokens"
55+
assert type(response.input_messages[0]["message"]) is str
56+
assert len(response.input_messages[0]["message"]) > 10
57+
assert type(response.input_messages[0]["tokens"][0]) is int
58+
assert type(response.output_messages[0]["message"]) is str
59+
assert len(response.output_messages[0]["message"]) > 10
60+
assert type(response.output_messages[0]["tokens"][0]) is int
61+
62+
4563
@pytest.mark.asyncio
4664
@pytest.mark.parametrize("model_name", [MODEL_NAME])
4765
async def test_reasoning_item(client: OpenAI, model_name: str):

vllm/entrypoints/context.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from vllm.entrypoints.openai.protocol import (
2525
ResponseInputOutputItem,
26+
ResponseRawMessageAndToken,
2627
ResponsesRequest,
2728
)
2829
from vllm.entrypoints.responses_utils import construct_tool_dicts
@@ -148,6 +149,8 @@ def _create_json_parse_error_messages(
148149

149150

150151
class SimpleContext(ConversationContext):
152+
"""This is a context that cannot handle MCP tool calls"""
153+
151154
def __init__(self):
152155
self.last_output = None
153156
self.num_prompt_tokens = 0
@@ -158,6 +161,9 @@ def __init__(self):
158161
# not implemented yet for SimpleContext
159162
self.all_turn_metrics = []
160163

164+
self.input_messages: list[ResponseRawMessageAndToken] = []
165+
self.output_messages: list[ResponseRawMessageAndToken] = []
166+
161167
def append_output(self, output) -> None:
162168
self.last_output = output
163169
if not isinstance(output, RequestOutput):
@@ -166,6 +172,22 @@ def append_output(self, output) -> None:
166172
self.num_cached_tokens = output.num_cached_tokens or 0
167173
self.num_output_tokens += len(output.outputs[0].token_ids or [])
168174

175+
if len(self.input_messages) == 0:
176+
output_prompt = output.prompt or ""
177+
output_prompt_token_ids = output.prompt_token_ids or []
178+
self.input_messages.append(
179+
ResponseRawMessageAndToken(
180+
message=output_prompt,
181+
tokens=output_prompt_token_ids,
182+
)
183+
)
184+
self.output_messages.append(
185+
ResponseRawMessageAndToken(
186+
message=output.outputs[0].text,
187+
tokens=output.outputs[0].token_ids,
188+
)
189+
)
190+
169191
def append_tool_output(self, output) -> None:
170192
raise NotImplementedError("Should not be called.")
171193

vllm/entrypoints/openai/protocol.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,20 @@ def serialize_messages(msgs):
15981598
return [serialize_message(msg) for msg in msgs] if msgs else None
15991599

16001600

1601+
class ResponseRawMessageAndToken(OpenAIBaseModel):
1602+
"""Class to show the raw message.
1603+
If message / tokens diverge, tokens is the source of truth"""
1604+
1605+
message: str
1606+
tokens: list[int]
1607+
type: Literal["raw_message_tokens"] = "raw_message_tokens"
1608+
1609+
1610+
ResponseInputOutputMessage: TypeAlias = (
1611+
list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken]
1612+
)
1613+
1614+
16011615
class ResponsesResponse(OpenAIBaseModel):
16021616
id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
16031617
created_at: int = Field(default_factory=lambda: int(time.time()))
@@ -1631,8 +1645,8 @@ class ResponsesResponse(OpenAIBaseModel):
16311645
# These are populated when enable_response_messages is set to True
16321646
# NOTE: custom serialization is needed
16331647
# see serialize_input_messages and serialize_output_messages
1634-
input_messages: list[ChatCompletionMessageParam] | None = None
1635-
output_messages: list[ChatCompletionMessageParam] | None = None
1648+
input_messages: ResponseInputOutputMessage | None = None
1649+
output_messages: ResponseInputOutputMessage | None = None
16361650
# --8<-- [end:responses-extra-params]
16371651

16381652
# NOTE: openAI harmony doesn't serialize TextContent properly,
@@ -1658,8 +1672,8 @@ def from_request(
16581672
output: list[ResponseOutputItem],
16591673
status: ResponseStatus,
16601674
usage: ResponseUsage | None = None,
1661-
input_messages: list[ChatCompletionMessageParam] | None = None,
1662-
output_messages: list[ChatCompletionMessageParam] | None = None,
1675+
input_messages: ResponseInputOutputMessage | None = None,
1676+
output_messages: ResponseInputOutputMessage | None = None,
16631677
) -> "ResponsesResponse":
16641678
incomplete_details: IncompleteDetails | None = None
16651679
if status == "incomplete":

vllm/entrypoints/openai/serving_responses.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
ResponseCompletedEvent,
8787
ResponseCreatedEvent,
8888
ResponseInProgressEvent,
89+
ResponseInputOutputMessage,
8990
ResponseReasoningPartAddedEvent,
9091
ResponseReasoningPartDoneEvent,
9192
ResponsesRequest,
@@ -629,8 +630,8 @@ async def responses_full_generator(
629630
# "completed" is implemented as the "catch-all" for now.
630631
status: ResponseStatus = "completed"
631632

632-
input_messages = None
633-
output_messages = None
633+
input_messages: ResponseInputOutputMessage | None = None
634+
output_messages: ResponseInputOutputMessage | None = None
634635
if self.use_harmony:
635636
assert isinstance(context, HarmonyContext)
636637
output = self._make_response_output_items_with_harmony(context)
@@ -670,12 +671,10 @@ async def responses_full_generator(
670671

671672
output = self._make_response_output_items(request, final_output, tokenizer)
672673

673-
# TODO: context for non-gptoss models doesn't use messages
674-
# so we can't get them out yet
675674
if request.enable_response_messages:
676-
raise NotImplementedError(
677-
"enable_response_messages is currently only supported for gpt-oss"
678-
)
675+
input_messages = context.input_messages
676+
output_messages = context.output_messages
677+
679678
# Calculate usage.
680679
assert final_res.prompt_token_ids is not None
681680
num_tool_output_tokens = 0

0 commit comments

Comments
 (0)