Skip to content

Commit 7759fbe

Browse files
committed
Fixed flaky server-side SSE tests
1 parent b1d242a commit 7759fbe

File tree

1 file changed

+27
-58
lines changed

1 file changed

+27
-58
lines changed

tests/server/test_sse_security.py

Lines changed: 27 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import multiprocessing
55
import socket
66
import time
7+
from contextlib import contextmanager
78

89
import httpx
910
import pytest
@@ -66,40 +67,48 @@ async def handle_sse(request: Request):
6667
uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error")
6768

6869

70+
@contextmanager
6971
def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None):
7072
"""Start server in a separate process."""
7173
context = multiprocessing.get_context("spawn")
7274
process = context.Process(target=run_server_with_settings, args=(port, security_settings))
7375
process.start()
74-
# Give server time to start
75-
time.sleep(1)
76-
return process
76+
77+
# Wait until the designated port can be connected
78+
max_attempts = 20
79+
for attempt in range(max_attempts):
80+
try:
81+
with socket.create_connection(("127.0.0.1", port)):
82+
break
83+
except ConnectionRefusedError:
84+
time.sleep(0.1)
85+
else:
86+
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
87+
88+
try:
89+
yield
90+
finally:
91+
process.terminate()
92+
process.join()
7793

7894

7995
@pytest.mark.anyio
8096
async def test_sse_security_default_settings(server_port: int):
8197
"""Test SSE with default security settings (protection disabled)."""
82-
process = start_server_process(server_port)
83-
84-
try:
98+
with start_server_process(server_port):
8599
headers = {"Host": "evil.com", "Origin": "http://evil.com"}
86100

87101
async with httpx.AsyncClient(timeout=5.0) as client:
88102
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
89103
assert response.status_code == 200
90-
finally:
91-
process.terminate()
92-
process.join()
93104

94105

95106
@pytest.mark.anyio
96107
async def test_sse_security_invalid_host_header(server_port: int):
97108
"""Test SSE with invalid Host header."""
98109
# Enable security by providing settings with an empty allowed_hosts list
99110
security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"])
100-
process = start_server_process(server_port, security_settings)
101-
102-
try:
111+
with start_server_process(server_port, security_settings):
103112
# Test with invalid host header
104113
headers = {"Host": "evil.com"}
105114

@@ -108,10 +117,6 @@ async def test_sse_security_invalid_host_header(server_port: int):
108117
assert response.status_code == 421
109118
assert response.text == "Invalid Host header"
110119

111-
finally:
112-
process.terminate()
113-
process.join()
114-
115120

116121
@pytest.mark.anyio
117122
async def test_sse_security_invalid_origin_header(server_port: int):
@@ -120,9 +125,7 @@ async def test_sse_security_invalid_origin_header(server_port: int):
120125
security_settings = TransportSecuritySettings(
121126
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"]
122127
)
123-
process = start_server_process(server_port, security_settings)
124-
125-
try:
128+
with start_server_process(server_port, security_settings):
126129
# Test with invalid origin header
127130
headers = {"Origin": "http://evil.com"}
128131

@@ -131,10 +134,6 @@ async def test_sse_security_invalid_origin_header(server_port: int):
131134
assert response.status_code == 400
132135
assert response.text == "Invalid Origin header"
133136

134-
finally:
135-
process.terminate()
136-
process.join()
137-
138137

139138
@pytest.mark.anyio
140139
async def test_sse_security_post_invalid_content_type(server_port: int):
@@ -143,9 +142,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
143142
security_settings = TransportSecuritySettings(
144143
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"]
145144
)
146-
process = start_server_process(server_port, security_settings)
147-
148-
try:
145+
with start_server_process(server_port, security_settings):
149146
async with httpx.AsyncClient(timeout=5.0) as client:
150147
# Test POST with invalid content type
151148
fake_session_id = "12345678123456781234567812345678"
@@ -164,18 +161,12 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
164161
assert response.status_code == 400
165162
assert response.text == "Invalid Content-Type header"
166163

167-
finally:
168-
process.terminate()
169-
process.join()
170-
171164

172165
@pytest.mark.anyio
173166
async def test_sse_security_disabled(server_port: int):
174167
"""Test SSE with security disabled."""
175168
settings = TransportSecuritySettings(enable_dns_rebinding_protection=False)
176-
process = start_server_process(server_port, settings)
177-
178-
try:
169+
with start_server_process(server_port, settings):
179170
# Test with invalid host header - should still work
180171
headers = {"Host": "evil.com"}
181172

@@ -185,10 +176,6 @@ async def test_sse_security_disabled(server_port: int):
185176
# Should connect successfully even with invalid host
186177
assert response.status_code == 200
187178

188-
finally:
189-
process.terminate()
190-
process.join()
191-
192179

193180
@pytest.mark.anyio
194181
async def test_sse_security_custom_allowed_hosts(server_port: int):
@@ -198,9 +185,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
198185
allowed_hosts=["localhost", "127.0.0.1", "custom.host"],
199186
allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"],
200187
)
201-
process = start_server_process(server_port, settings)
202-
203-
try:
188+
with start_server_process(server_port, settings):
204189
# Test with custom allowed host
205190
headers = {"Host": "custom.host"}
206191

@@ -218,10 +203,6 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
218203
assert response.status_code == 421
219204
assert response.text == "Invalid Host header"
220205

221-
finally:
222-
process.terminate()
223-
process.join()
224-
225206

226207
@pytest.mark.anyio
227208
async def test_sse_security_wildcard_ports(server_port: int):
@@ -231,9 +212,7 @@ async def test_sse_security_wildcard_ports(server_port: int):
231212
allowed_hosts=["localhost:*", "127.0.0.1:*"],
232213
allowed_origins=["http://localhost:*", "http://127.0.0.1:*"],
233214
)
234-
process = start_server_process(server_port, settings)
235-
236-
try:
215+
with start_server_process(server_port, settings):
237216
# Test with various port numbers
238217
for test_port in [8080, 3000, 9999]:
239218
headers = {"Host": f"localhost:{test_port}"}
@@ -252,10 +231,6 @@ async def test_sse_security_wildcard_ports(server_port: int):
252231
# Should connect successfully with any port
253232
assert response.status_code == 200
254233

255-
finally:
256-
process.terminate()
257-
process.join()
258-
259234

260235
@pytest.mark.anyio
261236
async def test_sse_security_post_valid_content_type(server_port: int):
@@ -264,9 +239,7 @@ async def test_sse_security_post_valid_content_type(server_port: int):
264239
security_settings = TransportSecuritySettings(
265240
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"]
266241
)
267-
process = start_server_process(server_port, security_settings)
268-
269-
try:
242+
with start_server_process(server_port, security_settings):
270243
async with httpx.AsyncClient() as client:
271244
# Test with various valid content types
272245
valid_content_types = [
@@ -288,7 +261,3 @@ async def test_sse_security_post_valid_content_type(server_port: int):
288261
# We're testing that it passes the content-type check
289262
assert response.status_code == 404
290263
assert response.text == "Could not find session"
291-
292-
finally:
293-
process.terminate()
294-
process.join()

0 commit comments

Comments
 (0)