|
1 | 1 | import importlib.util |
| 2 | +import ipaddress |
| 3 | +import os |
2 | 4 | import sys |
3 | 5 | from pathlib import Path |
4 | 6 |
|
| 7 | +import psutil |
| 8 | +from starlette.requests import Request |
| 9 | + |
5 | 10 | from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME |
6 | 11 | from huggingface_inference_toolkit.logging import logger |
7 | 12 |
|
@@ -66,7 +71,7 @@ def check_and_register_custom_pipeline_from_directory(model_dir): |
66 | 71 | # init custom handler with model_dir |
67 | 72 | custom_pipeline = handler.EndpointHandler(model_dir) |
68 | 73 | else: |
69 | | - logger.info(f"No spec from file location found for module %s, file %s", HF_MODULE_NAME, custom_module) |
| 74 | + logger.info("No spec from file location found for module %s, file %s", HF_MODULE_NAME, custom_module) |
70 | 75 | elif legacy_module.is_file(): |
71 | 76 | logger.warning( |
72 | 77 | """You are using a legacy custom pipeline. |
@@ -99,3 +104,63 @@ def convert_params_to_int_or_bool(params): |
99 | 104 | if v == "true": |
100 | 105 | params[k] = True |
101 | 106 | return params |
| 107 | + |
| 108 | + |
| 109 | +def should_discard_left() -> bool: |
| 110 | + return os.getenv('DISCARD_LEFT', '0').lower() in ['true', 'yes', '1'] |
| 111 | + |
| 112 | + |
| 113 | +def already_left(request: Request) -> bool: |
| 114 | + """ |
| 115 | + Check if the caller has already left without waiting for the answer to come. This can help during burst to relieve |
| 116 | + the pressure on the worker by cancelling jobs whose results don't matter as they won't be fetched anyway |
| 117 | + :param request: |
| 118 | + :return: bool |
| 119 | + """ |
| 120 | + # NOTE: Starlette method request.is_disconnected is totally broken, consumes the payload, does not return |
| 121 | + # the correct status. So we use the good old way to identify if the caller is still there. |
| 122 | + # In any case, if we are not sure, we return False |
| 123 | + logger.info("Checking if request caller already left") |
| 124 | + try: |
| 125 | + client = request.client |
| 126 | + host = client.host |
| 127 | + if not host: |
| 128 | + return False |
| 129 | + |
| 130 | + port = int(client.port) |
| 131 | + host = ipaddress.ip_address(host) |
| 132 | + |
| 133 | + if port <= 0 or port > 65535: |
| 134 | + logger.warning("Unexpected source port format for caller %s", port) |
| 135 | + return False |
| 136 | + counter = 0 |
| 137 | + for connection in psutil.net_connections(kind="tcp"): |
| 138 | + counter += 1 |
| 139 | + if connection.status != "ESTABLISHED": |
| 140 | + continue |
| 141 | + if not connection.raddr: |
| 142 | + continue |
| 143 | + if int(connection.raddr.port) != port: |
| 144 | + continue |
| 145 | + if ( |
| 146 | + not connection.raddr.ip |
| 147 | + or ipaddress.ip_address(connection.raddr.ip) != host |
| 148 | + ): |
| 149 | + continue |
| 150 | + logger.info( |
| 151 | + "Found caller connection still established, caller is most likely still there, %s", |
| 152 | + connection, |
| 153 | + ) |
| 154 | + return False |
| 155 | + except Exception as e: |
| 156 | + logger.warning( |
| 157 | + "Unexpected error while checking if caller already left, assuming still there" |
| 158 | + ) |
| 159 | + logger.exception(e) |
| 160 | + return False |
| 161 | + |
| 162 | + logger.info( |
| 163 | + "%d connections checked. No connection found matching to the caller, probably left", |
| 164 | + counter, |
| 165 | + ) |
| 166 | + return True |
0 commit comments