Skip to content

Commit d2b880d

Browse files
vaibhavjainwizdtrifiro
authored andcommitted
[Bugfix] Missing Content Type returns 500 Internal Server Error (vllm-project#13193)
1 parent 4b77141 commit d2b880d

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

tests/entrypoints/openai/test_basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,19 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
156156
max_tokens=10)
157157

158158
assert len(response.choices) == 1
159+
160+
161+
@pytest.mark.asyncio
162+
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
163+
164+
chat_input = [{"role": "user", "content": "Write a long story"}]
165+
client = server.get_async_client()
166+
167+
with pytest.raises(openai.APIStatusError):
168+
await client.chat.completions.create(
169+
messages=chat_input,
170+
model=MODEL_NAME,
171+
max_tokens=10000,
172+
extra_headers={
173+
"Content-Type": "application/x-www-form-urlencoded"
174+
})

vllm/entrypoints/openai/api_server.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
2121

2222
import uvloop
23-
from fastapi import APIRouter, FastAPI, HTTPException, Request
23+
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
2424
from fastapi.exceptions import RequestValidationError
2525
from fastapi.middleware.cors import CORSMiddleware
2626
from fastapi.responses import JSONResponse, Response, StreamingResponse
@@ -253,6 +253,15 @@ def _cleanup_ipc_path():
253253
multiprocess.mark_process_dead(engine_process.pid)
254254

255255

256+
async def validate_json_request(raw_request: Request):
257+
content_type = raw_request.headers.get("content-type", "").lower()
258+
if content_type != "application/json":
259+
raise HTTPException(
260+
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
261+
detail="Unsupported Media Type: Only 'application/json' is allowed"
262+
)
263+
264+
256265
router = APIRouter()
257266

258267

@@ -336,7 +345,7 @@ async def ping(raw_request: Request) -> Response:
336345
return await health(raw_request)
337346

338347

339-
@router.post("/tokenize")
348+
@router.post("/tokenize", dependencies=[Depends(validate_json_request)])
340349
@with_cancellation
341350
async def tokenize(request: TokenizeRequest, raw_request: Request):
342351
handler = tokenization(raw_request)
@@ -351,7 +360,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
351360
assert_never(generator)
352361

353362

354-
@router.post("/detokenize")
363+
@router.post("/detokenize", dependencies=[Depends(validate_json_request)])
355364
@with_cancellation
356365
async def detokenize(request: DetokenizeRequest, raw_request: Request):
357366
handler = tokenization(raw_request)
@@ -380,7 +389,8 @@ async def show_version():
380389
return JSONResponse(content=ver)
381390

382391

383-
@router.post("/v1/chat/completions")
392+
@router.post("/v1/chat/completions",
393+
dependencies=[Depends(validate_json_request)])
384394
@with_cancellation
385395
async def create_chat_completion(request: ChatCompletionRequest,
386396
raw_request: Request):
@@ -401,7 +411,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
401411
return StreamingResponse(content=generator, media_type="text/event-stream")
402412

403413

404-
@router.post("/v1/completions")
414+
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
405415
@with_cancellation
406416
async def create_completion(request: CompletionRequest, raw_request: Request):
407417
handler = completion(raw_request)
@@ -419,7 +429,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
419429
return StreamingResponse(content=generator, media_type="text/event-stream")
420430

421431

422-
@router.post("/v1/embeddings")
432+
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
423433
@with_cancellation
424434
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
425435
handler = embedding(raw_request)
@@ -465,7 +475,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
465475
assert_never(generator)
466476

467477

468-
@router.post("/pooling")
478+
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
469479
@with_cancellation
470480
async def create_pooling(request: PoolingRequest, raw_request: Request):
471481
handler = pooling(raw_request)
@@ -483,7 +493,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
483493
assert_never(generator)
484494

485495

486-
@router.post("/score")
496+
@router.post("/score", dependencies=[Depends(validate_json_request)])
487497
@with_cancellation
488498
async def create_score(request: ScoreRequest, raw_request: Request):
489499
handler = score(raw_request)
@@ -501,7 +511,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
501511
assert_never(generator)
502512

503513

504-
@router.post("/v1/score")
514+
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
505515
@with_cancellation
506516
async def create_score_v1(request: ScoreRequest, raw_request: Request):
507517
logger.warning(
@@ -511,7 +521,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
511521
return await create_score(request, raw_request)
512522

513523

514-
@router.post("/rerank")
524+
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
515525
@with_cancellation
516526
async def do_rerank(request: RerankRequest, raw_request: Request):
517527
handler = rerank(raw_request)
@@ -528,7 +538,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
528538
assert_never(generator)
529539

530540

531-
@router.post("/v1/rerank")
541+
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
532542
@with_cancellation
533543
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
534544
logger.warning_once(
@@ -539,7 +549,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
539549
return await do_rerank(request, raw_request)
540550

541551

542-
@router.post("/v2/rerank")
552+
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
543553
@with_cancellation
544554
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
545555
return await do_rerank(request, raw_request)
@@ -583,7 +593,7 @@ async def reset_prefix_cache(raw_request: Request):
583593
return Response(status_code=200)
584594

585595

586-
@router.post("/invocations")
596+
@router.post("/invocations", dependencies=[Depends(validate_json_request)])
587597
async def invocations(raw_request: Request):
588598
"""
589599
For SageMaker, routes requests to other handlers based on model `task`.
@@ -633,7 +643,8 @@ async def stop_profile(raw_request: Request):
633643
"Lora dynamic loading & unloading is enabled in the API server. "
634644
"This should ONLY be used for local development!")
635645

636-
@router.post("/v1/load_lora_adapter")
646+
@router.post("/v1/load_lora_adapter",
647+
dependencies=[Depends(validate_json_request)])
637648
async def load_lora_adapter(request: LoadLoraAdapterRequest,
638649
raw_request: Request):
639650
handler = models(raw_request)
@@ -644,7 +655,8 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest,
644655

645656
return Response(status_code=200, content=response)
646657

647-
@router.post("/v1/unload_lora_adapter")
658+
@router.post("/v1/unload_lora_adapter",
659+
dependencies=[Depends(validate_json_request)])
648660
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
649661
raw_request: Request):
650662
handler = models(raw_request)

0 commit comments

Comments
 (0)