Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 119 additions & 14 deletions src/transformers/commands/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,11 @@ def chat_completion(request: Request, body: dict):
@app.post("/v1/responses")
def responses(request: dict):
self.validate_response_request(request=request)
# Support non-streaming mode when `stream=false` is provided
stream = request.get("stream", True)
if not stream:
response_obj = self.generate_response_non_streaming(request)
return JSONResponse(response_obj)

output = self.generate_response(request)
return StreamingResponse(output, media_type="text/event-stream")
Expand Down Expand Up @@ -1334,19 +1339,31 @@ def generate_with_cache(**kwargs):
results = "" # reset the results -> results will now track the final response
continue
else:
continue

response_output_text_delta = ResponseTextDeltaEvent(
type="response.output_text.delta",
item_id=f"msg_{request_id}",
sequence_number=sequence_number,
output_index=output_index,
content_index=content_index,
delta=result,
logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs
)
sequence_number += 1
yield self.build_response_event(response_output_text_delta)
response_output_text_delta = ResponseTextDeltaEvent(
type="response.output_text.delta",
item_id=f"msg_{request_id}",
sequence_number=sequence_number,
output_index=output_index,
content_index=content_index,
delta=result,
logprobs=[],
)
sequence_number += 1
yield self.build_response_event(response_output_text_delta)
else:
# Normal path: emit token deltas when not filtering CoT
if result:
response_output_text_delta = ResponseTextDeltaEvent(
type="response.output_text.delta",
item_id=f"msg_{request_id}",
sequence_number=sequence_number,
output_index=output_index,
content_index=content_index,
delta=result,
logprobs=[],
)
sequence_number += 1
yield self.build_response_event(response_output_text_delta)

# Signal the end of the text generation
response_output_text_done = ResponseTextDoneEvent(
Expand All @@ -1356,7 +1373,7 @@ def generate_with_cache(**kwargs):
output_index=output_index,
content_index=0,
text=results,
logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs
logprobs=[],
)
sequence_number += 1
yield self.build_response_event(response_output_text_done)
Expand Down Expand Up @@ -1455,6 +1472,94 @@ def generate_with_cache(**kwargs):

return stream_response(generation_streamer, request_id)

def generate_response_non_streaming(self, req: dict) -> dict:
"""
Generates an OpenAI Response in non-streaming mode (single JSON payload).

Args:
req (`dict`): The request to generate an OpenAI Response for.

Returns:
`dict`: The OpenAI `Response` serialized as a dict.
"""
model_id_and_revision = self.process_model_name(req["model"])
must_discard_cache = model_id_and_revision != self.last_model
self.last_model = model_id_and_revision
model, processor = self.load_model_and_processor(model_id_and_revision)

if isinstance(req["input"], str):
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
inputs.append({"role": "user", "content": req["input"]})
elif isinstance(req["input"], list):
if "instructions" in req:
if req["input"][0]["role"] != "system":
inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]]
else:
inputs = req["input"]
inputs[0]["content"] = req["instructions"]
else:
inputs = req["input"]
elif isinstance(req["input"], dict):
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
inputs.append(req["input"])
else:
raise ValueError("inputs should be a list, dict, or str")

inputs = processor.apply_chat_template(inputs, add_generation_prompt=True, return_tensors="pt")
inputs = inputs.to(model.device)
request_id = req.get("previous_response_id", "req_0")

# Temporary hack for GPTOSS 1: don't filter special tokens
skip_special_tokens = True
if "gptoss" in model.config.architectures[0].lower():
skip_special_tokens = False

generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)

last_kv_cache = None
if self.is_continuation(req) and not must_discard_cache:
seq_len = self.last_kv_cache.get_seq_length()
if inputs.shape[-1] > seq_len:
last_kv_cache = self.last_kv_cache

generate_output = model.generate(
inputs=inputs,
attention_mask=torch.ones_like(inputs),
generation_config=generation_config,
return_dict_in_generate=True,
past_key_values=last_kv_cache,
)
# save KV cache
self.last_kv_cache = generate_output.past_key_values

# Decode full text
full_text = processor.batch_decode(generate_output.sequences, skip_special_tokens=skip_special_tokens)[0]

created_at = time.time()
response_output_item = ResponseOutputMessage(
id=f"msg_{request_id}",
type="message",
status="completed",
role="assistant",
content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])],
annotations=[],
)
response_completed = Response(
id=f"resp_{request_id}",
created_at=created_at,
status="completed",
model=model_id_and_revision,
instructions=req.get("instructions"),
text={"format": {"type": "text"}},
output=[response_output_item],
object="response",
tools=[],
parallel_tool_calls=req.get("parallel_tool_calls", False),
tool_choice="auto",
metadata=req.get("metadata"),
)
return response_completed.model_dump(exclude_none=True)

def generate_transcription(self, req: dict) -> Generator[str, None, None]:
"""
Generates an OpenAI Transcription using the audio file.
Expand Down
67 changes: 48 additions & 19 deletions tests/commands/test_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,22 +670,23 @@ def test_request(self):
}
all_payloads = asyncio.run(self.run_server(request))

order_of_payloads = [
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseContentPartAddedEvent,
ResponseTextDeltaEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseContentPartDoneEvent,
ResponseOutputItemDoneEvent,
ResponseCompletedEvent,
]
# Allow variable number of delta events depending on tokenizer/streamer behavior
self.assertGreaterEqual(len(all_payloads), 8)

# Start markers
self.assertIsInstance(all_payloads[0], ResponseCreatedEvent)
self.assertIsInstance(all_payloads[1], ResponseInProgressEvent)
self.assertIsInstance(all_payloads[2], ResponseOutputItemAddedEvent)
self.assertIsInstance(all_payloads[3], ResponseContentPartAddedEvent)

# At least one delta event during streaming
self.assertTrue(any(isinstance(p, ResponseTextDeltaEvent) for p in all_payloads[4:-4]))

self.assertEqual(len(all_payloads), 10)
for payload, payload_type in zip(all_payloads, order_of_payloads):
self.assertIsInstance(payload, payload_type)
# Closing markers
self.assertIsInstance(all_payloads[-4], ResponseTextDoneEvent)
self.assertIsInstance(all_payloads[-3], ResponseContentPartDoneEvent)
self.assertIsInstance(all_payloads[-2], ResponseOutputItemDoneEvent)
self.assertIsInstance(all_payloads[-1], ResponseCompletedEvent)

# TODO: one test for each request flag, to confirm it is working as expected
# TODO: speed-based test to confirm that KV cache is working across requests
Expand Down Expand Up @@ -716,6 +717,8 @@ def test_full_request(self):
"input": "Tell me what you can do.",
"stream": True,
"max_output_tokens": 30,
# Disable sampling for deterministic output
"temperature": 0,
}
all_payloads = asyncio.run(self.run_server(request))

Expand All @@ -725,12 +728,38 @@ def test_full_request(self):
full_text += token.delta

# Verify that the system prompt went through.
self.assertTrue(
full_text.startswith(
"As an AI language model, I am designed to assist with various tasks and provide information on different topics related to sports."
)
# With deterministic decoding, exact wording can still vary across versions.
# Assert non-empty output and that it references sports.
self.assertTrue(len(full_text) > 0)
self.assertIn("sports", full_text.lower())

@slow
def test_non_streaming_request(self):
"""Tests that an inference using the Responses API with stream=False returns a single Response payload."""
from openai import OpenAI
from openai.types.responses import Response as OpenAIResponse

client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="<KEY>")
resp = client.responses.create(
model="Qwen/Qwen2.5-0.5B-Instruct",
instructions="You are a helpful assistant.",
input="Hello!",
stream=False,
max_output_tokens=5,
)

# Should be a single Response object with completed status and one output item containing text
self.assertIsInstance(resp, OpenAIResponse)
self.assertEqual(resp.status, "completed")
self.assertTrue(len(resp.output) >= 1)
first_item = resp.output[0]
self.assertEqual(first_item.type, "message")
self.assertEqual(first_item.status, "completed")
self.assertTrue(len(first_item.content) >= 1)
first_part = first_item.content[0]
self.assertEqual(first_part.type, "output_text")
self.assertIsInstance(first_part.text, str)


class ServeInfrastructureTest(unittest.TestCase):
@classmethod
Expand Down