Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/inference_endpoint/commands/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _build_config_from_cli(
client=ClientSettings(
workers=args.workers if args.workers else -1,
log_level="DEBUG" if verbose_level >= 2 else "INFO",
warmup_connections=getattr(args, "warmup_connections", True),
warmup_connections=getattr(args, "warmup_connections", -1),
max_connections=getattr(args, "max_connections", None) or -1,
),
),
Expand Down
2 changes: 1 addition & 1 deletion src/inference_endpoint/commands/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def run_probe_command(args: argparse.Namespace) -> None:
],
api_type=api_type,
num_workers=1,
warmup_connections=False,
warmup_connections=0,
)
# Client creates its own event loop in a separate thread
client = HTTPEndpointClient(http_config, zmq_context=zmq_ctx)
Expand Down
3 changes: 2 additions & 1 deletion src/inference_endpoint/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ class ClientSettings(BaseModel):
log_level: str = "INFO"

# Pre-establish TCP connections during init for reuse at runtime.
warmup_connections: bool = True
# Values: -1 = auto (50% of pool), 0 = disabled, >0 = explicit total count
warmup_connections: int = -1

# Maximum concurrent TCP connections per worker.
# -1 = unlimited (bound by system ephemeral port limit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from inference_endpoint.dataset_manager.transforms import (
AddStaticColumns,
Harmonize,
Transform,
UserPromptFormatter,
)
Expand Down Expand Up @@ -48,3 +49,28 @@ def llama3_8b(
),
AddStaticColumns(chat_template),
]


def llama3_8b_sglang(
stream: bool = True,
max_new_tokens: int = 128,
temperature: float = 0.0,
top_p: float = 1.0,
top_k: int = 1,
Comment on lines +55 to +59
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameters stream, temperature, top_p, and top_k are defined in the llama3_8b_sglang function signature but are not used within the function body. This indicates dead code, which can be misleading and suggests that these parameters might be intended for future use or were overlooked. If these parameters are not meant to be used, they should be removed. If they are intended to be used, their functionality should be implemented.

tokenizer_name: str = "meta-llama/Llama-3.1-8B-Instruct",
) -> list[Transform]:
return [
# Step 1: Format the prompt from "article"
UserPromptFormatter(
user_prompt_format=f"Summarize the following news article in {max_new_tokens} tokens. Please output the summary only, without any other text.\n\nArticle:\n{{article}}\n\nSummary:",
output_column="prompt",
),
# Step 2: Tokenize the raw prompt via Harmonize in plain mode.
Harmonize(
tokenizer_name=tokenizer_name,
prompt_column="prompt",
tokenized_column="input_tokens",
harmonized_column=None,
mode="plain",
),
]
19 changes: 18 additions & 1 deletion src/inference_endpoint/dataset_manager/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
prompt_column: str = "prompt",
tokenized_column: str = "input_tokens",
harmonized_column: str | None = "harmonized_prompt",
mode: str = "harmony",
):
"""Initialize the Harmonize transform.

Expand All @@ -145,10 +146,14 @@ def __init__(
tokenized_column: The name of the column containing the tokenized prompt.
harmonized_column: The name of the column containing the harmonized prompt. If None,
the harmonized prompt will not be stored as text.
mode: "harmony" to render a Harmony conversation; "plain" to tokenize the raw prompt.
"""
self.prompt_column = prompt_column
self.tokenized_column = tokenized_column
self.harmonized_column = harmonized_column
self.mode = mode
if self.mode not in {"harmony", "plain"}:
raise ValueError(f"Invalid harmonize mode: {self.mode}")
Comment on lines +154 to +156
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Harmonize.__call__ still skips purely based on the presence of tokenized_column in df.columns, but process_row now skips only when the per-row value is non-null. This makes behavior differ depending on whether row processors are fused (or if fuse_row_processors=False is used). Consider aligning the dataframe-level skip logic with the row-level guard so the transform behaves consistently.

Copilot uses AI. Check for mistakes.
self.harmonizer = Harmonizer(
tokenizer_name=tokenizer_name,
encoding_name=encoding_name,
Comment on lines 157 to 159
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In mode="plain", process_row only calls self.harmonizer.to_tokens(...), but Harmonizer.__init__ still loads the Harmony encoding and constructs Harmony system content. That’s potentially expensive and unnecessary for plain tokenization. Consider a lightweight path for plain mode (e.g., defer encoding load until __call__ is used, or use the underlying tokenizer directly) to reduce init overhead.

Copilot uses AI. Check for mistakes.
Expand All @@ -171,7 +176,19 @@ def process_row(self, row: dict[str, Any]) -> dict[str, Any]:
Returns:
Row dictionary with the harmonized prompt added
"""
row[self.tokenized_column] = self.harmonizer(row[self.prompt_column])
# Guard pre-tokenized rows: the SGLang adapter adds a default Harmonize
# (GPT-OSS tokenizer + harmony mode). When row processors are fused, the
# dataframe-level skip is bypassed, so without this guard, adapter
# Harmonize would overwrite input tokens. Alternative: remove Harmonize
# from the adapter transforms and require each SGLang preset to add its
# own Harmonize with the desired tokenizer/args.
if self.tokenized_column in row and row[self.tokenized_column] is not None:
return row
Comment on lines +185 to +186
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pre-tokenized guard treats any non-None value as “already tokenized”. In pandas rows, missing values are often NaN (which is not None), so this would incorrectly skip tokenization and leave NaN in input_tokens, likely breaking downstream code that expects a list of token IDs. Consider using an explicit null check that treats NaN as missing (e.g., via pd.isna) before returning early.

Copilot uses AI. Check for mistakes.
if self.mode == "plain":
tokens = self.harmonizer.to_tokens(row[self.prompt_column])
row[self.tokenized_column] = tokens
else:
row[self.tokenized_column] = self.harmonizer(row[self.prompt_column])
Comment on lines 176 to +191
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change adds new Harmonize behavior (mode plus the overwrite-prevention guard when row processors are fused), but tests/unit/dataset_manager/test_transforms.py explicitly excludes Harmonize. Please add unit tests that cover (1) mode="plain" vs mode="harmony", and (2) fused pipelines where a second Harmonize should not overwrite existing input_tokens.

Copilot uses AI. Check for mistakes.
if self.harmonized_column is not None:
row[self.harmonized_column] = self.harmonizer.to_text(
row[self.tokenized_column]
Expand Down
5 changes: 5 additions & 0 deletions src/inference_endpoint/metrics/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,12 @@ def derive_TPOT(
output_sequence, reasoning_sequence = output_sequence_from_data(
data_bytes, join_chunks=False
)
if isinstance(output_sequence, str):
output_sequence = [output_sequence]
if not isinstance(output_sequence, list):
logging.warning(
f"Output sequence for sample {sample_uuid} is not a list but {type(output_sequence)}: {output_sequence}"
)
continue

all_chunks = output_sequence
Expand Down
6 changes: 3 additions & 3 deletions src/inference_endpoint/openai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class ChatCompletionResponseMessage(msgspec.Struct, kw_only=True, omit_defaults=

role: str
content: str | None
refusal: str | None
refusal: str | None = None


class ChatCompletionChoice(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg]
Expand All @@ -109,5 +109,5 @@ class ChatCompletionResponse(msgspec.Struct, kw_only=True, omit_defaults=True):
created: int
model: str
choices: list[ChatCompletionChoice]
usage: CompletionUsage | None
system_fingerprint: str | None
usage: CompletionUsage | None = None
system_fingerprint: str | None = None
Comment on lines 83 to +113
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no unit tests covering the msgspec OpenAI types / msgspec adapter decode path. Since these fields now default to None to support responses that omit them, it would be good to add a test that decodes a minimal OpenAI-compatible response missing refusal, usage, and system_fingerprint and asserts decoding succeeds and fields are None.

Copilot uses AI. Check for mistakes.
11 changes: 9 additions & 2 deletions src/inference_endpoint/utils/benchmark_httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import time
from dataclasses import dataclass

from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
from inference_endpoint.core.types import Query, QueryResult
from inference_endpoint.endpoint_client.config import HTTPClientConfig
from inference_endpoint.endpoint_client.cpu_affinity import compute_affinity_plan
Expand Down Expand Up @@ -399,6 +400,7 @@ def _create_client(
prompt: str,
enable_affinity: bool,
verbose: bool = True,
zmq_context: ManagedZMQContext | None = None,
) -> tuple:
"""Create an endpoint client and query data dict.

Expand All @@ -422,7 +424,7 @@ def _create_client(
endpoint_urls=[endpoint_url],
num_workers=num_workers if num_workers > 0 else -1,
max_connections=max_connections if max_connections > 0 else -1,
warmup_connections=False,
warmup_connections=0,
worker_gc_mode="relaxed",
log_level="CRITICAL",
cpu_affinity=cpu_affinity_plan,
Expand All @@ -434,7 +436,7 @@ def _create_client(
f"max_connections={config.max_connections}, stream={streaming}"
)

client = AsyncHttpEndpointClient(config)
client = AsyncHttpEndpointClient(config, zmq_context=zmq_context)
query_data = {
"prompt": prompt,
"model": "benchmark-model",
Expand Down Expand Up @@ -488,13 +490,17 @@ def run_benchmark(
except OSError:
pass

zmq_ctx_manager = ManagedZMQContext.scoped()
zmq_ctx = zmq_ctx_manager.__enter__()

client, query_data = _create_client(
endpoint_url,
num_workers,
max_connections,
streaming,
prompt,
enable_affinity,
zmq_context=zmq_ctx,
)
loop = client.loop
stats = BenchmarkStats(sse_events_per_response=sse_events_per_response)
Expand Down Expand Up @@ -613,6 +619,7 @@ async def receiver():
gc.collect()

asyncio.run_coroutine_threadsafe(client.shutdown(), loop).result(timeout=10.0)
zmq_ctx_manager.__exit__(None, None, None)

# Restore original affinity so the next sweep iteration sees all CPUs
if saved_affinity is not None:
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/commands/test_benchmark_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def test_offline_benchmark_with_echo_server(
verbose=1,
model="echo-server",
timeout=None,
warmup_connections=False,
warmup_connections=0,
)

with caplog.at_level("INFO"):
Expand Down Expand Up @@ -99,7 +99,7 @@ async def test_online_benchmark_with_echo_server(
verbose=1,
model="echo-server",
timeout=None,
warmup_connections=False,
warmup_connections=0,
)
with caplog.at_level("INFO"):
await run_benchmark_command(args)
Expand Down Expand Up @@ -143,7 +143,7 @@ async def test_benchmark_with_output_file(
verbose=0,
model="echo-server",
timeout=None,
warmup_connections=False,
warmup_connections=0,
)

await run_benchmark_command(args)
Expand Down Expand Up @@ -185,7 +185,7 @@ async def test_benchmark_mode_logging(
verbose=1,
model="echo-server",
timeout=None,
warmup_connections=False,
warmup_connections=0,
)
with caplog.at_level("INFO"):
await run_benchmark_command(args)
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/endpoint_client/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def create_futures_client(
url: str,
num_workers: int = 1,
max_connections: int = 10,
warmup_connections: bool = False,
warmup_connections: int = 0,
zmq_context=None,
) -> FuturesHttpClient:
"""Helper to create a FuturesHttpClient with specific config.
Expand All @@ -35,7 +35,7 @@ def create_futures_client(
url: The endpoint URL to connect to
num_workers: Number of worker processes (default: 1)
max_connections: Max connections per worker (default: 10 for tests)
warmup_connections: Whether to warmup connections (default: False for tests)
warmup_connections: Warmup connection count (0 = disabled, -1 = auto, >0 = explicit)
zmq_context: ManagedZMQContext when using ZMQ transport (required by default config).

Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _create_custom_client(
endpoint_urls=[f"{vllm_docker_server['url']}/v1/chat/completions"],
num_workers=num_workers,
max_connections=50,
warmup_connections=False,
warmup_connections=0,
)

# TODO(vir):
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/endpoint_client/test_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def test_many_workers(self, mock_http_echo_server):
num_workers=num_workers,
max_connections=num_workers
* 10, # ensure each worker has connections
warmup_connections=False,
warmup_connections=0,
zmq_context=zmq_ctx,
)

Expand Down Expand Up @@ -330,7 +330,7 @@ async def test_streaming_error_propagation(self):
# Use invalid endpoint to trigger errors
client = create_futures_client(
"http://invalid-endpoint-12345:9999/v1/chat/completions",
warmup_connections=False,
warmup_connections=0,
zmq_context=zmq_ctx,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/endpoint_client/test_sglang_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def sglang_futures_client():
endpoint_urls=[SGLANG_ENDPOINT],
num_workers=4,
api_type="sglang",
warmup_connections=False,
warmup_connections=0,
)

client = FuturesHttpClient(http_config)
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/endpoint_client/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def worker_config(self, mock_http_echo_server):
endpoint_urls=[f"{mock_http_echo_server.url}/v1/chat/completions"],
num_workers=1,
max_connections=10,
warmup_connections=False,
warmup_connections=0,
)
return http_config

Expand Down Expand Up @@ -229,7 +229,7 @@ def worker_config(self, mock_http_echo_server):
endpoint_urls=[f"{mock_http_echo_server.url}/v1/chat/completions"],
num_workers=1,
max_connections=10,
warmup_connections=False,
warmup_connections=0,
)
return http_config

Expand All @@ -240,7 +240,7 @@ def error_config(self):
endpoint_urls=["http://localhost:59999/v1/chat/completions"],
num_workers=1,
max_connections=10,
warmup_connections=False,
warmup_connections=0,
)
return http_config

Expand Down Expand Up @@ -416,7 +416,7 @@ async def malformed_json_non_streaming_handler(request):
endpoint_urls=[f"http://localhost:{server.port}/malformed"],
num_workers=1,
max_connections=10,
warmup_connections=False,
warmup_connections=0,
)

worker = Worker(
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/endpoint_client/test_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def manager_config(self, mock_http_echo_server):
endpoint_urls=[f"{mock_http_echo_server.url}/v1/chat/completions"],
num_workers=2,
max_connections=10,
warmup_connections=False,
warmup_connections=0,
)
return http_config

Expand Down Expand Up @@ -270,7 +270,7 @@ def worker_death_config(self):
endpoint_urls=["http://localhost:59999/advanced"],
num_workers=2,
max_connections=10,
warmup_connections=False,
warmup_connections=0,
)
return http_config

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_end_to_end_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class DeepSeekR1SampleIssuer(HttpClientSampleIssuer):
def __init__(self, tmp_path: Path, url: str, zmq_context: ManagedZMQContext):
self.http_config = HTTPClientConfig(
endpoint_urls=[urljoin(url, "/v1/chat/completions")],
warmup_connections=False,
warmup_connections=0,
)
super().__init__(HTTPEndpointClient(self.http_config, zmq_context=zmq_context))

Expand Down
2 changes: 1 addition & 1 deletion tests/performance/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def http_client(perf_http_echo_server):
http_config = HTTPClientConfig(
endpoint_urls=[f"{perf_http_echo_server.url}/v1/chat/completions"],
num_workers=1,
warmup_connections=False,
warmup_connections=0,
)

client = HTTPEndpointClient(config=http_config)
Expand Down
Loading