Skip to content

Commit 54d2596

Browse files
committed
feat: log level + fixes: async bug, idle bug
* environment log level var * some long blocking sync calls should be wrapped in a thread (model download) * idle check should be done for the entire predict call otherwise in non idle mode the worker could be kicked in the middle of a request Signed-off-by: Raphael Glon <[email protected]>
1 parent 52511c0 commit 54d2596

File tree

5 files changed

+101
-87
lines changed

5 files changed

+101
-87
lines changed

scripts/entrypoint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,4 @@ if [[ ! -z "${HF_MODEL_DIR}" ]]; then
5959
fi
6060

6161
# Start the server
62-
exec gunicorn webservice_starlette:app -k uvicorn.workers.UvicornWorker --workers ${WORKERS:-1} --bind 0.0.0.0:${PORT}
62+
exec gunicorn webservice_starlette:app -k uvicorn.workers.UvicornWorker --workers ${WORKERS:-1} --bind 0.0.0.0:${PORT} --timeout 30

src/huggingface_inference_toolkit/handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __call__(self, data: Dict[str, Any]):
3838
:return: prediction output
3939
"""
4040

41+
logger.debug("Calling HF default handler")
4142
# import as late as possible to reduce the footprint
4243
from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS
4344

@@ -126,6 +127,7 @@ def __call__(self, data: Dict[str, Any]):
126127
if self.pipeline.task == "token-classification":
127128
parameters.setdefault("aggregation_strategy", os.environ.get("DEFAULT_AGGREGATION_STRATEGY", "simple"))
128129

130+
logger.debug("Performing inference")
129131
resp = self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else \
130132
self.pipeline(inputs, **parameters)
131133

src/huggingface_inference_toolkit/idle.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ async def live_check_loop():
2424

2525
while True:
2626
await asyncio.sleep(sleep_time)
27-
LOG.debug("Checking whether we should unload anything from gpu")
27+
LOG.debug("Checking whether we should unload anything from memory")
2828

2929
last_start = LAST_START
3030
last_end = LAST_END
@@ -50,9 +50,13 @@ async def live_check_loop():
5050
@contextlib.contextmanager
5151
def request_witnesses():
5252
global LAST_START, LAST_END
53+
LOG.debug("Last request start was %s", LAST_START)
54+
LOG.debug("Last request end was %s", LAST_END)
5355
# Simple assignment, concurrency safe, no need for any lock
5456
LAST_START = time.time()
57+
LOG.debug("Current request start timestamp %s", LAST_START)
5558
try:
5659
yield
5760
finally:
5861
LAST_END = time.time()
62+
LOG.debug("Current request end timestamp %s", LAST_END)

src/huggingface_inference_toolkit/logging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
import sys
34

45

@@ -9,7 +10,7 @@ def setup_logging():
910

1011
# Configure the root logger
1112
logging.basicConfig(
12-
level=logging.INFO,
13+
level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO")),
1314
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
1415
datefmt="%Y-%m-%d %H:%M:%S",
1516
stream=sys.stdout,

src/huggingface_inference_toolkit/webservice_starlette.py

Lines changed: 91 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ async def prepare_model_artifacts():
5454

5555

5656
def _eager_model_dl():
57+
logger.debug("Model download")
5758
global MODEL_DOWNLOADED
5859
from huggingface_inference_toolkit.heavy_utils import load_repository_from_hf
5960
# 1. check if model artifacts available in HF_MODEL_DIR
@@ -81,6 +82,8 @@ def _eager_model_dl():
8182
Provided values are:
8283
HF_MODEL_DIR: {HF_MODEL_DIR} and HF_MODEL_ID:{HF_MODEL_ID}"""
8384
)
85+
else:
86+
logger.debug("Model already downloaded in %s", HF_MODEL_DIR)
8487
MODEL_DOWNLOADED = True
8588

8689

@@ -101,95 +104,99 @@ async def metrics(request):
101104

102105

103106
async def predict(request):
104-
global INFERENCE_HANDLERS
105-
106-
if not MODEL_DOWNLOADED:
107-
with MODEL_DL_LOCK:
108-
_eager_model_dl()
109-
try:
110-
task = request.path_params.get("task", HF_TASK)
111-
# extracts content from request
112-
content_type = request.headers.get("content-Type", os.environ.get("DEFAULT_CONTENT_TYPE", "")).lower()
113-
# try to deserialize payload
114-
deserialized_body = ContentType.get_deserializer(content_type, task).deserialize(
115-
await request.body()
116-
)
117-
# checks if input schema is correct
118-
if "inputs" not in deserialized_body and "instances" not in deserialized_body:
119-
raise ValueError(
120-
f"Body needs to provide a inputs key, received: {orjson.dumps(deserialized_body)}"
121-
)
122-
123-
# Decode base64 audio inputs before running inference
124-
if "parameters" in deserialized_body and HF_TASK in {
125-
"automatic-speech-recognition",
126-
"audio-classification",
127-
}:
128-
# Be more strict on base64 decoding, the provided string should valid base64 encoded data
129-
deserialized_body["inputs"] = base64.b64decode(
130-
deserialized_body["inputs"], validate=True
131-
)
132-
133-
# check for query parameter and add them to the body
134-
if request.query_params and "parameters" not in deserialized_body:
135-
deserialized_body["parameters"] = convert_params_to_int_or_bool(
136-
dict(request.query_params)
107+
with idle.request_witnesses():
108+
logger.debug("Received request, scope %s", request.scope)
109+
110+
global INFERENCE_HANDLERS
111+
112+
if not MODEL_DOWNLOADED:
113+
with MODEL_DL_LOCK:
114+
await asyncio.to_thread(_eager_model_dl)
115+
try:
116+
task = request.path_params.get("task", HF_TASK)
117+
# extracts content from request
118+
content_type = request.headers.get("content-Type", os.environ.get("DEFAULT_CONTENT_TYPE", "")).lower()
119+
# try to deserialize payload
120+
deserialized_body = ContentType.get_deserializer(content_type, task).deserialize(
121+
await request.body()
137122
)
138-
139-
# We lazily load pipelines for alt tasks
140-
141-
if task == "feature-extraction" and HF_TASK in [
142-
"sentence-similarity",
143-
"sentence-embeddings",
144-
"sentence-ranking",
145-
]:
146-
task = "sentence-embeddings"
147-
inference_handler = INFERENCE_HANDLERS.get(task)
148-
if not inference_handler:
149-
with INFERENCE_HANDLERS_LOCK:
150-
if task not in INFERENCE_HANDLERS:
151-
inference_handler = get_inference_handler_either_custom_or_default_handler(
152-
HF_MODEL_DIR, task=task)
153-
INFERENCE_HANDLERS[task] = inference_handler
154-
else:
155-
inference_handler = INFERENCE_HANDLERS[task]
156-
# tracks request time
157-
start_time = perf_counter()
158-
159-
if should_discard_left() and isinstance(inference_handler, HuggingFaceHandler):
160-
deserialized_body['handler_params'] = {
161-
'request': request
162-
}
163-
with idle.request_witnesses():
123+
# checks if input schema is correct
124+
if "inputs" not in deserialized_body and "instances" not in deserialized_body:
125+
raise ValueError(
126+
f"Body needs to provide a inputs key, received: {orjson.dumps(deserialized_body)}"
127+
)
128+
129+
# Decode base64 audio inputs before running inference
130+
if "parameters" in deserialized_body and HF_TASK in {
131+
"automatic-speech-recognition",
132+
"audio-classification",
133+
}:
134+
# Be more strict on base64 decoding, the provided string should valid base64 encoded data
135+
deserialized_body["inputs"] = base64.b64decode(
136+
deserialized_body["inputs"], validate=True
137+
)
138+
139+
# check for query parameter and add them to the body
140+
if request.query_params and "parameters" not in deserialized_body:
141+
deserialized_body["parameters"] = convert_params_to_int_or_bool(
142+
dict(request.query_params)
143+
)
144+
145+
# We lazily load pipelines for alt tasks
146+
147+
if task == "feature-extraction" and HF_TASK in [
148+
"sentence-similarity",
149+
"sentence-embeddings",
150+
"sentence-ranking",
151+
]:
152+
task = "sentence-embeddings"
153+
inference_handler = INFERENCE_HANDLERS.get(task)
154+
if not inference_handler:
155+
with INFERENCE_HANDLERS_LOCK:
156+
if task not in INFERENCE_HANDLERS:
157+
inference_handler = get_inference_handler_either_custom_or_default_handler(
158+
HF_MODEL_DIR, task=task)
159+
INFERENCE_HANDLERS[task] = inference_handler
160+
else:
161+
inference_handler = INFERENCE_HANDLERS[task]
162+
# tracks request time
163+
start_time = perf_counter()
164+
165+
if should_discard_left() and isinstance(inference_handler, HuggingFaceHandler):
166+
deserialized_body['handler_params'] = {
167+
'request': request
168+
}
169+
170+
logger.debug("Calling inference handler prediction routine")
164171
# run async not blocking call
165172
pred = await async_handler_call(inference_handler, deserialized_body)
166173

167-
# log request time
168-
logger.info(
169-
f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms"
170-
)
174+
# log request time
175+
logger.info(
176+
f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms"
177+
)
171178

172-
if should_discard_left() and pred is None:
173-
logger.info("No content returned as caller already left")
174-
return Response(status_code=204)
175-
176-
# response extracts content from request
177-
accept = request.headers.get("accept")
178-
if accept is None or accept == "*/*":
179-
accept = os.environ.get("DEFAULT_ACCEPT", "application/json")
180-
logger.info("Request accepts %s", accept)
181-
# deserialized and resonds with json
182-
serialized_response_body = ContentType.get_serializer(accept).serialize(
183-
pred, accept
184-
)
185-
return Response(serialized_response_body, media_type=accept)
186-
except Exception as e:
187-
logger.exception(e)
188-
return Response(
189-
Jsoner.serialize({"error": str(e)}),
190-
status_code=400,
191-
media_type="application/json",
192-
)
179+
if should_discard_left() and pred is None:
180+
logger.info("No content returned as caller already left")
181+
return Response(status_code=204)
182+
183+
# response extracts content from request
184+
accept = request.headers.get("accept")
185+
if accept is None or accept == "*/*":
186+
accept = os.environ.get("DEFAULT_ACCEPT", "application/json")
187+
logger.info("Request accepts %s", accept)
188+
# deserialized and resonds with json
189+
serialized_response_body = ContentType.get_serializer(accept).serialize(
190+
pred, accept
191+
)
192+
return Response(serialized_response_body, media_type=accept)
193+
except Exception as e:
194+
logger.exception(e)
195+
return Response(
196+
Jsoner.serialize({"error": str(e)}),
197+
status_code=400,
198+
media_type="application/json",
199+
)
193200

194201

195202
# Create app based on which cloud environment is used

0 commit comments

Comments
 (0)