Skip to content

Commit 4d4139b

Browse files
committed
optimize msgspec implementation
1 parent ef89455 commit 4d4139b

File tree

8 files changed

+1113
-14
lines changed

8 files changed

+1113
-14
lines changed

.cursor/rules/msgspec-patterns.mdc

Lines changed: 534 additions & 0 deletions
Large diffs are not rendered by default.

src/inference_endpoint/core/types.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@ class QueryStatus(Enum):
5252
_OUTPUT_RESULT_TYPE = str | tuple[str, ...] | _OUTPUT_DICT_TYPE | None
5353

5454

55-
class Query(msgspec.Struct, kw_only=True):
55+
class Query(
56+
msgspec.Struct,
57+
kw_only=True,
58+
array_like=True,
59+
omit_defaults=True,
60+
gc=False
61+
):
5662
"""Represents a single inference query to be sent to an endpoint.
5763
5864
A Query encapsulates all information needed to make an HTTP request to
@@ -72,6 +78,17 @@ class Query(msgspec.Struct, kw_only=True):
7278
... data={"prompt": "Hello", "model": "Qwen/Qwen3-8B", "max_tokens": 100},
7379
... headers={"Authorization": "Bearer token123"},
7480
... )
81+
82+
Note:
83+
gc=False: Safe because data/headers are simple key-value pairs without cycles.
84+
Do NOT store self-referential or cyclic structures in data/headers fields.
85+
86+
array_like=True: Encodes as array instead of object (e.g., ["id", {...}, {...}, 0.0]
87+
instead of {"id": ..., "data": ..., ...}). Provides ~6-50% size reduction and
88+
~6-29% ser/des speedup for ZMQ transport depending on payload size.
89+
90+
omit_defaults=True: Fields with default values are omitted during encoding,
91+
further reducing message size for queries with empty headers.
7592
"""
7693

7794
id: str = msgspec.field(default_factory=lambda: str(uuid.uuid4()))
@@ -80,7 +97,15 @@ class Query(msgspec.Struct, kw_only=True):
8097
created_at: float = msgspec.field(default_factory=time.time)
8198

8299

83-
class QueryResult(msgspec.Struct, tag="query_result", kw_only=True, frozen=True):
100+
class QueryResult(
101+
msgspec.Struct,
102+
tag="query_result",
103+
kw_only=True,
104+
frozen=True,
105+
array_like=True,
106+
omit_defaults=True,
107+
gc=False,
108+
):
84109
"""Result of a completed inference query.
85110
86111
Represents the outcome of processing a Query, including the response text,
@@ -106,6 +131,15 @@ class QueryResult(msgspec.Struct, tag="query_result", kw_only=True, frozen=True)
106131
Note:
107132
The completed_at field is intentionally set internally to prevent
108133
benchmark result manipulation. Users must not override this timestamp.
134+
135+
gc=False: Safe because metadata contains only scalar key-value pairs.
136+
Do NOT store cyclic references in metadata or response_output fields.
137+
138+
omit_defaults=True: Fields with static defaults (ie. those NOT using default_factory)
139+
are omitted if value equals default.
140+
141+
array_like=True: Encodes as array instead of object (e.g. ["id", "chunk", false, {}]
142+
instead of {"id": ..., "response_chunk": ..., ...}). Reduces payload size.
109143
"""
110144

111145
id: str = ""
@@ -143,7 +177,14 @@ def __post_init__(self):
143177
self.response_output[k] = tuple(v)
144178

145179

146-
class StreamChunk(msgspec.Struct, tag="stream_chunk", kw_only=True):
180+
class StreamChunk(
181+
msgspec.Struct,
182+
tag="stream_chunk",
183+
kw_only=True,
184+
array_like=True,
185+
omit_defaults=True,
186+
gc=False,
187+
):
147188
"""A single chunk from a streaming inference response.
148189
149190
Streaming responses are sent incrementally as the model generates text.
@@ -163,6 +204,16 @@ class StreamChunk(msgspec.Struct, tag="stream_chunk", kw_only=True):
163204
Streaming "Hello World" might produce:
164205
>>> StreamChunk(id="q1", response_chunk="Hello", is_complete=False)
165206
>>> StreamChunk(id="q1", response_chunk=" World", is_complete=True)
207+
208+
Note:
209+
gc=False: Safe because metadata contains only scalar key-value pairs.
210+
Do NOT store cyclic references in metadata field.
211+
212+
omit_defaults=True: Fields with static defaults (ie. those NOT using default_factory)
213+
are omitted if value equals default.
214+
215+
array_like=True: Encodes as array instead of object (e.g. ["id", "chunk", false, {}]
216+
instead of {"id": ..., "response_chunk": ..., ...}). Reduces payload size.
166217
"""
167218

168219
id: str = ""

src/inference_endpoint/openai/types.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,40 +24,67 @@
2424
# ============================================================================
2525

2626

27-
class SSEDelta(msgspec.Struct):
27+
# NOTE(vir): msgspec usage
28+
# omit_defaults=True: Fields with static defaults are omitted if value equals default (ie those not using default_factory)
29+
# gc=False: Safe for request/response structs with scalar and nested struct fields only.
30+
31+
32+
class SSEDelta(
33+
msgspec.Struct,
34+
omit_defaults=True,
35+
gc=False
36+
):
2837
"""SSE delta object containing content."""
2938

3039
content: str = ""
3140
reasoning: str = ""
3241

3342

34-
class SSEChoice(msgspec.Struct):
43+
class SSEChoice(
44+
msgspec.Struct,
45+
omit_defaults=True,
46+
gc=False
47+
):
3548
"""SSE choice object containing delta."""
3649

3750
delta: SSEDelta = msgspec.field(default_factory=SSEDelta)
3851
finish_reason: str | None = None
3952

4053

41-
class SSEMessage(msgspec.Struct):
54+
class SSEMessage(
55+
msgspec.Struct,
56+
omit_defaults=True,
57+
gc=False
58+
):
4259
"""SSE message structure for OpenAI streaming responses."""
4360

4461
choices: list[SSEChoice] = msgspec.field(default_factory=list)
4562

4663

4764
# ============================================================================
48-
# OpenAI Chat Completion Types (msgspec-based)
65+
# OpenAI Chat Completion Types
4966
# ============================================================================
5067

5168

52-
class ChatMessage(msgspec.Struct, kw_only=True, omit_defaults=True):
69+
class ChatMessage(
70+
msgspec.Struct,
71+
kw_only=True,
72+
omit_defaults=True,
73+
gc=False
74+
):
5375
"""Chat message in OpenAI format."""
5476

5577
role: str
5678
content: str
5779
name: str | None = None
5880

5981

60-
class ChatCompletionRequest(msgspec.Struct, kw_only=True, omit_defaults=True):
82+
class ChatCompletionRequest(
83+
msgspec.Struct,
84+
kw_only=True,
85+
omit_defaults=True,
86+
gc=False
87+
):
6188
"""OpenAI chat completion request."""
6289

6390
model: str
@@ -76,32 +103,52 @@ class ChatCompletionRequest(msgspec.Struct, kw_only=True, omit_defaults=True):
76103
user: str | None = None
77104

78105

79-
class ChatCompletionResponseMessage(msgspec.Struct, kw_only=True, omit_defaults=True):
106+
class ChatCompletionResponseMessage(
107+
msgspec.Struct,
108+
kw_only=True,
109+
omit_defaults=True,
110+
gc=False
111+
):
80112
"""Response message from OpenAI."""
81113

82114
role: str
83115
content: str | None
84116
refusal: str | None
85117

86118

87-
class ChatCompletionChoice(msgspec.Struct, kw_only=True, omit_defaults=True):
119+
class ChatCompletionChoice(
120+
msgspec.Struct,
121+
kw_only=True,
122+
omit_defaults=True,
123+
gc=False
124+
):
88125
"""A single choice in the completion response."""
89126

90127
index: int
91128
message: ChatCompletionResponseMessage
92129
finish_reason: str | None
93130

94131

95-
class CompletionUsage(msgspec.Struct, kw_only=True, omit_defaults=True):
132+
class CompletionUsage(
133+
msgspec.Struct,
134+
kw_only=True,
135+
omit_defaults=True,
136+
gc=False
137+
):
96138
"""Token usage statistics."""
97139

98140
prompt_tokens: int
99141
completion_tokens: int
100142
total_tokens: int
101143

102144

103-
class ChatCompletionResponse(msgspec.Struct, kw_only=True, omit_defaults=True):
104-
"""OpenAI chat completion response (msgspec version)."""
145+
class ChatCompletionResponse(
146+
msgspec.Struct,
147+
kw_only=True,
148+
omit_defaults=True,
149+
gc=False
150+
):
151+
"""OpenAI chat completion response."""
105152

106153
id: str
107154
object: str = "chat.completion"

tests/performance/openai/__init__.py

Whitespace-only changes.
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Performance benchmarks for OpenAIAdapter (pydantic + orjson) using pytest-benchmark.
18+
19+
Measures ns/op for encode_query, decode_response, decode_sse_message
20+
with varying payload sizes (0, 100, 1k, 8k, 32k). Run with:
21+
22+
pytest tests/performance/openai/test_adapter.py --benchmark-only --benchmark-columns=mean,stddev,ops
23+
"""
24+
25+
import json
26+
27+
import pytest
28+
29+
from inference_endpoint.core.types import Query
30+
from inference_endpoint.openai.openai_adapter import OpenAIAdapter
31+
32+
TEXT_SIZES = {
33+
"empty": "",
34+
"100": "x" * 100,
35+
"1k": "x" * 1_000,
36+
"8k": "x" * 8_000,
37+
"32k": "x" * 32_000,
38+
}
39+
40+
41+
def make_query(text: str) -> Query:
42+
"""Create a Query for benchmarks."""
43+
return Query(
44+
id="test-id",
45+
data={"prompt": text, "model": "test-model", "max_completion_tokens": 100},
46+
headers={"Authorization": "Bearer token"},
47+
)
48+
49+
50+
def make_response_bytes(text: str) -> bytes:
51+
"""Create OpenAI-compatible response JSON bytes."""
52+
return json.dumps(
53+
{
54+
"id": "chatcmpl-test",
55+
"object": "chat.completion",
56+
"created": 1234567890,
57+
"model": "test-model",
58+
"choices": [
59+
{
60+
"index": 0,
61+
"message": {"role": "assistant", "content": text, "refusal": None},
62+
"finish_reason": "stop",
63+
"logprobs": None,
64+
}
65+
],
66+
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
67+
"system_fingerprint": "fp_test",
68+
}
69+
).encode()
70+
71+
72+
def make_sse_bytes(text: str) -> bytes:
73+
"""Create SSE message JSON bytes."""
74+
return json.dumps(
75+
{
76+
"choices": [
77+
{"delta": {"content": text, "reasoning": ""}, "finish_reason": None}
78+
]
79+
}
80+
).encode()
81+
82+
83+
@pytest.mark.parametrize("size_name,text", TEXT_SIZES.items(), ids=TEXT_SIZES.keys())
84+
def test_encode_query(benchmark, size_name, text):
85+
"""Benchmark encode_query (Query -> HTTP bytes)."""
86+
query = make_query(text)
87+
benchmark.group = "openai_adapter_encode_query"
88+
benchmark(OpenAIAdapter.encode_query, query)
89+
90+
91+
@pytest.mark.parametrize("size_name,text", TEXT_SIZES.items(), ids=TEXT_SIZES.keys())
92+
def test_decode_response(benchmark, size_name, text):
93+
"""Benchmark decode_response (HTTP bytes -> QueryResult)."""
94+
response_bytes = make_response_bytes(text)
95+
benchmark.group = "openai_adapter_decode_response"
96+
benchmark(OpenAIAdapter.decode_response, response_bytes, "test-id")
97+
98+
99+
@pytest.mark.parametrize("size_name,text", TEXT_SIZES.items(), ids=TEXT_SIZES.keys())
100+
def test_decode_sse(benchmark, size_name, text):
101+
"""Benchmark decode_sse_message (SSE bytes -> content)."""
102+
sse_bytes = make_sse_bytes(text)
103+
benchmark.group = "openai_adapter_decode_sse"
104+
benchmark(OpenAIAdapter.decode_sse_message, sse_bytes)

0 commit comments

Comments
 (0)