Skip to content

Commit 587042d

Browse files
authored
🐛 Emit http.disconnect ASGI receive() event on server shutting down for streaming responses (#2829)
1 parent c9a75fb commit 587042d

File tree

3 files changed

+86
-8
lines changed

3 files changed

+86
-8
lines changed

tests/protocols/test_http.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,76 @@ async def test_shutdown_during_idle(http_protocol_cls: type[HTTPProtocol]):
775775
assert protocol.transport.is_closing()
776776

777777

778+
async def test_shutdown_during_streaming_sends_disconnect(http_protocol_cls: type[HTTPProtocol]):
779+
"""When the server shuts down during an SSE/streaming response,
780+
receive() should return http.disconnect so the ASGI app can stop."""
781+
got_disconnect_event = False
782+
783+
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
784+
nonlocal got_disconnect_event
785+
786+
await send(
787+
{
788+
"type": "http.response.start",
789+
"status": 200,
790+
"headers": [(b"content-type", b"text/event-stream")],
791+
}
792+
)
793+
await send({"type": "http.response.body", "body": b"data: hello\n\n", "more_body": True})
794+
795+
# This simulates an SSE app waiting for disconnect
796+
message = await receive()
797+
if message["type"] == "http.disconnect":
798+
got_disconnect_event = True
799+
800+
protocol = get_connected_protocol(app, http_protocol_cls)
801+
protocol.data_received(SIMPLE_GET_REQUEST)
802+
# Trigger server shutdown while the app is streaming
803+
protocol.shutdown() # type: ignore[attr-defined]
804+
await protocol.loop.run_one()
805+
assert got_disconnect_event
806+
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
807+
assert b"data: hello" in protocol.transport.buffer
808+
assert protocol.transport.is_closing()
809+
810+
811+
async def test_shutdown_during_streaming_allows_send_before_exit(http_protocol_cls: type[HTTPProtocol]):
812+
"""During server shutdown, the app should still be able to send() data
813+
(e.g., a farewell SSE event) before returning."""
814+
farewell_sent = False
815+
816+
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
817+
nonlocal farewell_sent
818+
819+
await send(
820+
{
821+
"type": "http.response.start",
822+
"status": 200,
823+
"headers": [
824+
(b"content-type", b"text/event-stream"),
825+
(b"transfer-encoding", b"chunked"),
826+
],
827+
}
828+
)
829+
await send({"type": "http.response.body", "body": b"data: hello\n\n", "more_body": True})
830+
831+
# Wait for disconnect
832+
message = await receive()
833+
assert message["type"] == "http.disconnect"
834+
835+
# Send a farewell event — this should still work since the transport is open
836+
await send({"type": "http.response.body", "body": b"data: goodbye\n\n", "more_body": True})
837+
farewell_sent = True
838+
839+
protocol = get_connected_protocol(app, http_protocol_cls)
840+
protocol.data_received(SIMPLE_GET_REQUEST)
841+
protocol.shutdown() # type: ignore[attr-defined]
842+
await protocol.loop.run_one()
843+
assert farewell_sent
844+
assert b"data: hello" in protocol.transport.buffer
845+
assert b"data: goodbye" in protocol.transport.buffer
846+
847+
778848
async def test_100_continue_sent_when_body_consumed(http_protocol_cls: type[HTTPProtocol]):
779849
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
780850
body = b""

uvicorn/protocols/http/h11_impl.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,8 @@ def shutdown(self) -> None:
344344
self.transport.close()
345345
else:
346346
self.cycle.keep_alive = False
347+
self.cycle.shutting_down = True
348+
self.cycle.message_event.set()
347349

348350
def pause_writing(self) -> None:
349351
"""
@@ -397,6 +399,7 @@ def __init__(
397399
self.disconnected = False
398400
self.keep_alive = True
399401
self.waiting_for_100_continue = conn.they_are_waiting_for_100_continue
402+
self.shutting_down = False
400403

401404
# Request state
402405
self.body = bytearray()
@@ -429,8 +432,9 @@ async def run_asgi(self, app: ASGI3Application) -> None:
429432
self.logger.error(msg)
430433
await self.send_500_response()
431434
elif not self.response_complete and not self.disconnected:
432-
msg = "ASGI callable returned without completing response."
433-
self.logger.error(msg)
435+
if not self.shutting_down:
436+
msg = "ASGI callable returned without completing response."
437+
self.logger.error(msg)
434438
self.transport.close()
435439
finally:
436440
self.on_response = lambda: None
@@ -528,12 +532,12 @@ async def receive(self) -> ASGIReceiveEvent:
528532
self.transport.write(output)
529533
self.waiting_for_100_continue = False
530534

531-
if not self.disconnected and not self.response_complete:
535+
if not self.disconnected and not self.response_complete and not self.shutting_down:
532536
self.flow.resume_reading()
533537
await self.message_event.wait()
534538
self.message_event.clear()
535539

536-
if self.disconnected or self.response_complete:
540+
if self.disconnected or self.response_complete or self.shutting_down:
537541
return {"type": "http.disconnect"}
538542

539543
message: HTTPRequestEvent = {"type": "http.request", "body": bytes(self.body), "more_body": self.more_body}

uvicorn/protocols/http/httptools_impl.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ def shutdown(self) -> None:
349349
self.transport.close()
350350
else:
351351
self.cycle.keep_alive = False
352+
self.cycle.shutting_down = True
353+
self.cycle.message_event.set()
352354

353355
def pause_writing(self) -> None:
354356
"""
@@ -400,6 +402,7 @@ def __init__(
400402
self.disconnected = False
401403
self.keep_alive = keep_alive
402404
self.waiting_for_100_continue = expect_100_continue
405+
self.shutting_down = False
403406

404407
# Request state
405408
self.body = bytearray()
@@ -434,8 +437,9 @@ async def run_asgi(self, app: ASGI3Application) -> None:
434437
self.logger.error(msg)
435438
await self.send_500_response()
436439
elif not self.response_complete and not self.disconnected:
437-
msg = "ASGI callable returned without completing response."
438-
self.logger.error(msg)
440+
if not self.shutting_down:
441+
msg = "ASGI callable returned without completing response."
442+
self.logger.error(msg)
439443
self.transport.close()
440444
finally:
441445
self.on_response = lambda: None
@@ -560,12 +564,12 @@ async def receive(self) -> ASGIReceiveEvent:
560564
self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
561565
self.waiting_for_100_continue = False
562566

563-
if not self.disconnected and not self.response_complete:
567+
if not self.disconnected and not self.response_complete and not self.shutting_down:
564568
self.flow.resume_reading()
565569
await self.message_event.wait()
566570
self.message_event.clear()
567571

568-
if self.disconnected or self.response_complete:
572+
if self.disconnected or self.response_complete or self.shutting_down:
569573
return {"type": "http.disconnect"}
570574
message: HTTPRequestEvent = {"type": "http.request", "body": bytes(self.body), "more_body": self.more_body}
571575
self.body = bytearray()

0 commit comments

Comments
 (0)