20
20
from typing import AsyncIterator , Dict , Optional , Set , Tuple , Union
21
21
22
22
import uvloop
23
- from fastapi import APIRouter , FastAPI , HTTPException , Request
23
+ from fastapi import APIRouter , Depends , FastAPI , HTTPException , Request
24
24
from fastapi .exceptions import RequestValidationError
25
25
from fastapi .middleware .cors import CORSMiddleware
26
26
from fastapi .responses import JSONResponse , Response , StreamingResponse
@@ -253,6 +253,15 @@ def _cleanup_ipc_path():
253
253
multiprocess .mark_process_dead (engine_process .pid )
254
254
255
255
256
+ async def validate_json_request (raw_request : Request ):
257
+ content_type = raw_request .headers .get ("content-type" , "" ).lower ()
258
+ if content_type != "application/json" :
259
+ raise HTTPException (
260
+ status_code = HTTPStatus .UNSUPPORTED_MEDIA_TYPE ,
261
+ detail = "Unsupported Media Type: Only 'application/json' is allowed"
262
+ )
263
+
264
+
256
265
router = APIRouter ()
257
266
258
267
@@ -336,7 +345,7 @@ async def ping(raw_request: Request) -> Response:
336
345
return await health (raw_request )
337
346
338
347
339
- @router .post ("/tokenize" )
348
+ @router .post ("/tokenize" , dependencies = [ Depends ( validate_json_request )] )
340
349
@with_cancellation
341
350
async def tokenize (request : TokenizeRequest , raw_request : Request ):
342
351
handler = tokenization (raw_request )
@@ -351,7 +360,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
351
360
assert_never (generator )
352
361
353
362
354
- @router .post ("/detokenize" )
363
+ @router .post ("/detokenize" , dependencies = [ Depends ( validate_json_request )] )
355
364
@with_cancellation
356
365
async def detokenize (request : DetokenizeRequest , raw_request : Request ):
357
366
handler = tokenization (raw_request )
@@ -380,7 +389,8 @@ async def show_version():
380
389
return JSONResponse (content = ver )
381
390
382
391
383
- @router .post ("/v1/chat/completions" )
392
+ @router .post ("/v1/chat/completions" ,
393
+ dependencies = [Depends (validate_json_request )])
384
394
@with_cancellation
385
395
async def create_chat_completion (request : ChatCompletionRequest ,
386
396
raw_request : Request ):
@@ -401,7 +411,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
401
411
return StreamingResponse (content = generator , media_type = "text/event-stream" )
402
412
403
413
404
- @router .post ("/v1/completions" )
414
+ @router .post ("/v1/completions" , dependencies = [ Depends ( validate_json_request )] )
405
415
@with_cancellation
406
416
async def create_completion (request : CompletionRequest , raw_request : Request ):
407
417
handler = completion (raw_request )
@@ -419,7 +429,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
419
429
return StreamingResponse (content = generator , media_type = "text/event-stream" )
420
430
421
431
422
- @router .post ("/v1/embeddings" )
432
+ @router .post ("/v1/embeddings" , dependencies = [ Depends ( validate_json_request )] )
423
433
@with_cancellation
424
434
async def create_embedding (request : EmbeddingRequest , raw_request : Request ):
425
435
handler = embedding (raw_request )
@@ -465,7 +475,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
465
475
assert_never (generator )
466
476
467
477
468
- @router .post ("/pooling" )
478
+ @router .post ("/pooling" , dependencies = [ Depends ( validate_json_request )] )
469
479
@with_cancellation
470
480
async def create_pooling (request : PoolingRequest , raw_request : Request ):
471
481
handler = pooling (raw_request )
@@ -483,7 +493,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
483
493
assert_never (generator )
484
494
485
495
486
- @router .post ("/score" )
496
+ @router .post ("/score" , dependencies = [ Depends ( validate_json_request )] )
487
497
@with_cancellation
488
498
async def create_score (request : ScoreRequest , raw_request : Request ):
489
499
handler = score (raw_request )
@@ -501,7 +511,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
501
511
assert_never (generator )
502
512
503
513
504
- @router .post ("/v1/score" )
514
+ @router .post ("/v1/score" , dependencies = [ Depends ( validate_json_request )] )
505
515
@with_cancellation
506
516
async def create_score_v1 (request : ScoreRequest , raw_request : Request ):
507
517
logger .warning (
@@ -511,7 +521,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
511
521
return await create_score (request , raw_request )
512
522
513
523
514
- @router .post ("/rerank" )
524
+ @router .post ("/rerank" , dependencies = [ Depends ( validate_json_request )] )
515
525
@with_cancellation
516
526
async def do_rerank (request : RerankRequest , raw_request : Request ):
517
527
handler = rerank (raw_request )
@@ -528,7 +538,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
528
538
assert_never (generator )
529
539
530
540
531
- @router .post ("/v1/rerank" )
541
+ @router .post ("/v1/rerank" , dependencies = [ Depends ( validate_json_request )] )
532
542
@with_cancellation
533
543
async def do_rerank_v1 (request : RerankRequest , raw_request : Request ):
534
544
logger .warning_once (
@@ -539,7 +549,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
539
549
return await do_rerank (request , raw_request )
540
550
541
551
542
- @router .post ("/v2/rerank" )
552
+ @router .post ("/v2/rerank" , dependencies = [ Depends ( validate_json_request )] )
543
553
@with_cancellation
544
554
async def do_rerank_v2 (request : RerankRequest , raw_request : Request ):
545
555
return await do_rerank (request , raw_request )
@@ -583,7 +593,7 @@ async def reset_prefix_cache(raw_request: Request):
583
593
return Response (status_code = 200 )
584
594
585
595
586
- @router .post ("/invocations" )
596
+ @router .post ("/invocations" , dependencies = [ Depends ( validate_json_request )] )
587
597
async def invocations (raw_request : Request ):
588
598
"""
589
599
For SageMaker, routes requests to other handlers based on model `task`.
@@ -633,7 +643,8 @@ async def stop_profile(raw_request: Request):
633
643
"Lora dynamic loading & unloading is enabled in the API server. "
634
644
"This should ONLY be used for local development!" )
635
645
636
- @router .post ("/v1/load_lora_adapter" )
646
+ @router .post ("/v1/load_lora_adapter" ,
647
+ dependencies = [Depends (validate_json_request )])
637
648
async def load_lora_adapter (request : LoadLoraAdapterRequest ,
638
649
raw_request : Request ):
639
650
handler = models (raw_request )
@@ -644,7 +655,8 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest,
644
655
645
656
return Response (status_code = 200 , content = response )
646
657
647
- @router .post ("/v1/unload_lora_adapter" )
658
+ @router .post ("/v1/unload_lora_adapter" ,
659
+ dependencies = [Depends (validate_json_request )])
648
660
async def unload_lora_adapter (request : UnloadLoraAdapterRequest ,
649
661
raw_request : Request ):
650
662
handler = models (raw_request )
0 commit comments