Skip to content

Commit ebe214f

Browse files
committed
fix: resolve streaming response corruption on HTTP keep-alive connections (#80)
Remove custom StreamingResponse that created duplicate ASGI receive() consumers, causing TransferEncodingError on second request. Replace BaseHTTPMiddleware with pure ASGI middleware to avoid streaming response pipe layer interference. Fix MockBaseEngine.count_chat_tokens signature.
1 parent cc12b0c commit ebe214f

File tree

3 files changed

+49
-55
lines changed

3 files changed

+49
-55
lines changed

omlx/server.py

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from fastapi import Depends, FastAPI, HTTPException, Request as FastAPIRequest
5656
from fastapi.middleware.cors import CORSMiddleware
5757
from fastapi.exceptions import RequestValidationError
58-
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse as _BaseStreamingResponse
58+
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
5959
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
6060

6161
from omlx._version import __version__
@@ -143,50 +143,6 @@
143143
logger = logging.getLogger(__name__)
144144

145145

146-
class StreamingResponse(_BaseStreamingResponse):
147-
"""StreamingResponse that aborts generation when client disconnects.
148-
149-
Monitors the ASGI receive channel for http.disconnect and closes
150-
the body iterator, propagating GeneratorExit through the engine's
151-
stream_generate which calls abort_request().
152-
"""
153-
154-
async def __call__(self, scope, receive, send):
155-
disconnected = asyncio.Event()
156-
157-
async def _monitor_disconnect():
158-
while True:
159-
message = await receive()
160-
if message.get("type") == "http.disconnect":
161-
disconnected.set()
162-
return
163-
164-
monitor_task = asyncio.create_task(_monitor_disconnect())
165-
166-
inner = self.body_iterator
167-
168-
async def _disconnect_aware():
169-
try:
170-
async for chunk in inner:
171-
if disconnected.is_set():
172-
logger.info("Client disconnected, stopping stream")
173-
return
174-
yield chunk
175-
finally:
176-
if hasattr(inner, "aclose"):
177-
await inner.aclose()
178-
179-
self.body_iterator = _disconnect_aware()
180-
try:
181-
await super().__call__(scope, receive, send)
182-
finally:
183-
monitor_task.cancel()
184-
try:
185-
await monitor_task
186-
except asyncio.CancelledError:
187-
pass
188-
189-
190146
# Security bearer for API key authentication
191147
security = HTTPBearer(auto_error=False)
192148

@@ -434,19 +390,57 @@ async def unhandled_exception_handler(request: FastAPIRequest, exc: Exception):
434390
)
435391

436392

437-
@app.middleware("http")
438-
async def debug_request_logging(request: FastAPIRequest, call_next):
439-
"""Log full request body for POST requests when debug logging is enabled."""
440-
if logger.isEnabledFor(5) and request.method == "POST":
441-
body = await request.body()
393+
class DebugRequestLoggingMiddleware:
394+
"""Pure ASGI middleware for trace-level request body logging.
395+
396+
Uses raw ASGI protocol instead of BaseHTTPMiddleware to avoid
397+
wrapping StreamingResponse in an intermediate pipe layer, which
398+
causes connection corruption on HTTP keep-alive connections.
399+
"""
400+
401+
def __init__(self, app):
402+
self.app = app
403+
404+
async def __call__(self, scope, receive, send):
405+
if (
406+
scope["type"] != "http"
407+
or not logger.isEnabledFor(5)
408+
or scope.get("method") != "POST"
409+
):
410+
await self.app(scope, receive, send)
411+
return
412+
413+
# Read and cache the request body for logging
414+
body_parts = []
415+
while True:
416+
message = await receive()
417+
body_parts.append(message)
418+
if not message.get("more_body", False):
419+
break
420+
421+
body = b"".join(part.get("body", b"") for part in body_parts)
442422
logger.log(
443423
5,
444424
"Incoming %s %s — body: %s",
445-
request.method, request.url.path,
425+
scope["method"],
426+
scope["path"],
446427
body.decode("utf-8", errors="replace"),
447428
)
448-
response = await call_next(request)
449-
return response
429+
430+
# Replay cached body for inner app, then forward real receive
431+
body_sent = False
432+
433+
async def cached_receive():
434+
nonlocal body_sent
435+
if not body_sent:
436+
body_sent = True
437+
return {"type": "http.request", "body": body, "more_body": False}
438+
return await receive()
439+
440+
await self.app(scope, cached_receive, send)
441+
442+
443+
app.add_middleware(DebugRequestLoggingMiddleware)
450444

451445

452446
# =============================================================================

tests/integration/test_e2e_streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ async def stream_generate(self, prompt: str, **kwargs) -> AsyncIterator[MockGene
104104
finish_reason="stop",
105105
)
106106

107-
def count_chat_tokens(self, messages: List[Dict], tools=None) -> int:
107+
def count_chat_tokens(self, messages: List[Dict], tools=None, chat_template_kwargs=None) -> int:
108108
prompt = self._tokenizer.apply_chat_template(messages, tokenize=False)
109109
return len(self._tokenizer.encode(prompt))
110110

tests/integration/test_server_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ async def stream_generate(self, prompt: str, **kwargs):
177177
finish_reason="stop",
178178
)
179179

180-
def count_chat_tokens(self, messages: List[Dict], tools=None) -> int:
180+
def count_chat_tokens(self, messages: List[Dict], tools=None, chat_template_kwargs=None) -> int:
181181
prompt = self._tokenizer.apply_chat_template(messages, tokenize=False)
182182
return len(self._tokenizer.encode(prompt))
183183

0 commit comments

Comments
 (0)