|
1 | 1 | import importlib.util |
| 2 | +import ipaddress |
2 | 3 | import sys |
3 | 4 | from pathlib import Path |
4 | 5 |
|
| 6 | +import psutil |
| 7 | +from starlette.requests import Request |
| 8 | + |
5 | 9 | from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME |
6 | 10 | from huggingface_inference_toolkit.logging import logger |
7 | 11 |
|
@@ -99,3 +103,59 @@ def convert_params_to_int_or_bool(params): |
99 | 103 | if v == "true": |
100 | 104 | params[k] = True |
101 | 105 | return params |
| 106 | + |
| 107 | + |
| 108 | +def already_left(request: Request) -> bool: |
| 109 | + """ |
| 110 | + Check if the caller has already left without waiting for the answer to come. This can help during burst to relieve |
| 111 | + the pressure on the worker by cancelling jobs whose results don't matter as they won't be fetched anyway |
| 112 | + :param request: |
| 113 | + :return: bool |
| 114 | + """ |
| 115 | + # NOTE: Starlette method request.is_disconnected is totally broken, consumes the payload, does not return |
| 116 | + # the correct status. So we use the good old way to identify if the caller is still there. |
| 117 | + # In any case, if we are not sure, we return False |
| 118 | + logger.info("Checking if request caller already left") |
| 119 | + try: |
| 120 | + client = request.client |
| 121 | + host = client.host |
| 122 | + if not host: |
| 123 | + return False |
| 124 | + |
| 125 | + port = int(client.port) |
| 126 | + host = ipaddress.ip_address(host) |
| 127 | + |
| 128 | + if port <= 0 or port > 65535: |
| 129 | + logger.warning("Unexpected source port format for caller %s", port) |
| 130 | + return False |
| 131 | + counter = 0 |
| 132 | + for connection in psutil.net_connections(kind="tcp"): |
| 133 | + counter += 1 |
| 134 | + if connection.status != "ESTABLISHED": |
| 135 | + continue |
| 136 | + if not connection.raddr: |
| 137 | + continue |
| 138 | + if int(connection.raddr.port) != port: |
| 139 | + continue |
| 140 | + if ( |
| 141 | + not connection.raddr.ip |
| 142 | + or ipaddress.ip_address(connection.raddr.ip) != host |
| 143 | + ): |
| 144 | + continue |
| 145 | + logger.info( |
| 146 | + "Found caller connection still established, caller is most likely still there, %s", |
| 147 | + connection, |
| 148 | + ) |
| 149 | + return False |
| 150 | + except Exception as e: |
| 151 | + logger.warning( |
| 152 | + "Unexpected error while checking if caller already left, assuming still there" |
| 153 | + ) |
| 154 | + logger.exception(e) |
| 155 | + return False |
| 156 | + |
| 157 | + logger.info( |
| 158 | + "%d connections checked. No connection found matching to the caller, probably left", |
| 159 | + counter, |
| 160 | + ) |
| 161 | + return True |
0 commit comments