Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ filterwarnings = [

[tool.coverage.run]
source = ["src"]
concurrency = ["multiprocessing", "thread"]
parallel = true
sigterm = true
omit = [
"*/tests/*",
"*/test_*",
Expand Down
14 changes: 8 additions & 6 deletions src/inference_endpoint/endpoint_client/accumulator_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,19 @@ class SSEAccumulatorProtocol(Protocol):
is disabled, only the first chunk is emitted via add_chunk().
"""

def __init__(self, query_id: str, stream_all_chunks: bool) -> None:
def __init__(
self, query_id: str, stream_all_chunks: bool
) -> None: # pragma: no cover
"""
Initialize the accumulator.

Args:
query_id: Unique identifier for the request being accumulated.
stream_all_chunks: If True, emit all chunks; if False, only first chunk.
"""
pass
...

def add_chunk(self, delta: Any) -> StreamChunk | None:
def add_chunk(self, delta: Any) -> StreamChunk | None: # pragma: no cover
"""
Process an SSE delta and optionally emit a StreamChunk.

Expand All @@ -54,9 +56,9 @@ def add_chunk(self, delta: Any) -> StreamChunk | None:
Returns None for empty deltas, or after first chunk when
stream_all_chunks=False (TTFT-only mode).
"""
pass
...

def get_final_output(self) -> QueryResult:
def get_final_output(self) -> QueryResult: # pragma: no cover
"""
Return the final accumulated result after stream completion.

Expand All @@ -66,4 +68,4 @@ def get_final_output(self) -> QueryResult:
Returns:
QueryResult with the complete response output.
"""
pass
...
74 changes: 41 additions & 33 deletions src/inference_endpoint/endpoint_client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ class PooledConnection:
"last_used",
"in_use",
"idle_time_on_acquire",
"_fd",
"_stale_poller",
)

def __init__(
Expand All @@ -385,6 +387,11 @@ def __init__(
self.in_use = True
self.idle_time_on_acquire = 0.0

# Cache fd for stale checks — stable for the lifetime of the connection
sock = transport.get_extra_info("socket")
self._fd: int = sock.fileno() if sock is not None else -1
self._stale_poller: select.poll | None = None

def is_alive(self) -> bool:
"""Check if the connection is still usable.

Expand All @@ -405,31 +412,34 @@ def is_stale(self) -> bool:
For idle HTTP keep-alive connections, there should be no pending data.
If the socket is readable, it means the server sent FIN (EOF).

Optimization: Skip check for recently-used connections (< 1 second).
Uses poll() instead of select() to avoid FD_SETSIZE limit on high fds.
Poller is created lazily on first call and reused (fd is stable per connection).
"""
# Skip stale check for recently-used connections
# Server unlikely to close within 1 second of last use
if time.monotonic() - self.last_used < 1.0:
return False

if self.transport is None:
return True
# Fast path: poller already registered from a previous call
if self._stale_poller is not None:
try:
return bool(self._stale_poller.poll(0))
except (OSError, ValueError):
# fd closed or invalid — connection is dead, treat as stale
return True

# Get the socket file descriptor
sock = self.transport.get_extra_info("socket")
if sock is None:
# Slow path: first call — create poller and register fd
if self._fd < 0:
return True

try:
fd = sock.fileno()
if fd < 0:
return True

# Use select with zero timeout - avoids poll() object creation overhead
readable, _, exceptional = select.select([fd], [], [fd], 0)
return bool(readable or exceptional)
poller = select.poll()
poller.register(self._fd, select.POLLIN | select.POLLERR | select.POLLHUP)
self._stale_poller = poller
return bool(poller.poll(0))

except (OSError, ValueError):
# fd closed or invalid — connection is dead, treat as stale
return True


Expand Down Expand Up @@ -660,10 +670,11 @@ class HttpRequestTemplate:
that remain constant across requests to a given endpoint.

Attributes:
static_prefix: Pre-merged request line + host header bytes
static_prefix: Pre-merged request line + host header bytes.
cached_headers: Pre-encoded headers from cache_headers(), included in every request.
"""

__slots__ = ("static_prefix", "_extra_headers_cache", "extra_cached_headers")
__slots__ = ("static_prefix", "cached_headers")

# Pre-encoded general headers
HEADERS_STREAMING = (
Expand All @@ -675,8 +686,7 @@ class HttpRequestTemplate:

def __init__(self, static_prefix: bytes):
self.static_prefix = static_prefix
self._extra_headers_cache: dict[frozenset, bytes] = {}
self.extra_cached_headers = b""
self.cached_headers = b""

@classmethod
def from_url(cls, host: str, port: int, path: str = "/") -> HttpRequestTemplate:
Expand Down Expand Up @@ -714,12 +724,13 @@ def cache_headers(self, headers: dict[str, str]) -> None:
Args:
headers: Headers to pre-encode and cache
"""
cache_key = frozenset(headers.items())
if cache_key not in self._extra_headers_cache:
self._extra_headers_cache[cache_key] = "".join(
f"{k}: {v}\r\n" for k, v in headers.items()
).encode("utf-8", "surrogateescape")
self.extra_cached_headers = b"".join(self._extra_headers_cache.values())
encoded = "".join(f"{k}: {v}\r\n" for k, v in headers.items()).encode(
"utf-8", "surrogateescape"
)
# Substring dedup: safe because this is called once at setup with
# full header lines (e.g. "Authorization: Bearer ...\r\n"), not arbitrary fragments.
if encoded not in self.cached_headers:
self.cached_headers += encoded

def build_request(
self,
Expand Down Expand Up @@ -748,25 +759,22 @@ def build_request(
return b"".join(
[
self.static_prefix,
self.extra_cached_headers,
self.cached_headers,
content_type_headers,
content_length,
body,
]
)

# Slow path: extra headers (~1us uncached, ~50ns per extra-header cached)
cache_key = frozenset(extra_headers.items())
if (extra := self._extra_headers_cache.get(cache_key)) is None:
extra = "".join(f"{k}: {v}\r\n" for k, v in extra_headers.items()).encode(
"utf-8", "surrogateescape"
)
self._extra_headers_cache[cache_key] = extra

# Slow path: extra headers are encoded per-call;
# use cache_headers() at setup time for headers that repeat every request.
extra = "".join(f"{k}: {v}\r\n" for k, v in extra_headers.items()).encode(
"utf-8", "surrogateescape"
)
return b"".join(
[
self.static_prefix,
self.extra_cached_headers,
self.cached_headers,
content_type_headers,
extra,
content_length,
Expand Down
2 changes: 1 addition & 1 deletion src/inference_endpoint/endpoint_client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
# - uvloop requires use of 'spawn'
try:
multiprocessing.set_start_method("spawn", force=False)
except RuntimeError:
except RuntimeError: # pragma: no cover
# Already set, which is fine (likely in tests or when importing multiple times)
pass

Expand Down
Loading