Skip to content

Commit 672374f

Browse files
committed
Fix SGLang adapter usage, run tests manually on local SGLang server
1 parent ad77838 commit 672374f

File tree

3 files changed

+29
-24
lines changed

3 files changed

+29
-24
lines changed

src/inference_endpoint/endpoint_client/worker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def add_chunk(self, delta: SGLangSSEDelta) -> StreamChunk | None:
187187
if not isinstance(delta, SGLangSSEDelta):
188188
return None
189189

190-
if delta.total_tokens == self.total_tokens:
190+
if delta.total_completion_tokens == self.total_tokens:
191191
return None
192192

193193
# In SGLang /generate, the .text field is the total accumulated text, not
@@ -196,8 +196,8 @@ def add_chunk(self, delta: SGLangSSEDelta) -> StreamChunk | None:
196196
if (start_idx := len(delta.text)) > len(self.text):
197197
content_diff = delta.text[start_idx:]
198198
self.text = delta.text
199-
self.token_ids.extend(delta.token_ids)
200-
self.total_tokens = delta.total_tokens
199+
self.token_ids.extend(delta.token_delta)
200+
self.total_tokens = delta.total_completion_tokens
201201
if delta.has_retractions:
202202
# For now, we won't be handling retractions if they occur, but we will
203203
# report it as part of the metadata if it does happen.
@@ -228,7 +228,7 @@ def get_final_output(self) -> QueryResult:
228228
"final_chunk": True,
229229
"retraction_occurred": self.retraction_occurred,
230230
"n_tokens": self.total_tokens,
231-
"output_tokens": self.token_ids,
231+
"token_ids": self.token_ids,
232232
},
233233
)
234234

src/inference_endpoint/sglang/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class SamplingParams(msgspec.Struct, kw_only=True, omit_defaults=True):
4343
class SGLangGenerateRequest(msgspec.Struct, kw_only=True, omit_defaults=True):
4444
input_ids: list[int]
4545
sampling_params: SamplingParams
46-
stream: bool = True
46+
stream: bool
4747

4848

4949
class MetaInfo(msgspec.Struct, kw_only=True, omit_defaults=True):
@@ -65,6 +65,6 @@ class SGLangGenerateResponse(msgspec.Struct, kw_only=True, omit_defaults=True):
6565

6666
class SGLangSSEDelta(msgspec.Struct):
6767
text: str = ""
68-
token_delta: int = 0
68+
token_delta: list[int] = msgspec.field(default_factory=list)
6969
total_completion_tokens: int = 0
7070
has_retractions: bool = False

tests/integration/endpoint_client/test_sglang_adapter.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
"""Integration tests for SGLang adapter with real GPT-OSS server.
1717
1818
This test assumes a server running GPT-OSS is available at localhost:30000.
19-
To start a server, use:
20-
python3 -m sglang.launch_server --model-path <model> --host 0.0.0.0 --port 30000
2119
"""
2220

2321
import asyncio
@@ -98,6 +96,7 @@ class TestSGLangAdapterIntegration:
9896
"""Integration tests for SGLang adapter with real GPT-OSS server."""
9997

10098
@pytest.mark.asyncio
99+
@pytest.mark.run_explicitly
101100
@pytest.mark.integration
102101
async def test_sglang_non_streaming_request(self, sglang_futures_client):
103102
"""Test non-streaming request through SGLang adapter.
@@ -114,6 +113,10 @@ async def test_sglang_non_streaming_request(self, sglang_futures_client):
114113
"input_tokens": input_tokens,
115114
"stream": False,
116115
},
116+
headers={
117+
"Content-Type": "application/json",
118+
"Accept": "application/json",
119+
},
117120
)
118121

119122
future = sglang_futures_client.issue_query(query)
@@ -126,19 +129,14 @@ async def test_sglang_non_streaming_request(self, sglang_futures_client):
126129
assert len(result.response_output) > 0
127130

128131
# Verify metadata
129-
assert "metadata" in dir(result)
130132
assert result.metadata is not None
131133
assert "token_ids" in result.metadata
132134
assert "n_tokens" in result.metadata
133135
assert isinstance(result.metadata["token_ids"], list)
134136
assert isinstance(result.metadata["n_tokens"], int)
135137

136-
print(
137-
f"\nNon-streaming response: {result.response_output[:100]}..."
138-
) # Print first 100 chars
139-
print(f"Token count: {result.metadata['n_tokens']}")
140-
141138
@pytest.mark.asyncio
139+
@pytest.mark.run_explicitly
142140
@pytest.mark.integration
143141
async def test_sglang_streaming_request(self, sglang_futures_client):
144142
"""Test streaming request through SGLang adapter.
@@ -157,6 +155,10 @@ async def test_sglang_streaming_request(self, sglang_futures_client):
157155
"temperature": 0.8,
158156
"stream": True,
159157
},
158+
headers={
159+
"Content-Type": "application/json",
160+
"Accept": "text/event-stream",
161+
},
160162
)
161163

162164
future = sglang_futures_client.issue_query(query)
@@ -167,15 +169,18 @@ async def test_sglang_streaming_request(self, sglang_futures_client):
167169
assert "response_output" in dir(result)
168170
assert result.response_output is not None
169171

170-
# In streaming mode, response_output should contain accumulated output
171-
assert "output" in result.response_output
172-
output_chunks = result.response_output["output"]
173-
assert isinstance(output_chunks, list)
174-
assert len(output_chunks) > 0
172+
assert result.metadata is not None
173+
assert "token_ids" in result.metadata
174+
assert "n_tokens" in result.metadata
175+
assert isinstance(result.metadata["token_ids"], list)
176+
assert isinstance(result.metadata["n_tokens"], int)
175177

176-
# Reconstruct full text
177-
full_text = "".join(output_chunks)
178-
assert len(full_text) > 0
178+
# Check that something was generated, but no more than max_new_tokens
179+
assert 0 < result.metadata["n_tokens"] and result.metadata["n_tokens"] <= 100
179180

180-
print(f"\nStreaming response: {full_text[:100]}...") # Print first 100 chars
181-
print(f"Number of chunks: {len(output_chunks)}")
181+
# The token IDs in the result should be at most n_tokens because of retractions
182+
if result.metadata["retraction_occurred"]:
183+
assert len(result.metadata["token_ids"]) <= result.metadata["n_tokens"]
184+
else:
185+
# STOP token is not included in the response, but counts towards generated
186+
assert len(result.metadata["token_ids"]) + 1 == result.metadata["n_tokens"]

0 commit comments

Comments
 (0)