44import multiprocessing
55import socket
66import time
7+ from contextlib import contextmanager
78
89import httpx
910import pytest
@@ -66,40 +67,48 @@ async def handle_sse(request: Request):
6667 uvicorn .run (starlette_app , host = "127.0.0.1" , port = port , log_level = "error" )
6768
6869
70+ @contextmanager
6971def start_server_process (port : int , security_settings : TransportSecuritySettings | None = None ):
7072 """Start server in a separate process."""
7173 context = multiprocessing .get_context ("spawn" )
7274 process = context .Process (target = run_server_with_settings , args = (port , security_settings ))
7375 process .start ()
74- # Give server time to start
75- time .sleep (1 )
76- return process
76+
77+ # Wait until the designated port can be connected
78+ max_attempts = 20
79+ for attempt in range (max_attempts ):
80+ try :
81+ with socket .create_connection (("127.0.0.1" , port )):
82+ break
83+ except ConnectionRefusedError :
84+ time .sleep (0.1 )
85+ else :
86+ raise RuntimeError (f"Server failed to start after { max_attempts } attempts" )
87+
88+ try :
89+ yield
90+ finally :
91+ process .terminate ()
92+ process .join ()
7793
7894
7995@pytest .mark .anyio
8096async def test_sse_security_default_settings (server_port : int ):
8197 """Test SSE with default security settings (protection disabled)."""
82- process = start_server_process (server_port )
83-
84- try :
98+ with start_server_process (server_port ):
8599 headers = {"Host" : "evil.com" , "Origin" : "http://evil.com" }
86100
87101 async with httpx .AsyncClient (timeout = 5.0 ) as client :
88102 async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
89103 assert response .status_code == 200
90- finally :
91- process .terminate ()
92- process .join ()
93104
94105
95106@pytest .mark .anyio
96107async def test_sse_security_invalid_host_header (server_port : int ):
97108 """Test SSE with invalid Host header."""
98109 # Enable security by providing settings with an empty allowed_hosts list
99110 security_settings = TransportSecuritySettings (enable_dns_rebinding_protection = True , allowed_hosts = ["example.com" ])
100- process = start_server_process (server_port , security_settings )
101-
102- try :
111+ with start_server_process (server_port , security_settings ):
103112 # Test with invalid host header
104113 headers = {"Host" : "evil.com" }
105114
@@ -108,10 +117,6 @@ async def test_sse_security_invalid_host_header(server_port: int):
108117 assert response .status_code == 421
109118 assert response .text == "Invalid Host header"
110119
111- finally :
112- process .terminate ()
113- process .join ()
114-
115120
116121@pytest .mark .anyio
117122async def test_sse_security_invalid_origin_header (server_port : int ):
@@ -120,9 +125,7 @@ async def test_sse_security_invalid_origin_header(server_port: int):
120125 security_settings = TransportSecuritySettings (
121126 enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://localhost:*" ]
122127 )
123- process = start_server_process (server_port , security_settings )
124-
125- try :
128+ with start_server_process (server_port , security_settings ):
126129 # Test with invalid origin header
127130 headers = {"Origin" : "http://evil.com" }
128131
@@ -131,10 +134,6 @@ async def test_sse_security_invalid_origin_header(server_port: int):
131134 assert response .status_code == 400
132135 assert response .text == "Invalid Origin header"
133136
134- finally :
135- process .terminate ()
136- process .join ()
137-
138137
139138@pytest .mark .anyio
140139async def test_sse_security_post_invalid_content_type (server_port : int ):
@@ -143,9 +142,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
143142 security_settings = TransportSecuritySettings (
144143 enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://127.0.0.1:*" ]
145144 )
146- process = start_server_process (server_port , security_settings )
147-
148- try :
145+ with start_server_process (server_port , security_settings ):
149146 async with httpx .AsyncClient (timeout = 5.0 ) as client :
150147 # Test POST with invalid content type
151148 fake_session_id = "12345678123456781234567812345678"
@@ -164,18 +161,12 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
164161 assert response .status_code == 400
165162 assert response .text == "Invalid Content-Type header"
166163
167- finally :
168- process .terminate ()
169- process .join ()
170-
171164
172165@pytest .mark .anyio
173166async def test_sse_security_disabled (server_port : int ):
174167 """Test SSE with security disabled."""
175168 settings = TransportSecuritySettings (enable_dns_rebinding_protection = False )
176- process = start_server_process (server_port , settings )
177-
178- try :
169+ with start_server_process (server_port , settings ):
179170 # Test with invalid host header - should still work
180171 headers = {"Host" : "evil.com" }
181172
@@ -185,10 +176,6 @@ async def test_sse_security_disabled(server_port: int):
185176 # Should connect successfully even with invalid host
186177 assert response .status_code == 200
187178
188- finally :
189- process .terminate ()
190- process .join ()
191-
192179
193180@pytest .mark .anyio
194181async def test_sse_security_custom_allowed_hosts (server_port : int ):
@@ -198,9 +185,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
198185 allowed_hosts = ["localhost" , "127.0.0.1" , "custom.host" ],
199186 allowed_origins = ["http://localhost" , "http://127.0.0.1" , "http://custom.host" ],
200187 )
201- process = start_server_process (server_port , settings )
202-
203- try :
188+ with start_server_process (server_port , settings ):
204189 # Test with custom allowed host
205190 headers = {"Host" : "custom.host" }
206191
@@ -218,10 +203,6 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
218203 assert response .status_code == 421
219204 assert response .text == "Invalid Host header"
220205
221- finally :
222- process .terminate ()
223- process .join ()
224-
225206
226207@pytest .mark .anyio
227208async def test_sse_security_wildcard_ports (server_port : int ):
@@ -231,9 +212,7 @@ async def test_sse_security_wildcard_ports(server_port: int):
231212 allowed_hosts = ["localhost:*" , "127.0.0.1:*" ],
232213 allowed_origins = ["http://localhost:*" , "http://127.0.0.1:*" ],
233214 )
234- process = start_server_process (server_port , settings )
235-
236- try :
215+ with start_server_process (server_port , settings ):
237216 # Test with various port numbers
238217 for test_port in [8080 , 3000 , 9999 ]:
239218 headers = {"Host" : f"localhost:{ test_port } " }
@@ -252,10 +231,6 @@ async def test_sse_security_wildcard_ports(server_port: int):
252231 # Should connect successfully with any port
253232 assert response .status_code == 200
254233
255- finally :
256- process .terminate ()
257- process .join ()
258-
259234
260235@pytest .mark .anyio
261236async def test_sse_security_post_valid_content_type (server_port : int ):
@@ -264,9 +239,7 @@ async def test_sse_security_post_valid_content_type(server_port: int):
264239 security_settings = TransportSecuritySettings (
265240 enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://127.0.0.1:*" ]
266241 )
267- process = start_server_process (server_port , security_settings )
268-
269- try :
242+ with start_server_process (server_port , security_settings ):
270243 async with httpx .AsyncClient () as client :
271244 # Test with various valid content types
272245 valid_content_types = [
@@ -288,7 +261,3 @@ async def test_sse_security_post_valid_content_type(server_port: int):
288261 # We're testing that it passes the content-type check
289262 assert response .status_code == 404
290263 assert response .text == "Could not find session"
291-
292- finally :
293- process .terminate ()
294- process .join ()
0 commit comments