1919from huggingface_inference_toolkit .handler import get_inference_handler_either_custom_or_default_handler
2020from huggingface_inference_toolkit .serialization .base import ContentType
2121from huggingface_inference_toolkit .serialization .json_utils import Jsoner
22- from huggingface_inference_toolkit .utils import _load_repository_from_hf
22+ from huggingface_inference_toolkit .utils import _load_repository_from_hf , convert_params_to_int_or_bool
2323
2424
2525def config_logging (level = logging .INFO ):
@@ -64,8 +64,6 @@ async def health(request):
6464
6565async def predict (request ):
6666 try :
67- # tracks request time
68- start_time = perf_counter ()
6967 # extracts content from request
7068 content_type = request .headers .get ("content-Type" , None )
7169 # try to deserialize payload
@@ -74,13 +72,16 @@ async def predict(request):
7472 if "inputs" not in deserialized_body :
7573 raise ValueError (f"Body needs to provide a inputs key, recieved: { orjson .dumps (deserialized_body )} " )
7674
75+ # check for query parameter and add them to the body
76+ if request .query_params and "parameters" not in deserialized_body :
77+ deserialized_body ["parameters" ] = convert_params_to_int_or_bool (dict (request .query_params ))
78+ print (deserialized_body )
79+
80+ # tracks request time
81+ start_time = perf_counter ()
7782 # run async not blocking call
7883 pred = await async_handler_call (inference_handler , deserialized_body )
79- # run sync blocking call -> slighty faster for < 200ms prediction time
80- # pred = inference_handler(deserialized_body)
81-
8284 # log request time
83- # TODO: repalce with middleware
8485 logger .info (f"POST { request .url .path } | Duration: { (perf_counter ()- start_time ) * 1000 :.2f} ms" )
8586
8687 # response extracts content from request
0 commit comments