Skip to content

Commit 2d2c62b

Browse files
committed
Improve WSGI compliance
The response body is closed if it has a close method as per PEP 3333. In addition the response headers are only sent when the first response body byte is available to send. Finally, an error is raised if start_response has not been called by the app.
1 parent cb443a4 commit 2d2c62b

File tree

2 files changed

+69
-33
lines changed

2 files changed

+69
-33
lines changed

src/hypercorn/app_wrappers.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,25 +84,40 @@ async def handle_http(
8484

8585
def run_app(self, environ: dict, send: Callable) -> None:
8686
headers: List[Tuple[bytes, bytes]]
87+
headers_sent = False
88+
response_started = False
8789
status_code: Optional[int] = None
8890

8991
def start_response(
9092
status: str,
9193
response_headers: List[Tuple[str, str]],
9294
exc_info: Optional[Exception] = None,
9395
) -> None:
94-
nonlocal headers, status_code
96+
nonlocal headers, response_started, status_code
9597

9698
raw, _ = status.split(" ", 1)
9799
status_code = int(raw)
98100
headers = [
99101
(name.lower().encode("ascii"), value.encode("ascii"))
100102
for name, value in response_headers
101103
]
102-
send({"type": "http.response.start", "status": status_code, "headers": headers})
104+
response_started = True
103105

104-
for output in self.app(environ, start_response):
105-
send({"type": "http.response.body", "body": output, "more_body": True})
106+
response_body = self.app(environ, start_response)
107+
108+
if not response_started:
109+
raise RuntimeError("WSGI app did not call start_response")
110+
111+
try:
112+
for output in response_body:
113+
if not headers_sent:
114+
send({"type": "http.response.start", "status": status_code, "headers": headers})
115+
headers_sent = True
116+
117+
send({"type": "http.response.body", "body": output, "more_body": True})
118+
finally:
119+
if hasattr(response_body, "close"):
120+
response_body.close()
106121

107122

108123
def _build_environ(scope: HTTPScope, body: bytes) -> dict:

tests/test_app_wrappers.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,28 @@ async def _send(message: ASGISendEvent) -> None:
6161
]
6262

6363

64+
async def _run_app(app: WSGIWrapper, scope: HTTPScope, body: bytes = b"") -> List[ASGISendEvent]:
65+
queue: asyncio.Queue = asyncio.Queue()
66+
await queue.put({"type": "http.request", "body": body})
67+
68+
messages = []
69+
70+
async def _send(message: ASGISendEvent) -> None:
71+
nonlocal messages
72+
messages.append(message)
73+
74+
event_loop = asyncio.get_running_loop()
75+
76+
def _call_soon(func: Callable, *args: Any) -> Any:
77+
future = asyncio.run_coroutine_threadsafe(func(*args), event_loop)
78+
return future.result()
79+
80+
await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon)
81+
return messages
82+
83+
6484
@pytest.mark.asyncio
65-
async def test_wsgi_asyncio(event_loop: asyncio.AbstractEventLoop) -> None:
85+
async def test_wsgi_asyncio() -> None:
6686
app = WSGIWrapper(echo_body, 2**16)
6787
scope: HTTPScope = {
6888
"http_version": "1.1",
@@ -79,20 +99,7 @@ async def test_wsgi_asyncio(event_loop: asyncio.AbstractEventLoop) -> None:
7999
"server": None,
80100
"extensions": {},
81101
}
82-
queue: asyncio.Queue = asyncio.Queue()
83-
await queue.put({"type": "http.request"})
84-
85-
messages = []
86-
87-
async def _send(message: ASGISendEvent) -> None:
88-
nonlocal messages
89-
messages.append(message)
90-
91-
def _call_soon(func: Callable, *args: Any) -> Any:
92-
future = asyncio.run_coroutine_threadsafe(func(*args), event_loop)
93-
return future.result()
94-
95-
await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon)
102+
messages = await _run_app(app, scope)
96103
assert messages == [
97104
{
98105
"headers": [(b"content-type", b"text/plain; charset=utf-8"), (b"content-length", b"0")],
@@ -105,7 +112,7 @@ def _call_soon(func: Callable, *args: Any) -> Any:
105112

106113

107114
@pytest.mark.asyncio
108-
async def test_max_body_size(event_loop: asyncio.AbstractEventLoop) -> None:
115+
async def test_max_body_size() -> None:
109116
app = WSGIWrapper(echo_body, 4)
110117
scope: HTTPScope = {
111118
"http_version": "1.1",
@@ -122,25 +129,39 @@ async def test_max_body_size(event_loop: asyncio.AbstractEventLoop) -> None:
122129
"server": None,
123130
"extensions": {},
124131
}
125-
queue: asyncio.Queue = asyncio.Queue()
126-
await queue.put({"type": "http.request", "body": b"abcde"})
127-
messages = []
128-
129-
async def _send(message: ASGISendEvent) -> None:
130-
nonlocal messages
131-
messages.append(message)
132-
133-
def _call_soon(func: Callable, *args: Any) -> Any:
134-
future = asyncio.run_coroutine_threadsafe(func(*args), event_loop)
135-
return future.result()
136-
137-
await app(scope, queue.get, _send, partial(event_loop.run_in_executor, None), _call_soon)
132+
messages = await _run_app(app, scope, b"abcde")
138133
assert messages == [
139134
{"headers": [], "status": 400, "type": "http.response.start"},
140135
{"body": bytearray(b""), "type": "http.response.body", "more_body": False},
141136
]
142137

143138

139+
def no_start_response(environ: dict, start_response: Callable) -> List[bytes]:
140+
return [b"result"]
141+
142+
143+
@pytest.mark.asyncio
144+
async def test_no_start_response() -> None:
145+
app = WSGIWrapper(no_start_response, 2**16)
146+
scope: HTTPScope = {
147+
"http_version": "1.1",
148+
"asgi": {},
149+
"method": "GET",
150+
"headers": [],
151+
"path": "/",
152+
"root_path": "/",
153+
"query_string": b"a=b",
154+
"raw_path": b"/",
155+
"scheme": "http",
156+
"type": "http",
157+
"client": ("localhost", 80),
158+
"server": None,
159+
"extensions": {},
160+
}
161+
with pytest.raises(RuntimeError):
162+
await _run_app(app, scope)
163+
164+
144165
def test_build_environ_encoding() -> None:
145166
scope: HTTPScope = {
146167
"http_version": "1.0",

0 commit comments

Comments
 (0)