Skip to content

Commit ff66399

Browse files
authored
Merge branch 'main' into feat/attafosu/sglang-openai-api-compatibility
2 parents 1572da4 + 760fb88 commit ff66399

File tree

15 files changed

+118
-23
lines changed

15 files changed

+118
-23
lines changed

src/inference_endpoint/commands/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def _build_config_from_cli(
325325
client=ClientSettings(
326326
workers=args.workers if args.workers else -1,
327327
log_level="DEBUG" if verbose_level >= 2 else "INFO",
328-
warmup_connections=getattr(args, "warmup_connections", True),
328+
warmup_connections=getattr(args, "warmup_connections", -1),
329329
max_connections=getattr(args, "max_connections", None) or -1,
330330
),
331331
),

src/inference_endpoint/commands/probe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ async def run_probe_command(args: argparse.Namespace) -> None:
7272
],
7373
api_type=api_type,
7474
num_workers=1,
75-
warmup_connections=False,
75+
warmup_connections=0,
7676
)
7777
# Client creates its own event loop in a separate thread
7878
client = HTTPEndpointClient(http_config, zmq_context=zmq_ctx)

src/inference_endpoint/config/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ class ClientSettings(BaseModel):
292292
log_level: str = "INFO"
293293

294294
# Pre-establish TCP connections during init for reuse at runtime.
295-
warmup_connections: bool = True
295+
# Values: -1 = auto (50% of pool), 0 = disabled, >0 = explicit total count
296+
warmup_connections: int = -1
296297

297298
# Maximum concurrent TCP connections per worker.
298299
# -1 = unlimited (bound by system ephemeral port limit)

src/inference_endpoint/metrics/reporter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,12 @@ def derive_TPOT(
10751075
output_sequence, reasoning_sequence = output_sequence_from_data(
10761076
data_bytes, join_chunks=False
10771077
)
1078+
if isinstance(output_sequence, str):
1079+
output_sequence = [output_sequence]
10781080
if not isinstance(output_sequence, list):
1081+
logging.warning(
1082+
f"Output sequence for sample {sample_uuid} is not a list but {type(output_sequence)}: {output_sequence}"
1083+
)
10791084
continue
10801085

10811086
all_chunks = output_sequence

src/inference_endpoint/utils/benchmark_httpclient.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import time
3838
from dataclasses import dataclass
3939

40+
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
4041
from inference_endpoint.core.types import Query, QueryResult
4142
from inference_endpoint.endpoint_client.config import HTTPClientConfig
4243
from inference_endpoint.endpoint_client.cpu_affinity import compute_affinity_plan
@@ -399,6 +400,7 @@ def _create_client(
399400
prompt: str,
400401
enable_affinity: bool,
401402
verbose: bool = True,
403+
zmq_context: ManagedZMQContext | None = None,
402404
) -> tuple:
403405
"""Create an endpoint client and query data dict.
404406
@@ -422,7 +424,7 @@ def _create_client(
422424
endpoint_urls=[endpoint_url],
423425
num_workers=num_workers if num_workers > 0 else -1,
424426
max_connections=max_connections if max_connections > 0 else -1,
425-
warmup_connections=False,
427+
warmup_connections=0,
426428
worker_gc_mode="relaxed",
427429
log_level="CRITICAL",
428430
cpu_affinity=cpu_affinity_plan,
@@ -434,7 +436,7 @@ def _create_client(
434436
f"max_connections={config.max_connections}, stream={streaming}"
435437
)
436438

437-
client = AsyncHttpEndpointClient(config)
439+
client = AsyncHttpEndpointClient(config, zmq_context=zmq_context)
438440
query_data = {
439441
"prompt": prompt,
440442
"model": "benchmark-model",
@@ -488,13 +490,17 @@ def run_benchmark(
488490
except OSError:
489491
pass
490492

493+
zmq_ctx_manager = ManagedZMQContext.scoped()
494+
zmq_ctx = zmq_ctx_manager.__enter__()
495+
491496
client, query_data = _create_client(
492497
endpoint_url,
493498
num_workers,
494499
max_connections,
495500
streaming,
496501
prompt,
497502
enable_affinity,
503+
zmq_context=zmq_ctx,
498504
)
499505
loop = client.loop
500506
stats = BenchmarkStats(sse_events_per_response=sse_events_per_response)
@@ -613,6 +619,7 @@ async def receiver():
613619
gc.collect()
614620

615621
asyncio.run_coroutine_threadsafe(client.shutdown(), loop).result(timeout=10.0)
622+
zmq_ctx_manager.__exit__(None, None, None)
616623

617624
# Restore original affinity so the next sweep iteration sees all CPUs
618625
if saved_affinity is not None:

tests/integration/commands/test_benchmark_command.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def test_offline_benchmark_with_echo_server(
5959
verbose=1,
6060
model="echo-server",
6161
timeout=None,
62-
warmup_connections=False,
62+
warmup_connections=0,
6363
)
6464

6565
with caplog.at_level("INFO"):
@@ -99,7 +99,7 @@ async def test_online_benchmark_with_echo_server(
9999
verbose=1,
100100
model="echo-server",
101101
timeout=None,
102-
warmup_connections=False,
102+
warmup_connections=0,
103103
)
104104
with caplog.at_level("INFO"):
105105
await run_benchmark_command(args)
@@ -143,7 +143,7 @@ async def test_benchmark_with_output_file(
143143
verbose=0,
144144
model="echo-server",
145145
timeout=None,
146-
warmup_connections=False,
146+
warmup_connections=0,
147147
)
148148

149149
await run_benchmark_command(args)
@@ -185,7 +185,7 @@ async def test_benchmark_mode_logging(
185185
verbose=1,
186186
model="echo-server",
187187
timeout=None,
188-
warmup_connections=False,
188+
warmup_connections=0,
189189
)
190190
with caplog.at_level("INFO"):
191191
await run_benchmark_command(args)

tests/integration/endpoint_client/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def create_futures_client(
2626
url: str,
2727
num_workers: int = 1,
2828
max_connections: int = 10,
29-
warmup_connections: bool = False,
29+
warmup_connections: int = 0,
3030
zmq_context=None,
3131
) -> FuturesHttpClient:
3232
"""Helper to create a FuturesHttpClient with specific config.
@@ -35,7 +35,7 @@ def create_futures_client(
3535
url: The endpoint URL to connect to
3636
num_workers: Number of worker processes (default: 1)
3737
max_connections: Max connections per worker (default: 10 for tests)
38-
warmup_connections: Whether to warmup connections (default: False for tests)
38+
warmup_connections: Warmup connection count (0 = disabled, -1 = auto, >0 = explicit)
3939
zmq_context: ManagedZMQContext when using ZMQ transport (required by default config).
4040
4141
Returns:

tests/integration/endpoint_client/test_external_serving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _create_custom_client(
4141
endpoint_urls=[f"{vllm_docker_server['url']}/v1/chat/completions"],
4242
num_workers=num_workers,
4343
max_connections=50,
44-
warmup_connections=False,
44+
warmup_connections=0,
4545
)
4646

4747
# TODO(vir):

tests/integration/endpoint_client/test_http_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ async def test_many_workers(self, mock_http_echo_server):
156156
num_workers=num_workers,
157157
max_connections=num_workers
158158
* 10, # ensure each worker has connections
159-
warmup_connections=False,
159+
warmup_connections=0,
160160
zmq_context=zmq_ctx,
161161
)
162162

@@ -330,7 +330,7 @@ async def test_streaming_error_propagation(self):
330330
# Use invalid endpoint to trigger errors
331331
client = create_futures_client(
332332
"http://invalid-endpoint-12345:9999/v1/chat/completions",
333-
warmup_connections=False,
333+
warmup_connections=0,
334334
zmq_context=zmq_ctx,
335335
)
336336

tests/integration/endpoint_client/test_sglang_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def sglang_futures_client():
4343
endpoint_urls=[SGLANG_ENDPOINT],
4444
num_workers=4,
4545
api_type="sglang",
46-
warmup_connections=False,
46+
warmup_connections=0,
4747
)
4848

4949
client = FuturesHttpClient(http_config)

0 commit comments

Comments
 (0)