|
55 | 55 | from fastapi import Depends, FastAPI, HTTPException, Request as FastAPIRequest |
56 | 56 | from fastapi.middleware.cors import CORSMiddleware |
57 | 57 | from fastapi.exceptions import RequestValidationError |
58 | | -from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse as _BaseStreamingResponse |
| 58 | +from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse |
59 | 59 | from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer |
60 | 60 |
|
61 | 61 | from omlx._version import __version__ |
|
143 | 143 | logger = logging.getLogger(__name__) |
144 | 144 |
|
145 | 145 |
|
146 | | -class StreamingResponse(_BaseStreamingResponse): |
147 | | - """StreamingResponse that aborts generation when client disconnects. |
148 | | -
|
149 | | - Monitors the ASGI receive channel for http.disconnect and closes |
150 | | - the body iterator, propagating GeneratorExit through the engine's |
151 | | - stream_generate which calls abort_request(). |
152 | | - """ |
153 | | - |
154 | | - async def __call__(self, scope, receive, send): |
155 | | - disconnected = asyncio.Event() |
156 | | - |
157 | | - async def _monitor_disconnect(): |
158 | | - while True: |
159 | | - message = await receive() |
160 | | - if message.get("type") == "http.disconnect": |
161 | | - disconnected.set() |
162 | | - return |
163 | | - |
164 | | - monitor_task = asyncio.create_task(_monitor_disconnect()) |
165 | | - |
166 | | - inner = self.body_iterator |
167 | | - |
168 | | - async def _disconnect_aware(): |
169 | | - try: |
170 | | - async for chunk in inner: |
171 | | - if disconnected.is_set(): |
172 | | - logger.info("Client disconnected, stopping stream") |
173 | | - return |
174 | | - yield chunk |
175 | | - finally: |
176 | | - if hasattr(inner, "aclose"): |
177 | | - await inner.aclose() |
178 | | - |
179 | | - self.body_iterator = _disconnect_aware() |
180 | | - try: |
181 | | - await super().__call__(scope, receive, send) |
182 | | - finally: |
183 | | - monitor_task.cancel() |
184 | | - try: |
185 | | - await monitor_task |
186 | | - except asyncio.CancelledError: |
187 | | - pass |
188 | | - |
189 | | - |
190 | 146 | # Security bearer for API key authentication |
191 | 147 | security = HTTPBearer(auto_error=False) |
192 | 148 |
|
@@ -434,19 +390,57 @@ async def unhandled_exception_handler(request: FastAPIRequest, exc: Exception): |
434 | 390 | ) |
435 | 391 |
|
436 | 392 |
|
437 | | -@app.middleware("http") |
438 | | -async def debug_request_logging(request: FastAPIRequest, call_next): |
439 | | - """Log full request body for POST requests when debug logging is enabled.""" |
440 | | - if logger.isEnabledFor(5) and request.method == "POST": |
441 | | - body = await request.body() |
| 393 | +class DebugRequestLoggingMiddleware: |
| 394 | + """Pure ASGI middleware for trace-level request body logging. |
| 395 | +
|
| 396 | + Uses raw ASGI protocol instead of BaseHTTPMiddleware to avoid |
| 397 | + wrapping StreamingResponse in an intermediate pipe layer, which |
| 398 | + causes connection corruption on HTTP keep-alive connections. |
| 399 | + """ |
| 400 | + |
| 401 | + def __init__(self, app): |
| 402 | + self.app = app |
| 403 | + |
| 404 | + async def __call__(self, scope, receive, send): |
| 405 | + if ( |
| 406 | + scope["type"] != "http" |
| 407 | + or not logger.isEnabledFor(5) |
| 408 | + or scope.get("method") != "POST" |
| 409 | + ): |
| 410 | + await self.app(scope, receive, send) |
| 411 | + return |
| 412 | + |
| 413 | + # Read and cache the request body for logging |
| 414 | + body_parts = [] |
| 415 | + while True: |
| 416 | + message = await receive() |
| 417 | + body_parts.append(message) |
| 418 | + if not message.get("more_body", False): |
| 419 | + break |
| 420 | + |
| 421 | + body = b"".join(part.get("body", b"") for part in body_parts) |
442 | 422 | logger.log( |
443 | 423 | 5, |
444 | 424 | "Incoming %s %s — body: %s", |
445 | | - request.method, request.url.path, |
| 425 | + scope["method"], |
| 426 | + scope["path"], |
446 | 427 | body.decode("utf-8", errors="replace"), |
447 | 428 | ) |
448 | | - response = await call_next(request) |
449 | | - return response |
| 429 | + |
| 430 | + # Replay cached body for inner app, then forward real receive |
| 431 | + body_sent = False |
| 432 | + |
| 433 | + async def cached_receive(): |
| 434 | + nonlocal body_sent |
| 435 | + if not body_sent: |
| 436 | + body_sent = True |
| 437 | + return {"type": "http.request", "body": body, "more_body": False} |
| 438 | + return await receive() |
| 439 | + |
| 440 | + await self.app(scope, cached_receive, send) |
| 441 | + |
| 442 | + |
| 443 | +app.add_middleware(DebugRequestLoggingMiddleware) |
450 | 444 |
|
451 | 445 |
|
452 | 446 | # ============================================================================= |
|
0 commit comments