Skip to content

Commit cdeb882

Browse files
committed
Don't log an error when process_request returns a response.
Fix #1513.
1 parent 810bdeb commit cdeb882

File tree

9 files changed

+163
-80
lines changed

9 files changed

+163
-80
lines changed

src/websockets/asyncio/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ async def handshake(
9595
return_when=asyncio.FIRST_COMPLETED,
9696
)
9797

98-
# self.protocol.handshake_exc is always set when the connection is lost
99-
# before receiving a response, when the response cannot be parsed, or
100-
# when the response fails the handshake.
98+
# self.protocol.handshake_exc is set when the connection is lost before
99+
# receiving a response, when the response cannot be parsed, or when the
100+
# response fails the handshake.
101101

102102
if self.protocol.handshake_exc is not None:
103103
raise self.protocol.handshake_exc

src/websockets/asyncio/server.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,13 @@ async def handshake(
192192

193193
self.protocol.send_response(self.response)
194194

195-
# self.protocol.handshake_exc is always set when the connection is lost
196-
# before receiving a request, when the request cannot be parsed, when
197-
# the handshake encounters an error, or when process_request or
198-
# process_response sends an HTTP response that rejects the handshake.
195+
# self.protocol.handshake_exc is set when the connection is lost before
196+
# receiving a request, when the request cannot be parsed, or when the
197+
# handshake fails, including when process_request or process_response
198+
# raises an exception.
199+
200+
# It isn't set when process_request or process_response sends an HTTP
201+
# response that rejects the handshake.
199202

200203
if self.protocol.handshake_exc is not None:
201204
raise self.protocol.handshake_exc
@@ -360,7 +363,11 @@ async def conn_handler(self, connection: ServerConnection) -> None:
360363
connection.close_transport()
361364
return
362365

363-
assert connection.protocol.state is OPEN
366+
if connection.protocol.state is not OPEN:
367+
# process_request or process_response rejected the handshake.
368+
connection.close_transport()
369+
return
370+
364371
try:
365372
connection.start_keepalive()
366373
await self.handler(connection)

src/websockets/protocol.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -518,15 +518,34 @@ def close_expected(self) -> bool:
518518
Whether the TCP connection is expected to close soon.
519519
520520
"""
521-
# We expect a TCP close if and only if we sent a close frame:
521+
# During the opening handshake, when our state is CONNECTING, we expect
522+
# a TCP close if and only if the hansdake fails. When it does, we start
523+
# the TCP closing handshake by sending EOF with send_eof().
524+
525+
# Once the opening handshake completes successfully, we expect a TCP
526+
# close if and only if we sent a close frame, meaning that our state
527+
# progressed to CLOSING:
528+
522529
# * Normal closure: once we send a close frame, we expect a TCP close:
523530
# server waits for client to complete the TCP closing handshake;
524531
# client waits for server to initiate the TCP closing handshake.
532+
525533
# * Abnormal closure: we always send a close frame and the same logic
526534
# applies, except on EOFError where we don't send a close frame
527535
# because we already received the TCP close, so we don't expect it.
528-
# We already got a TCP Close if and only if the state is CLOSED.
529-
return self.state is CLOSING or self.handshake_exc is not None
536+
537+
# If our state is CLOSED, we already received a TCP close so we don't
538+
# expect it anymore.
539+
540+
# Micro-optimization: put the most common case first
541+
if self.state is OPEN:
542+
return False
543+
if self.state is CLOSING:
544+
return True
545+
if self.state is CLOSED:
546+
return False
547+
assert self.state is CONNECTING
548+
return self.eof_sent
530549

531550
# Private methods for receiving data.
532551

@@ -616,14 +635,14 @@ def discard(self) -> Generator[None]:
616635
# connection in the same circumstances where discard() replaces parse().
617636
# The client closes it when it receives EOF from the server or times
618637
# out. (The latter case cannot be handled in this Sans-I/O layer.)
619-
assert (self.state == CONNECTING or self.side is SERVER) == (self.eof_sent)
638+
assert (self.side is SERVER or self.state is CONNECTING) == (self.eof_sent)
620639
while not (yield from self.reader.at_eof()):
621640
self.reader.discard()
622641
if self.debug:
623642
self.logger.debug("< EOF")
624643
# A server closes the TCP connection immediately, while a client
625644
# waits for the server to close the TCP connection.
626-
if self.state != CONNECTING and self.side is CLIENT:
645+
if self.side is CLIENT and self.state is not CONNECTING:
627646
self.send_eof()
628647
self.state = CLOSED
629648
# If discard() completes normally, execution ends here.

src/websockets/server.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
InvalidHeader,
1515
InvalidHeaderValue,
1616
InvalidOrigin,
17-
InvalidStatus,
1817
InvalidUpgrade,
1918
NegotiationError,
2019
)
@@ -536,11 +535,6 @@ def send_response(self, response: Response) -> None:
536535
self.logger.info("connection open")
537536

538537
else:
539-
# handshake_exc may be already set if accept() encountered an error.
540-
# If the connection isn't open, set handshake_exc to guarantee that
541-
# handshake_exc is None if and only if opening handshake succeeded.
542-
if self.handshake_exc is None:
543-
self.handshake_exc = InvalidStatus(response)
544538
self.logger.info(
545539
"connection rejected (%d %s)",
546540
response.status_code,

src/websockets/sync/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ def handshake(
8787
if not self.response_rcvd.wait(timeout):
8888
raise TimeoutError("timed out during handshake")
8989

90-
# self.protocol.handshake_exc is always set when the connection is lost
91-
# before receiving a response, when the response cannot be parsed, or
92-
# when the response fails the handshake.
90+
# self.protocol.handshake_exc is set when the connection is lost before
91+
# receiving a response, when the response cannot be parsed, or when the
92+
# response fails the handshake.
9393

9494
if self.protocol.handshake_exc is not None:
9595
raise self.protocol.handshake_exc

src/websockets/sync/server.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,13 @@ def handshake(
170170

171171
self.protocol.send_response(self.response)
172172

173-
# self.protocol.handshake_exc is always set when the connection is lost
174-
# before receiving a request, when the request cannot be parsed, when
175-
# the handshake encounters an error, or when process_request or
176-
# process_response sends an HTTP response that rejects the handshake.
173+
# self.protocol.handshake_exc is set when the connection is lost before
174+
# receiving a request, when the request cannot be parsed, or when the
175+
# handshake fails, including when process_request or process_response
176+
# raises an exception.
177+
178+
# It isn't set when process_request or process_response sends an HTTP
179+
# response that rejects the handshake.
177180

178181
if self.protocol.handshake_exc is not None:
179182
raise self.protocol.handshake_exc

tests/asyncio/test_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def asyncTearDown(self):
5151
if sys.version_info[:2] < (3, 10): # pragma: no cover
5252

5353
@contextlib.contextmanager
54-
def assertNoLogs(self, logger="websockets", level=logging.ERROR):
54+
def assertNoLogs(self, logger=None, level=None):
5555
"""
5656
No message is logged on the given logger with at least the given level.
5757

tests/asyncio/test_server.py

Lines changed: 94 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,17 @@ def process_request(ws, request):
148148
async def handler(ws):
149149
self.fail("handler must not run")
150150

151-
async with serve(handler, *args[1:], process_request=process_request) as server:
152-
with self.assertRaises(InvalidStatus) as raised:
153-
async with connect(get_uri(server)):
154-
self.fail("did not raise")
155-
self.assertEqual(
156-
str(raised.exception),
157-
"server rejected WebSocket connection: HTTP 403",
158-
)
151+
with self.assertNoLogs("websockets", logging.ERROR):
152+
async with serve(
153+
handler, *args[1:], process_request=process_request
154+
) as server:
155+
with self.assertRaises(InvalidStatus) as raised:
156+
async with connect(get_uri(server)):
157+
self.fail("did not raise")
158+
self.assertEqual(
159+
str(raised.exception),
160+
"server rejected WebSocket connection: HTTP 403",
161+
)
159162

160163
async def test_async_process_request_returns_response(self):
161164
"""Server aborts handshake if async process_request returns a response."""
@@ -166,44 +169,65 @@ async def process_request(ws, request):
166169
async def handler(ws):
167170
self.fail("handler must not run")
168171

169-
async with serve(handler, *args[1:], process_request=process_request) as server:
170-
with self.assertRaises(InvalidStatus) as raised:
171-
async with connect(get_uri(server)):
172-
self.fail("did not raise")
173-
self.assertEqual(
174-
str(raised.exception),
175-
"server rejected WebSocket connection: HTTP 403",
176-
)
172+
with self.assertNoLogs("websockets", logging.ERROR):
173+
async with serve(
174+
handler, *args[1:], process_request=process_request
175+
) as server:
176+
with self.assertRaises(InvalidStatus) as raised:
177+
async with connect(get_uri(server)):
178+
self.fail("did not raise")
179+
self.assertEqual(
180+
str(raised.exception),
181+
"server rejected WebSocket connection: HTTP 403",
182+
)
177183

178184
async def test_process_request_raises_exception(self):
179185
"""Server returns an error if process_request raises an exception."""
180186

181187
def process_request(ws, request):
182-
raise RuntimeError
188+
raise RuntimeError("BOOM")
183189

184-
async with serve(*args, process_request=process_request) as server:
185-
with self.assertRaises(InvalidStatus) as raised:
186-
async with connect(get_uri(server)):
187-
self.fail("did not raise")
188-
self.assertEqual(
189-
str(raised.exception),
190-
"server rejected WebSocket connection: HTTP 500",
191-
)
190+
with self.assertLogs("websockets", logging.ERROR) as logs:
191+
async with serve(*args, process_request=process_request) as server:
192+
with self.assertRaises(InvalidStatus) as raised:
193+
async with connect(get_uri(server)):
194+
self.fail("did not raise")
195+
self.assertEqual(
196+
str(raised.exception),
197+
"server rejected WebSocket connection: HTTP 500",
198+
)
199+
self.assertEqual(
200+
[record.getMessage() for record in logs.records],
201+
["opening handshake failed"],
202+
)
203+
self.assertEqual(
204+
[str(record.exc_info[1]) for record in logs.records],
205+
["BOOM"],
206+
)
192207

193208
async def test_async_process_request_raises_exception(self):
194209
"""Server returns an error if async process_request raises an exception."""
195210

196211
async def process_request(ws, request):
197-
raise RuntimeError
212+
raise RuntimeError("BOOM")
198213

199-
async with serve(*args, process_request=process_request) as server:
200-
with self.assertRaises(InvalidStatus) as raised:
201-
async with connect(get_uri(server)):
202-
self.fail("did not raise")
203-
self.assertEqual(
204-
str(raised.exception),
205-
"server rejected WebSocket connection: HTTP 500",
206-
)
214+
with self.assertLogs("websockets", logging.ERROR) as logs:
215+
async with serve(*args, process_request=process_request) as server:
216+
with self.assertRaises(InvalidStatus) as raised:
217+
async with connect(get_uri(server)):
218+
self.fail("did not raise")
219+
self.assertEqual(
220+
str(raised.exception),
221+
"server rejected WebSocket connection: HTTP 500",
222+
)
223+
self.assertEqual(
224+
[record.getMessage() for record in logs.records],
225+
["opening handshake failed"],
226+
)
227+
self.assertEqual(
228+
[str(record.exc_info[1]) for record in logs.records],
229+
["BOOM"],
230+
)
207231

208232
async def test_process_response_returns_none(self):
209233
"""Server runs process_response but keeps the handshake response."""
@@ -277,31 +301,49 @@ async def test_process_response_raises_exception(self):
277301
"""Server returns an error if process_response raises an exception."""
278302

279303
def process_response(ws, request, response):
280-
raise RuntimeError
304+
raise RuntimeError("BOOM")
281305

282-
async with serve(*args, process_response=process_response) as server:
283-
with self.assertRaises(InvalidStatus) as raised:
284-
async with connect(get_uri(server)):
285-
self.fail("did not raise")
286-
self.assertEqual(
287-
str(raised.exception),
288-
"server rejected WebSocket connection: HTTP 500",
289-
)
306+
with self.assertLogs("websockets", logging.ERROR) as logs:
307+
async with serve(*args, process_response=process_response) as server:
308+
with self.assertRaises(InvalidStatus) as raised:
309+
async with connect(get_uri(server)):
310+
self.fail("did not raise")
311+
self.assertEqual(
312+
str(raised.exception),
313+
"server rejected WebSocket connection: HTTP 500",
314+
)
315+
self.assertEqual(
316+
[record.getMessage() for record in logs.records],
317+
["opening handshake failed"],
318+
)
319+
self.assertEqual(
320+
[str(record.exc_info[1]) for record in logs.records],
321+
["BOOM"],
322+
)
290323

291324
async def test_async_process_response_raises_exception(self):
292325
"""Server returns an error if async process_response raises an exception."""
293326

294327
async def process_response(ws, request, response):
295-
raise RuntimeError
328+
raise RuntimeError("BOOM")
296329

297-
async with serve(*args, process_response=process_response) as server:
298-
with self.assertRaises(InvalidStatus) as raised:
299-
async with connect(get_uri(server)):
300-
self.fail("did not raise")
301-
self.assertEqual(
302-
str(raised.exception),
303-
"server rejected WebSocket connection: HTTP 500",
304-
)
330+
with self.assertLogs("websockets", logging.ERROR) as logs:
331+
async with serve(*args, process_response=process_response) as server:
332+
with self.assertRaises(InvalidStatus) as raised:
333+
async with connect(get_uri(server)):
334+
self.fail("did not raise")
335+
self.assertEqual(
336+
str(raised.exception),
337+
"server rejected WebSocket connection: HTTP 500",
338+
)
339+
self.assertEqual(
340+
[record.getMessage() for record in logs.records],
341+
["opening handshake failed"],
342+
)
343+
self.assertEqual(
344+
[str(record.exc_info[1]) for record in logs.records],
345+
["BOOM"],
346+
)
305347

306348
async def test_override_server(self):
307349
"""Server can override Server header with server_header."""

tests/test_protocol.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
Frame,
2121
)
2222
from websockets.protocol import *
23-
from websockets.protocol import CLIENT, CLOSED, CLOSING, SERVER
23+
from websockets.protocol import CLIENT, CLOSED, CLOSING, CONNECTING, SERVER
2424

2525
from .extensions.utils import Rsv2Extension
2626
from .test_frames import FramesTestCase
@@ -1696,6 +1696,24 @@ def test_server_fails_connection(self):
16961696
server.fail(CloseCode.PROTOCOL_ERROR)
16971697
self.assertTrue(server.close_expected())
16981698

1699+
def test_client_is_connecting(self):
1700+
client = Protocol(CLIENT, state=CONNECTING)
1701+
self.assertFalse(client.close_expected())
1702+
1703+
def test_server_is_connecting(self):
1704+
server = Protocol(SERVER, state=CONNECTING)
1705+
self.assertFalse(server.close_expected())
1706+
1707+
def test_client_failed_connecting(self):
1708+
client = Protocol(CLIENT, state=CONNECTING)
1709+
client.send_eof()
1710+
self.assertTrue(client.close_expected())
1711+
1712+
def test_server_failed_connecting(self):
1713+
server = Protocol(SERVER, state=CONNECTING)
1714+
server.send_eof()
1715+
self.assertTrue(server.close_expected())
1716+
16991717

17001718
class ConnectionClosedTests(ProtocolTestCase):
17011719
"""

0 commit comments

Comments
 (0)