Skip to content

Commit ad77838

Browse files
committed
Fix cyclical dependencies for SSE msgspec structs, add initial test for SGLang adapter
1 parent 657a601 commit ad77838

File tree

13 files changed

+412
-179
lines changed

13 files changed

+412
-179
lines changed

src/inference_endpoint/commands/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from inference_endpoint.commands.utils import get_default_report_path
3636
from inference_endpoint.config.runtime_settings import RuntimeSettings
3737
from inference_endpoint.config.schema import (
38+
APIType,
3839
BenchmarkConfig,
3940
ClientSettings,
4041
Dataset,
@@ -56,7 +57,6 @@
5657
from inference_endpoint.dataset_manager.factory import DataLoaderFactory
5758
from inference_endpoint.endpoint_client.configs import (
5859
AioHttpConfig,
59-
APIType,
6060
HTTPClientConfig,
6161
ZMQConfig,
6262
)

src/inference_endpoint/commands/probe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
import time
2323
from urllib.parse import urljoin
2424

25+
from inference_endpoint.config.schema import APIType
2526
from inference_endpoint.core.types import Query, QueryResult
2627
from inference_endpoint.endpoint_client.configs import (
2728
AioHttpConfig,
28-
APIType,
2929
HTTPClientConfig,
3030
ZMQConfig,
3131
)

src/inference_endpoint/config/schema.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,23 @@
2727
from pydantic import BaseModel, Field
2828

2929
from .. import metrics
30-
from ..endpoint_client.configs import APIType
3130
from .ruleset_base import BenchmarkSuiteRuleset
3231

3332

33+
class APIType(str, Enum):
34+
OPENAI = "openai"
35+
SGLANG = "sglang"
36+
37+
def default_route(self) -> str:
38+
match self:
39+
case APIType.OPENAI:
40+
return "/v1/chat/completions"
41+
case APIType.SGLANG:
42+
return "/generate"
43+
case _:
44+
raise ValueError(f"Invalid API type: {self}")
45+
46+
3447
class LoadPatternType(str, Enum):
3548
"""Load pattern types."""
3649

src/inference_endpoint/endpoint_client/configs.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,14 @@
1818
import os
1919
import socket
2020
from dataclasses import dataclass, field
21-
from enum import Enum
2221
from pathlib import Path
2322
from typing import Any
2423

2524
import aiohttp
2625
import zmq
2726

28-
from inference_endpoint.endpoint_client.adapter_protocol import HttpRequestAdapter
29-
30-
31-
class APIType(Enum):
32-
OPENAI = "openai"
33-
SGLANG = "sglang"
34-
35-
def default_route(self) -> str:
36-
match self:
37-
case APIType.OPENAI:
38-
return "/v1/chat/completions"
39-
case APIType.SGLANG:
40-
return "/generate"
41-
case _:
42-
raise ValueError(f"Invalid API type: {self}")
27+
from ..config.schema import APIType
28+
from .adapter_protocol import HttpRequestAdapter
4329

4430

4531
@dataclass
@@ -80,6 +66,9 @@ class HTTPClientConfig:
8066

8167
def __post_init__(self):
8268
# set default adapter in __post_init__ to avoid circular dependency
69+
if isinstance(self.api_type, str):
70+
self.api_type = APIType(self.api_type)
71+
8372
if self.adapter is None:
8473
if self.api_type == APIType.OPENAI:
8574
from inference_endpoint.openai.openai_msgspec_adapter import (

src/inference_endpoint/endpoint_client/worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,24 @@
3131
import zmq
3232
import zmq.asyncio
3333

34+
from inference_endpoint.config.schema import APIType
3435
from inference_endpoint.core.types import (
3536
Query,
3637
QueryResult,
3738
StreamChunk,
3839
)
3940
from inference_endpoint.endpoint_client.configs import (
4041
AioHttpConfig,
41-
APIType,
4242
HTTPClientConfig,
4343
ZMQConfig,
4444
)
4545
from inference_endpoint.endpoint_client.zmq_utils import ZMQPullSocket, ZMQPushSocket
4646
from inference_endpoint.load_generator.events import SampleEvent
4747
from inference_endpoint.metrics.recorder import EventRecorder
4848
from inference_endpoint.metrics.reporter import MetricsReporter
49-
from inference_endpoint.openai.openai_adapter import SSEDelta as OpenAISSEDelta
49+
from inference_endpoint.openai.types import SSEDelta as OpenAISSEDelta
5050
from inference_endpoint.profiling import profile
51-
from inference_endpoint.sglang.adapter import SGLangSSEDelta
51+
from inference_endpoint.sglang.types import SGLangSSEDelta
5252
from inference_endpoint.utils.logging import setup_logging
5353

5454
logger = logging.getLogger(__name__)

src/inference_endpoint/openai/openai_adapter.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,7 @@
3434
Role6,
3535
ServiceTier,
3636
)
37-
38-
39-
# msgspec structs for typed SSE message parsing (OpenAI streaming format)
40-
class SSEDelta(msgspec.Struct):
41-
"""SSE delta object containing content."""
42-
43-
content: str = ""
44-
reasoning: str = ""
45-
46-
47-
class SSEChoice(msgspec.Struct):
48-
"""SSE choice object containing delta."""
49-
50-
delta: SSEDelta = msgspec.field(default_factory=SSEDelta)
51-
finish_reason: str | None = None
52-
53-
54-
class SSEMessage(msgspec.Struct):
55-
"""SSE message structure for OpenAI streaming responses."""
56-
57-
choices: list[SSEChoice] = msgspec.field(default_factory=list)
37+
from .types import SSEMessage
5838

5939

6040
class OpenAIAdapter(HttpRequestAdapter):

src/inference_endpoint/openai/openai_msgspec_adapter.py

Lines changed: 8 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -25,75 +25,14 @@
2525
# Import base class and shared SSE types
2626
from inference_endpoint.endpoint_client.adapter_protocol import HttpRequestAdapter
2727

28-
from .openai_adapter import SSEMessage
29-
30-
# ============================================================================
31-
# msgspec Structs for OpenAI API Types
32-
# ============================================================================
33-
34-
35-
class ChatMessage(msgspec.Struct, kw_only=True, omit_defaults=True):
36-
"""Chat message in OpenAI format."""
37-
38-
role: str
39-
content: str
40-
name: str | None = None
41-
42-
43-
class ChatCompletionRequest(msgspec.Struct, kw_only=True, omit_defaults=True):
44-
"""OpenAI chat completion request."""
45-
46-
model: str
47-
messages: list[ChatMessage]
48-
temperature: float | None = None
49-
max_completion_tokens: int | None = None
50-
stream: bool | None = None
51-
top_p: float | None = None
52-
top_k: int | None = None
53-
repetition_penalty: float | None = None
54-
n: int | None = None
55-
stop: str | list[str] | None = None
56-
presence_penalty: float | None = None
57-
frequency_penalty: float | None = None
58-
logit_bias: dict[str, float] | None = None
59-
user: str | None = None
60-
61-
62-
class ChatCompletionResponseMessage(msgspec.Struct, kw_only=True, omit_defaults=True):
63-
"""Response message from OpenAI."""
64-
65-
role: str
66-
content: str | None
67-
refusal: str | None
68-
69-
70-
class ChatCompletionChoice(msgspec.Struct, kw_only=True, omit_defaults=True):
71-
"""A single choice in the completion response."""
72-
73-
index: int
74-
message: ChatCompletionResponseMessage
75-
finish_reason: str | None
76-
77-
78-
class CompletionUsage(msgspec.Struct, kw_only=True, omit_defaults=True):
79-
"""Token usage statistics."""
80-
81-
prompt_tokens: int
82-
completion_tokens: int
83-
total_tokens: int
84-
85-
86-
class ChatCompletionResponse(msgspec.Struct, kw_only=True, omit_defaults=True):
87-
"""OpenAI chat completion response (msgspec version)."""
88-
89-
id: str
90-
object: str = "chat.completion"
91-
created: int
92-
model: str
93-
choices: list[ChatCompletionChoice]
94-
usage: CompletionUsage | None
95-
system_fingerprint: str | None
96-
28+
from .types import (
29+
ChatCompletionChoice,
30+
ChatCompletionRequest,
31+
ChatCompletionResponse,
32+
ChatCompletionResponseMessage,
33+
ChatMessage,
34+
SSEMessage,
35+
)
9736

9837
# ============================================================================
9938
# msgspec-based OpenAI Adapter
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
msgspec types for OpenAI API serialization/deserialization.
18+
"""
19+
20+
import msgspec
21+
22+
# ============================================================================
23+
# SSE (Server-Sent Events) Types for OpenAI streaming format
24+
# ============================================================================
25+
26+
27+
class SSEDelta(msgspec.Struct):
28+
"""SSE delta object containing content."""
29+
30+
content: str = ""
31+
reasoning: str = ""
32+
33+
34+
class SSEChoice(msgspec.Struct):
35+
"""SSE choice object containing delta."""
36+
37+
delta: SSEDelta = msgspec.field(default_factory=SSEDelta)
38+
finish_reason: str | None = None
39+
40+
41+
class SSEMessage(msgspec.Struct):
42+
"""SSE message structure for OpenAI streaming responses."""
43+
44+
choices: list[SSEChoice] = msgspec.field(default_factory=list)
45+
46+
47+
# ============================================================================
48+
# OpenAI Chat Completion Types (msgspec-based)
49+
# ============================================================================
50+
51+
52+
class ChatMessage(msgspec.Struct, kw_only=True, omit_defaults=True):
53+
"""Chat message in OpenAI format."""
54+
55+
role: str
56+
content: str
57+
name: str | None = None
58+
59+
60+
class ChatCompletionRequest(msgspec.Struct, kw_only=True, omit_defaults=True):
61+
"""OpenAI chat completion request."""
62+
63+
model: str
64+
messages: list[ChatMessage]
65+
temperature: float | None = None
66+
max_completion_tokens: int | None = None
67+
stream: bool | None = None
68+
top_p: float | None = None
69+
top_k: int | None = None
70+
repetition_penalty: float | None = None
71+
n: int | None = None
72+
stop: str | list[str] | None = None
73+
presence_penalty: float | None = None
74+
frequency_penalty: float | None = None
75+
logit_bias: dict[str, float] | None = None
76+
user: str | None = None
77+
78+
79+
class ChatCompletionResponseMessage(msgspec.Struct, kw_only=True, omit_defaults=True):
80+
"""Response message from OpenAI."""
81+
82+
role: str
83+
content: str | None
84+
refusal: str | None
85+
86+
87+
class ChatCompletionChoice(msgspec.Struct, kw_only=True, omit_defaults=True):
88+
"""A single choice in the completion response."""
89+
90+
index: int
91+
message: ChatCompletionResponseMessage
92+
finish_reason: str | None
93+
94+
95+
class CompletionUsage(msgspec.Struct, kw_only=True, omit_defaults=True):
96+
"""Token usage statistics."""
97+
98+
prompt_tokens: int
99+
completion_tokens: int
100+
total_tokens: int
101+
102+
103+
class ChatCompletionResponse(msgspec.Struct, kw_only=True, omit_defaults=True):
104+
"""OpenAI chat completion response (msgspec version)."""
105+
106+
id: str
107+
object: str = "chat.completion"
108+
created: int
109+
model: str
110+
choices: list[ChatCompletionChoice]
111+
usage: CompletionUsage | None
112+
system_fingerprint: str | None

0 commit comments

Comments
 (0)