@@ -54,6 +54,7 @@ async def prepare_model_artifacts():
5454
5555
5656def _eager_model_dl ():
57+ logger .debug ("Model download" )
5758 global MODEL_DOWNLOADED
5859 from huggingface_inference_toolkit .heavy_utils import load_repository_from_hf
5960 # 1. check if model artifacts available in HF_MODEL_DIR
@@ -81,6 +82,8 @@ def _eager_model_dl():
8182 Provided values are:
8283 HF_MODEL_DIR: { HF_MODEL_DIR } and HF_MODEL_ID:{ HF_MODEL_ID } """
8384 )
85+ else :
86+ logger .debug ("Model already downloaded in %s" , HF_MODEL_DIR )
8487 MODEL_DOWNLOADED = True
8588
8689
@@ -101,95 +104,99 @@ async def metrics(request):
101104
102105
103106async def predict (request ):
104- global INFERENCE_HANDLERS
105-
106- if not MODEL_DOWNLOADED :
107- with MODEL_DL_LOCK :
108- _eager_model_dl ()
109- try :
110- task = request .path_params .get ("task" , HF_TASK )
111- # extracts content from request
112- content_type = request .headers .get ("content-Type" , os .environ .get ("DEFAULT_CONTENT_TYPE" , "" )).lower ()
113- # try to deserialize payload
114- deserialized_body = ContentType .get_deserializer (content_type , task ).deserialize (
115- await request .body ()
116- )
117- # checks if input schema is correct
118- if "inputs" not in deserialized_body and "instances" not in deserialized_body :
119- raise ValueError (
120- f"Body needs to provide a inputs key, received: { orjson .dumps (deserialized_body )} "
121- )
122-
123- # Decode base64 audio inputs before running inference
124- if "parameters" in deserialized_body and HF_TASK in {
125- "automatic-speech-recognition" ,
126- "audio-classification" ,
127- }:
128- # Be more strict on base64 decoding, the provided string should valid base64 encoded data
129- deserialized_body ["inputs" ] = base64 .b64decode (
130- deserialized_body ["inputs" ], validate = True
131- )
132-
133- # check for query parameter and add them to the body
134- if request .query_params and "parameters" not in deserialized_body :
135- deserialized_body ["parameters" ] = convert_params_to_int_or_bool (
136- dict (request .query_params )
107+ with idle .request_witnesses ():
108+ logger .debug ("Received request, scope %s" , request .scope )
109+
110+ global INFERENCE_HANDLERS
111+
112+ if not MODEL_DOWNLOADED :
113+ with MODEL_DL_LOCK :
114+ await asyncio .to_thread (_eager_model_dl )
115+ try :
116+ task = request .path_params .get ("task" , HF_TASK )
117+ # extracts content from request
118+ content_type = request .headers .get ("content-Type" , os .environ .get ("DEFAULT_CONTENT_TYPE" , "" )).lower ()
119+ # try to deserialize payload
120+ deserialized_body = ContentType .get_deserializer (content_type , task ).deserialize (
121+ await request .body ()
137122 )
138-
139- # We lazily load pipelines for alt tasks
140-
141- if task == "feature-extraction" and HF_TASK in [
142- "sentence-similarity" ,
143- "sentence-embeddings" ,
144- "sentence-ranking" ,
145- ]:
146- task = "sentence-embeddings"
147- inference_handler = INFERENCE_HANDLERS .get (task )
148- if not inference_handler :
149- with INFERENCE_HANDLERS_LOCK :
150- if task not in INFERENCE_HANDLERS :
151- inference_handler = get_inference_handler_either_custom_or_default_handler (
152- HF_MODEL_DIR , task = task )
153- INFERENCE_HANDLERS [task ] = inference_handler
154- else :
155- inference_handler = INFERENCE_HANDLERS [task ]
156- # tracks request time
157- start_time = perf_counter ()
158-
159- if should_discard_left () and isinstance (inference_handler , HuggingFaceHandler ):
160- deserialized_body ['handler_params' ] = {
161- 'request' : request
162- }
163- with idle .request_witnesses ():
123+ # checks if input schema is correct
124+ if "inputs" not in deserialized_body and "instances" not in deserialized_body :
125+ raise ValueError (
126+ f"Body needs to provide a inputs key, received: { orjson .dumps (deserialized_body )} "
127+ )
128+
129+ # Decode base64 audio inputs before running inference
130+ if "parameters" in deserialized_body and HF_TASK in {
131+ "automatic-speech-recognition" ,
132+ "audio-classification" ,
133+ }:
134+ # Be more strict on base64 decoding, the provided string should valid base64 encoded data
135+ deserialized_body ["inputs" ] = base64 .b64decode (
136+ deserialized_body ["inputs" ], validate = True
137+ )
138+
139+ # check for query parameter and add them to the body
140+ if request .query_params and "parameters" not in deserialized_body :
141+ deserialized_body ["parameters" ] = convert_params_to_int_or_bool (
142+ dict (request .query_params )
143+ )
144+
145+ # We lazily load pipelines for alt tasks
146+
147+ if task == "feature-extraction" and HF_TASK in [
148+ "sentence-similarity" ,
149+ "sentence-embeddings" ,
150+ "sentence-ranking" ,
151+ ]:
152+ task = "sentence-embeddings"
153+ inference_handler = INFERENCE_HANDLERS .get (task )
154+ if not inference_handler :
155+ with INFERENCE_HANDLERS_LOCK :
156+ if task not in INFERENCE_HANDLERS :
157+ inference_handler = get_inference_handler_either_custom_or_default_handler (
158+ HF_MODEL_DIR , task = task )
159+ INFERENCE_HANDLERS [task ] = inference_handler
160+ else :
161+ inference_handler = INFERENCE_HANDLERS [task ]
162+ # tracks request time
163+ start_time = perf_counter ()
164+
165+ if should_discard_left () and isinstance (inference_handler , HuggingFaceHandler ):
166+ deserialized_body ['handler_params' ] = {
167+ 'request' : request
168+ }
169+
170+ logger .debug ("Calling inference handler prediction routine" )
164171 # run async not blocking call
165172 pred = await async_handler_call (inference_handler , deserialized_body )
166173
167- # log request time
168- logger .info (
169- f"POST { request .url .path } | Duration: { (perf_counter ()- start_time ) * 1000 :.2f} ms"
170- )
174+ # log request time
175+ logger .info (
176+ f"POST { request .url .path } | Duration: { (perf_counter ()- start_time ) * 1000 :.2f} ms"
177+ )
171178
172- if should_discard_left () and pred is None :
173- logger .info ("No content returned as caller already left" )
174- return Response (status_code = 204 )
175-
176- # response extracts content from request
177- accept = request .headers .get ("accept" )
178- if accept is None or accept == "*/*" :
179- accept = os .environ .get ("DEFAULT_ACCEPT" , "application/json" )
180- logger .info ("Request accepts %s" , accept )
181- # deserialized and resonds with json
182- serialized_response_body = ContentType .get_serializer (accept ).serialize (
183- pred , accept
184- )
185- return Response (serialized_response_body , media_type = accept )
186- except Exception as e :
187- logger .exception (e )
188- return Response (
189- Jsoner .serialize ({"error" : str (e )}),
190- status_code = 400 ,
191- media_type = "application/json" ,
192- )
179+ if should_discard_left () and pred is None :
180+ logger .info ("No content returned as caller already left" )
181+ return Response (status_code = 204 )
182+
183+ # response extracts content from request
184+ accept = request .headers .get ("accept" )
185+ if accept is None or accept == "*/*" :
186+ accept = os .environ .get ("DEFAULT_ACCEPT" , "application/json" )
187+ logger .info ("Request accepts %s" , accept )
188+ # deserialized and resonds with json
189+ serialized_response_body = ContentType .get_serializer (accept ).serialize (
190+ pred , accept
191+ )
192+ return Response (serialized_response_body , media_type = accept )
193+ except Exception as e :
194+ logger .exception (e )
195+ return Response (
196+ Jsoner .serialize ({"error" : str (e )}),
197+ status_code = 400 ,
198+ media_type = "application/json" ,
199+ )
193200
194201
195202# Create app based on which cloud environment is used
0 commit comments