Skip to content

Commit 52511c0

Browse files
committed
feat(relieve): discard request if the caller is not waiting for the answer anymore*
When behind a proxy this requires the proxy to close the connection to be effective though Signed-off-by: Raphael Glon <[email protected]>
1 parent c0a0e42 commit 52511c0

File tree

6 files changed

+98
-7
lines changed

6 files changed

+98
-7
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ torch==2.5.1
2020
torchvision
2121
torchaudio
2222
peft==0.15.1
23+
psutil>=6.0.0

src/huggingface_inference_toolkit/handler.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
from huggingface_inference_toolkit import logging
66
from huggingface_inference_toolkit.const import HF_TRUST_REMOTE_CODE
77
from huggingface_inference_toolkit.env_utils import api_inference_compat, ignore_custom_handler
8-
from huggingface_inference_toolkit.utils import check_and_register_custom_pipeline_from_directory
8+
from huggingface_inference_toolkit.logging import logger
9+
from huggingface_inference_toolkit.utils import (
10+
already_left,
11+
check_and_register_custom_pipeline_from_directory,
12+
should_discard_left,
13+
)
914

1015

1116
class HuggingFaceHandler:
@@ -39,7 +44,17 @@ def __call__(self, data: Dict[str, Any]):
3944
inputs = data.pop("inputs", data)
4045
parameters = data.pop("parameters", {})
4146

42-
# diffusers and sentence transformers pipelines do not have the `task` arg
47+
if "handler_params" in data:
48+
handler_params = data.pop("handler_params")
49+
if should_discard_left():
50+
request = handler_params.get("request")
51+
if not request:
52+
logger.warn("Cannot know if request caller already left, missing request handler param")
53+
elif already_left(request):
54+
logger.info("Discarding request as the caller already left")
55+
return None
56+
57+
# diffusers and sentence transformers pipelines do not have the `task` arg
4358
if not hasattr(self.pipeline, "task"):
4459
# sentence transformers parameters not supported yet
4560
if any(isinstance(self.pipeline, v) for v in SENTENCE_TRANSFORMERS_TASKS.values()):
@@ -168,7 +183,7 @@ def __call__(self, data: Dict[str, Any]):
168183
scores = resp['scores']
169184
if len(labels) == len(scores):
170185
new_resp = []
171-
for label, score in zip(labels, scores):
186+
for label, score in zip(labels, scores, strict=True):
172187
new_resp.append({"label": label, "score": score})
173188
resp = new_resp
174189
else:

src/huggingface_inference_toolkit/heavy_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,4 @@ def get_pipeline(
184184
hf_pipeline.model.config.forced_decoder_ids = hf_pipeline.tokenizer.get_decoder_prompt_ids(
185185
language="english", task="transcribe"
186186
)
187-
return hf_pipeline # type: ignore
187+
return hf_pipeline # type: ignore

src/huggingface_inference_toolkit/serialization/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ContentType:
3838
@staticmethod
3939
def get_deserializer(content_type: str, task: str):
4040
if not content_type:
41-
message = f"No content type provided and no default one configured."
41+
message = "No content type provided and no default one configured."
4242
raise Exception(message)
4343
if content_type.lower().startswith("application/octet-stream"):
4444
if "audio" in task or "speech" in task:

src/huggingface_inference_toolkit/utils.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import importlib.util
2+
import ipaddress
3+
import os
24
import sys
35
from pathlib import Path
46

7+
import psutil
8+
from starlette.requests import Request
9+
510
from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME
611
from huggingface_inference_toolkit.logging import logger
712

@@ -66,7 +71,7 @@ def check_and_register_custom_pipeline_from_directory(model_dir):
6671
# init custom handler with model_dir
6772
custom_pipeline = handler.EndpointHandler(model_dir)
6873
else:
69-
logger.info(f"No spec from file location found for module %s, file %s", HF_MODULE_NAME, custom_module)
74+
logger.info("No spec from file location found for module %s, file %s", HF_MODULE_NAME, custom_module)
7075
elif legacy_module.is_file():
7176
logger.warning(
7277
"""You are using a legacy custom pipeline.
@@ -99,3 +104,63 @@ def convert_params_to_int_or_bool(params):
99104
if v == "true":
100105
params[k] = True
101106
return params
107+
108+
109+
def should_discard_left() -> bool:
110+
return os.getenv('DISCARD_LEFT', '0').lower() in ['true', 'yes', '1']
111+
112+
113+
def already_left(request: Request) -> bool:
114+
"""
115+
Check if the caller has already left without waiting for the answer to come. This can help during burst to relieve
116+
the pressure on the worker by cancelling jobs whose results don't matter as they won't be fetched anyway
117+
:param request:
118+
:return: bool
119+
"""
120+
# NOTE: Starlette method request.is_disconnected is totally broken, consumes the payload, does not return
121+
# the correct status. So we use the good old way to identify if the caller is still there.
122+
# In any case, if we are not sure, we return False
123+
logger.info("Checking if request caller already left")
124+
try:
125+
client = request.client
126+
host = client.host
127+
if not host:
128+
return False
129+
130+
port = int(client.port)
131+
host = ipaddress.ip_address(host)
132+
133+
if port <= 0 or port > 65535:
134+
logger.warning("Unexpected source port format for caller %s", port)
135+
return False
136+
counter = 0
137+
for connection in psutil.net_connections(kind="tcp"):
138+
counter += 1
139+
if connection.status != "ESTABLISHED":
140+
continue
141+
if not connection.raddr:
142+
continue
143+
if int(connection.raddr.port) != port:
144+
continue
145+
if (
146+
not connection.raddr.ip
147+
or ipaddress.ip_address(connection.raddr.ip) != host
148+
):
149+
continue
150+
logger.info(
151+
"Found caller connection still established, caller is most likely still there, %s",
152+
connection,
153+
)
154+
return False
155+
except Exception as e:
156+
logger.warning(
157+
"Unexpected error while checking if caller already left, assuming still there"
158+
)
159+
logger.exception(e)
160+
return False
161+
162+
logger.info(
163+
"%d connections checked. No connection found matching to the caller, probably left",
164+
counter,
165+
)
166+
return True

src/huggingface_inference_toolkit/webservice_starlette.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
)
2323
from huggingface_inference_toolkit.env_utils import api_inference_compat
2424
from huggingface_inference_toolkit.handler import (
25+
HuggingFaceHandler,
2526
get_inference_handler_either_custom_or_default_handler,
2627
)
2728
from huggingface_inference_toolkit.logging import logger
2829
from huggingface_inference_toolkit.serialization.base import ContentType
2930
from huggingface_inference_toolkit.serialization.json_utils import Jsoner
30-
from huggingface_inference_toolkit.utils import convert_params_to_int_or_bool
31+
from huggingface_inference_toolkit.utils import convert_params_to_int_or_bool, should_discard_left
3132
from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs
3233

3334
INFERENCE_HANDLERS = {}
@@ -101,6 +102,7 @@ async def metrics(request):
101102

102103
async def predict(request):
103104
global INFERENCE_HANDLERS
105+
104106
if not MODEL_DOWNLOADED:
105107
with MODEL_DL_LOCK:
106108
_eager_model_dl()
@@ -154,6 +156,10 @@ async def predict(request):
154156
# tracks request time
155157
start_time = perf_counter()
156158

159+
if should_discard_left() and isinstance(inference_handler, HuggingFaceHandler):
160+
deserialized_body['handler_params'] = {
161+
'request': request
162+
}
157163
with idle.request_witnesses():
158164
# run async not blocking call
159165
pred = await async_handler_call(inference_handler, deserialized_body)
@@ -163,6 +169,10 @@ async def predict(request):
163169
f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms"
164170
)
165171

172+
if should_discard_left() and pred is None:
173+
logger.info("No content returned as caller already left")
174+
return Response(status_code=204)
175+
166176
# response extracts content from request
167177
accept = request.headers.get("accept")
168178
if accept is None or accept == "*/*":

0 commit comments

Comments
 (0)