Skip to content

Commit 935e4f4

Browse files
committed
feat(relieve): discard request if the caller is not waiting for the answer anymore*
When behind a proxy this requires the proxy to close the connection to be effective though Signed-off-by: Raphael Glon <[email protected]>
1 parent c0a0e42 commit 935e4f4

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ torch==2.5.1
2020
torchvision
2121
torchaudio
2222
peft==0.15.1
23+
psutil>=6.0.0

src/huggingface_inference_toolkit/utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import importlib.util
2+
import ipaddress
23
import sys
34
from pathlib import Path
45

6+
import psutil
7+
from starlette.requests import Request
8+
59
from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME
610
from huggingface_inference_toolkit.logging import logger
711

@@ -99,3 +103,59 @@ def convert_params_to_int_or_bool(params):
99103
if v == "true":
100104
params[k] = True
101105
return params
106+
107+
108+
def already_left(request: Request) -> bool:
109+
"""
110+
Check if the caller has already left without waiting for the answer to come. This can help during burst to relieve
111+
the pressure on the worker by cancelling jobs whose results don't matter as they won't be fetched anyway
112+
:param request:
113+
:return: bool
114+
"""
115+
# NOTE: Starlette method request.is_disconnected is totally broken, consumes the payload, does not return
116+
# the correct status. So we use the good old way to identify if the caller is still there.
117+
# In any case, if we are not sure, we return False
118+
logger.info("Checking if request caller already left")
119+
try:
120+
client = request.client
121+
host = client.host
122+
if not host:
123+
return False
124+
125+
port = int(client.port)
126+
host = ipaddress.ip_address(host)
127+
128+
if port <= 0 or port > 65535:
129+
logger.warning("Unexpected source port format for caller %s", port)
130+
return False
131+
counter = 0
132+
for connection in psutil.net_connections(kind="tcp"):
133+
counter += 1
134+
if connection.status != "ESTABLISHED":
135+
continue
136+
if not connection.raddr:
137+
continue
138+
if int(connection.raddr.port) != port:
139+
continue
140+
if (
141+
not connection.raddr.ip
142+
or ipaddress.ip_address(connection.raddr.ip) != host
143+
):
144+
continue
145+
logger.info(
146+
"Found caller connection still established, caller is most likely still there, %s",
147+
connection,
148+
)
149+
return False
150+
except Exception as e:
151+
logger.warning(
152+
"Unexpected error while checking if caller already left, assuming still there"
153+
)
154+
logger.exception(e)
155+
return False
156+
157+
logger.info(
158+
"%d connections checked. No connection found matching to the caller, probably left",
159+
counter,
160+
)
161+
return True

src/huggingface_inference_toolkit/webservice_starlette.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from huggingface_inference_toolkit.logging import logger
2828
from huggingface_inference_toolkit.serialization.base import ContentType
2929
from huggingface_inference_toolkit.serialization.json_utils import Jsoner
30-
from huggingface_inference_toolkit.utils import convert_params_to_int_or_bool
30+
from huggingface_inference_toolkit.utils import convert_params_to_int_or_bool, already_left
3131
from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs
3232

3333
INFERENCE_HANDLERS = {}
@@ -101,6 +101,11 @@ async def metrics(request):
101101

102102
async def predict(request):
103103
global INFERENCE_HANDLERS
104+
105+
if os.getenv("DISCARD_LEFT", "0").lower() in ["1", "true", "yes"] and already_left(request):
106+
logger.info("Discarding request as the caller already left")
107+
return Response(status_code=204)
108+
104109
if not MODEL_DOWNLOADED:
105110
with MODEL_DL_LOCK:
106111
_eager_model_dl()

0 commit comments

Comments
 (0)