diff --git a/docs/conf.py b/docs/conf.py index 798d595db..3dcc04169 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -50,6 +50,8 @@ ("py:meth", "protocol.WebSocketCommonProtocol.connection_lost"), ("py:meth", "protocol.WebSocketCommonProtocol.read_message"), ("py:meth", "protocol.WebSocketCommonProtocol.write_frame"), + # Caused by https://github.com/sphinx-doc/sphinx/issues/13838 + ("py:class", "ssl_module.SSLContext"), ] # Add any Sphinx extension module names here, as strings. They can be @@ -85,6 +87,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "sesame": ("https://django-sesame.readthedocs.io/en/stable/", None), + "trio": ("https://trio.readthedocs.io/en/stable/", None), "werkzeug": ("https://werkzeug.palletsprojects.com/en/stable/", None), } diff --git a/docs/index.rst b/docs/index.rst index 738258688..0774b91b8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -70,6 +70,10 @@ Here's an echo server and corresponding client. .. literalinclude:: ../example/sync/echo.py +.. tab:: trio + + .. literalinclude:: ../example/trio/echo.py + .. tab:: asyncio :new-set: @@ -79,6 +83,11 @@ Here's an echo server and corresponding client. .. literalinclude:: ../example/sync/hello.py +.. tab:: trio + + .. literalinclude:: ../example/trio/hello.py + + Don't worry about the opening and closing handshakes, pings and pongs, or any other behavior described in the WebSocket specification. websockets takes care of this under the hood so you can focus on your application! diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f6d7abb76..c43c20f71 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -43,6 +43,14 @@ Backwards-incompatible changes New features ............ +.. admonition:: websockets 16.0 introduces a :mod:`trio` implementation. + :class: important + + It is an alternative to the :mod:`asyncio` implementation. + + See :func:`websockets.trio.client.connect` and + :func:`websockets.trio.server.serve` for details. + * Validated compatibility with Python 3.14. Improvements diff --git a/docs/reference/asyncio/server.rst b/docs/reference/asyncio/server.rst index a245929ef..e8d80902b 100644 --- a/docs/reference/asyncio/server.rst +++ b/docs/reference/asyncio/server.rst @@ -46,7 +46,7 @@ Running a server .. automethod:: serve_forever - .. autoattribute:: sockets + .. autoproperty:: sockets Using a connection ------------------ diff --git a/docs/reference/features.rst b/docs/reference/features.rst index e5f6e0de0..dbeb6b43d 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -16,6 +16,7 @@ Feature support matrices summarize which implementations support which features. .. |aio| replace:: :mod:`asyncio` (new) .. |sync| replace:: :mod:`threading` +.. |trio| replace:: :mod:`trio` .. |sans| replace:: `Sans-I/O`_ .. |leg| replace:: :mod:`asyncio` (legacy) .. _Sans-I/O: https://sans-io.readthedocs.io/ @@ -26,68 +27,68 @@ Both sides .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | |leg| | - +====================================+========+========+========+========+ - | Perform the opening handshake | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Enforce opening timeout | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Send a message | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Broadcast a message | ✅ | ❌ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Receive a message | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Iterate over received messages | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Receive a fragmented message frame | ✅ | ✅ | — | ❌ | - | by frame | | | | | - +------------------------------------+--------+--------+--------+--------+ - | Receive a fragmented message after | ✅ | ✅ | — | ✅ | - | reassembly | | | | | - +------------------------------------+--------+--------+--------+--------+ - | Force sending a message as Text or | ✅ | ✅ | — | ❌ | - | Binary | | | | | - +------------------------------------+--------+--------+--------+--------+ - | Force receiving a message as | ✅ | ✅ | — | ❌ | - | :class:`bytes` or :class:`str` | | | | | - +------------------------------------+--------+--------+--------+--------+ - | Send a ping | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Send a pong | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Keepalive | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Heartbeat | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Measure latency | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Enforce closing timeout | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Report close codes and reasons | ✅ | ✅ | ✅ | ❌ | - | from both sides | | | | | - +------------------------------------+--------+--------+--------+--------+ - | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Tune memory usage for compression | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Negotiate extensions | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Implement custom extensions | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Negotiate a subprotocol | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Enforce security limits | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Log events | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+--------+ + | | |aio| | |sync| | |trio| | |sans| | |leg| | + +====================================+========+========+========+========+========+ + | Perform the opening handshake | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Enforce opening timeout | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Send a message | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Broadcast a message | ✅ | ❌ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Receive a message | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Iterate over received messages | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Send a fragmented message | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Receive a fragmented message frame | ✅ | ✅ | ✅ | — | ❌ | + | by frame | | | | | | + +------------------------------------+--------+--------+--------+--------+--------+ + | Receive a fragmented message after | ✅ | ✅ | ✅ | — | ✅ | + | reassembly | | | | | | + +------------------------------------+--------+--------+--------+--------+--------+ + | Force sending a message as Text or | ✅ | ✅ | ✅ | — | ❌ | + | Binary | | | | | | + +------------------------------------+--------+--------+--------+--------+--------+ + | Force receiving a message as | ✅ | ✅ | ✅ | — | ❌ | + | :class:`bytes` or :class:`str` | | | | | | + +------------------------------------+--------+--------+--------+--------+--------+ + | Send a ping | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Respond to pings automatically | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Send a pong | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Keepalive | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Heartbeat | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Measure latency | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Enforce closing timeout | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Report close codes and reasons | ✅ | ✅ | ✅ | ✅ | ❌ | + | from both sides | | | | | | + +------------------------------------+--------+--------+--------+--------+--------+ + | Compress messages (:rfc:`7692`) | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Tune memory usage for compression | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Negotiate extensions | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Implement custom extensions | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Negotiate a subprotocol | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Enforce security limits | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Log events | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ Server ------ @@ -95,39 +96,39 @@ Server .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | |leg| | - +====================================+========+========+========+========+ - | Listen on a TCP socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Listen on a Unix socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Listen using a preexisting socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Close server on context exit | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Close connection on handler exit | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Shut down server gracefully | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Check ``Origin`` header | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Customize subprotocol selection | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Configure ``Server`` header | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Alter opening handshake request | ✅ | ✅ | ✅ | ❌ | - +------------------------------------+--------+--------+--------+--------+ - | Alter opening handshake response | ✅ | ✅ | ✅ | ❌ | - +------------------------------------+--------+--------+--------+--------+ - | Force an HTTP response | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ✅ | ✅ | ❌ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Dispatch connections to handlers | ✅ | ✅ | — | ❌ | - +------------------------------------+--------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+--------+ + | | |aio| | |sync| | |trio| | |sans| | |leg| | + +====================================+========+========+========+========+========+ + | Listen on a TCP socket | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Listen on a Unix socket | ✅ | ✅ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Listen using a preexisting socket | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Close server on context exit | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Close connection on handler exit | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Shut down server gracefully | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Check ``Origin`` header | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Customize subprotocol selection | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Configure ``Server`` header | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Alter opening handshake request | ✅ | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Alter opening handshake response | ✅ | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Force an HTTP response | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ❌ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Route connections to handlers | ✅ | ✅ | ❌ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ Client ------ @@ -135,39 +136,39 @@ Client .. table:: :class: support-matrix-table - +------------------------------------+--------+--------+--------+--------+ - | | |aio| | |sync| | |sans| | |leg| | - +====================================+========+========+========+========+ - | Connect to a TCP socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Connect to a Unix socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Connect using a preexisting socket | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Encrypt connection with TLS | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Close connection on context exit | ✅ | ✅ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Reconnect automatically | ✅ | ❌ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Configure ``Origin`` header | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Modify opening handshake request | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Modify opening handshake response | ✅ | ✅ | ✅ | ❌ | - +------------------------------------+--------+--------+--------+--------+ - | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Follow HTTP redirects | ✅ | ❌ | — | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ - | Connect via HTTP proxy | ✅ | ✅ | — | ❌ | - +------------------------------------+--------+--------+--------+--------+ - | Connect via SOCKS5 proxy | ✅ | ✅ | — | ❌ | - +------------------------------------+--------+--------+--------+--------+ + +------------------------------------+--------+--------+--------+--------+--------+ + | | |aio| | |sync| | |trio| | |sans| | |leg| | + +====================================+========+========+========+========+========+ + | Connect to a TCP socket | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Connect to a Unix socket | ✅ | ✅ | ❌ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Connect using a preexisting socket | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Encrypt connection with TLS | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Close connection on context exit | ✅ | ✅ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Reconnect automatically | ✅ | ❌ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Configure ``Origin`` header | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Configure ``User-Agent`` header | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Modify opening handshake request | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Modify opening handshake response | ✅ | ✅ | ✅ | ✅ | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Follow HTTP redirects | ✅ | ❌ | ✅ | — | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | ✅ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Connect via HTTP proxy | ✅ | ✅ | ✅ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ + | Connect via SOCKS5 proxy | ✅ | ✅ | ✅ | — | ❌ | + +------------------------------------+--------+--------+--------+--------+--------+ Known limitations ----------------- diff --git a/docs/reference/index.rst b/docs/reference/index.rst index cc9542c24..64a393d53 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -37,6 +37,17 @@ This alternative implementation can be a good choice for clients. sync/server sync/client +:mod:`trio` +------------ + +This is another option for servers that handle many clients concurrently. + +.. toctree:: + :titlesonly: + + trio/server + trio/client + `Sans-I/O`_ ----------- diff --git a/docs/reference/trio/client.rst b/docs/reference/trio/client.rst new file mode 100644 index 000000000..cf5643c55 --- /dev/null +++ b/docs/reference/trio/client.rst @@ -0,0 +1,63 @@ +Client (:mod:`trio`) +======================= + +.. automodule:: websockets.trio.client + +Opening a connection +-------------------- + +.. autofunction:: connect + :async: + +.. autofunction:: process_exception + +Using a connection +------------------ + +.. autoclass:: ClientConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: aclose + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoattribute:: latency + + .. autoproperty:: state + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/trio/common.rst b/docs/reference/trio/common.rst new file mode 100644 index 000000000..a1c68e0eb --- /dev/null +++ b/docs/reference/trio/common.rst @@ -0,0 +1,54 @@ +:orphan: + +Both sides (:mod:`trio`) +=========================== + +.. automodule:: websockets.trio.connection + +.. autoclass:: Connection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: aclose + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoattribute:: latency + + .. autoproperty:: state + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason diff --git a/docs/reference/trio/server.rst b/docs/reference/trio/server.rst new file mode 100644 index 000000000..e3d92ed45 --- /dev/null +++ b/docs/reference/trio/server.rst @@ -0,0 +1,84 @@ +Server (:mod:`trio`) +======================= + +.. automodule:: websockets.trio.server + +Creating a server +----------------- + +.. autofunction:: serve + :async: + +.. currentmodule:: websockets.trio.server + +Running a server +---------------- + +.. autoclass:: Server + + .. autoattribute:: connections + + .. automethod:: aclose + + .. autoattribute:: listeners + +Using a connection +------------------ + +.. autoclass:: ServerConnection + + .. automethod:: __aiter__ + + .. automethod:: recv + + .. automethod:: recv_streaming + + .. automethod:: send + + .. automethod:: aclose + + .. automethod:: wait_closed + + .. automethod:: ping + + .. automethod:: pong + + .. automethod:: respond + + WebSocket connection objects also provide these attributes: + + .. autoattribute:: id + + .. autoattribute:: logger + + .. autoproperty:: local_address + + .. autoproperty:: remote_address + + .. autoattribute:: latency + + .. autoproperty:: state + + The following attributes are available after the opening handshake, + once the WebSocket connection is open: + + .. autoattribute:: request + + .. autoattribute:: response + + .. autoproperty:: subprotocol + + The following attributes are available after the closing handshake, + once the WebSocket connection is closed: + + .. autoproperty:: close_code + + .. autoproperty:: close_reason + +HTTP Basic Authentication +------------------------- + +websockets supports HTTP Basic Authentication according to +:rfc:`7235` and :rfc:`7617`. + +.. autofunction:: basic_auth diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index e63c2f8f5..fd8300183 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -136,8 +136,8 @@ measured during the last exchange of Ping and Pong frames:: Alternatively, you can measure the latency at any time by calling :attr:`~asyncio.connection.Connection.ping` and awaiting its result:: - pong_waiter = await websocket.ping() - latency = await pong_waiter + pong_received = await websocket.ping() + latency = await pong_received Latency between a client and a server may increase for two reasons: diff --git a/example/asyncio/client.py b/example/asyncio/client.py old mode 100644 new mode 100755 index e3562642d..4d40f97c4 --- a/example/asyncio/client.py +++ b/example/asyncio/client.py @@ -3,7 +3,6 @@ """Client example using the asyncio API.""" import asyncio - from websockets.asyncio.client import connect diff --git a/example/asyncio/server.py b/example/asyncio/server.py old mode 100644 new mode 100755 diff --git a/example/sync/client.py b/example/sync/client.py old mode 100644 new mode 100755 diff --git a/example/sync/server.py b/example/sync/server.py old mode 100644 new mode 100755 diff --git a/example/trio/client.py b/example/trio/client.py new file mode 100755 index 000000000..8bb5d9759 --- /dev/null +++ b/example/trio/client.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +"""Client example using the trio API.""" + +import trio +from websockets.trio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + name = input("What's your name? ") + + await websocket.send(name) + print(f">>> {name}") + + greeting = await websocket.recv() + print(f"<<< {greeting}") + + +if __name__ == "__main__": + trio.run(hello) diff --git a/example/trio/echo.py b/example/trio/echo.py new file mode 100755 index 000000000..e995b767e --- /dev/null +++ b/example/trio/echo.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python + +"""Echo server using the trio API.""" + +import trio +from websockets.trio.server import serve + + +async def echo(websocket): + async for message in websocket: + await websocket.send(message) + + +if __name__ == "__main__": + trio.run(serve, echo, 8765) diff --git a/example/trio/hello.py b/example/trio/hello.py new file mode 100755 index 000000000..1accba49e --- /dev/null +++ b/example/trio/hello.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +"""Client using the trio API.""" + +import trio +from websockets.trio.client import connect + + +async def hello(): + async with connect("ws://localhost:8765") as websocket: + await websocket.send("Hello world!") + message = await websocket.recv() + print(message) + + +if __name__ == "__main__": + trio.run(hello) diff --git a/example/trio/server.py b/example/trio/server.py new file mode 100755 index 000000000..78a5ab7bd --- /dev/null +++ b/example/trio/server.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +"""Server example using the trio API.""" + +import trio +from websockets.trio.server import serve + + +async def hello(websocket): + name = await websocket.recv() + print(f"<<< {name}") + + greeting = f"Hello {name}!" + + await websocket.send(greeting) + print(f">>> {greeting}") + + +if __name__ == "__main__": + trio.run(serve, hello, 8765) diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index bf50bd6f5..862345fe5 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -12,7 +12,7 @@ from typing import Any, Callable, Literal, cast from ..client import ClientProtocol, backoff -from ..datastructures import Headers, HeadersLike +from ..datastructures import HeadersLike from ..exceptions import ( InvalidMessage, InvalidProxyMessage, @@ -23,12 +23,13 @@ ) from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import build_authorization_basic, build_host, validate_subprotocols +from ..headers import validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol -from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri +from ..uri import WebSocketURI, parse_uri from .compatibility import TimeoutError, asyncio_timeout from .connection import Connection @@ -342,7 +343,7 @@ def __init__( if create_connection is None: create_connection = ClientConnection - def protocol_factory(uri: WebSocketURI) -> ClientConnection: + def factory(uri: WebSocketURI) -> ClientConnection: # This is a protocol in the Sans-I/O implementation of websockets. protocol = ClientProtocol( uri, @@ -364,18 +365,18 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection: return connection self.proxy = proxy - self.protocol_factory = protocol_factory + self.factory = factory self.additional_headers = additional_headers self.user_agent_header = user_agent_header self.process_exception = process_exception self.open_timeout = open_timeout self.logger = logger - self.connection_kwargs = kwargs + self.create_connection_kwargs = kwargs - async def create_connection(self) -> ClientConnection: + async def create_client_connection(self) -> ClientConnection: """Create TCP or Unix connection.""" loop = asyncio.get_running_loop() - kwargs = self.connection_kwargs.copy() + kwargs = self.create_connection_kwargs.copy() ws_uri = parse_uri(self.uri) @@ -388,7 +389,7 @@ async def create_connection(self) -> ClientConnection: proxy = get_proxy(ws_uri) def factory() -> ClientConnection: - return self.protocol_factory(ws_uri) + return self.factory(ws_uri) if ws_uri.secure: kwargs.setdefault("ssl", True) @@ -459,7 +460,7 @@ def factory() -> ClientConnection: transport = new_transport connection.connection_made(transport) else: - raise AssertionError("unsupported proxy") + raise NotImplementedError(f"unsupported proxy: {proxy}") else: # Connect to the server directly. if kwargs.get("sock") is None: @@ -496,7 +497,7 @@ def process_redirect(self, exc: Exception) -> Exception | str: new_ws_uri = parse_uri(new_uri) # If connect() received a socket, it is closed and cannot be reused. - if self.connection_kwargs.get("sock") is not None: + if self.create_connection_kwargs.get("sock") is not None: return ValueError( f"cannot follow redirect to {new_uri} with a preexisting socket" ) @@ -512,7 +513,7 @@ def process_redirect(self, exc: Exception) -> Exception | str: or old_ws_uri.port != new_ws_uri.port ): # Cross-origin redirects on Unix sockets don't quite make sense. - if self.connection_kwargs.get("unix", False): + if self.create_connection_kwargs.get("unix", False): return ValueError( f"cannot follow cross-origin redirect to {new_uri} " f"with a Unix socket" @@ -520,8 +521,8 @@ def process_redirect(self, exc: Exception) -> Exception | str: # Cross-origin redirects when host and port are overridden are ill-defined. if ( - self.connection_kwargs.get("host") is not None - or self.connection_kwargs.get("port") is not None + self.create_connection_kwargs.get("host") is not None + or self.create_connection_kwargs.get("port") is not None ): return ValueError( f"cannot follow cross-origin redirect to {new_uri} " @@ -540,14 +541,14 @@ async def __await_impl__(self) -> ClientConnection: try: async with asyncio_timeout(self.open_timeout): for _ in range(MAX_REDIRECTS): - self.connection = await self.create_connection() + connection = await self.create_client_connection() try: - await self.connection.handshake( + await connection.handshake( self.additional_headers, self.user_agent_header, ) except asyncio.CancelledError: - self.connection.transport.abort() + connection.transport.abort() raise except Exception as exc: # Always close the connection even though keep-alive is @@ -556,7 +557,7 @@ async def __await_impl__(self) -> ClientConnection: # protocol. In the current design of connect(), there is # no easy way to reuse the network connection that works # in every case nor to reinitialize the protocol. - self.connection.transport.abort() + connection.transport.abort() uri_or_exc = self.process_redirect(exc) # Response is a valid redirect; follow it. @@ -570,8 +571,8 @@ async def __await_impl__(self) -> ClientConnection: raise uri_or_exc from exc else: - self.connection.start_keepalive() - return self.connection + connection.start_keepalive() + return connection else: raise SecurityError(f"more than {MAX_REDIRECTS} redirects") @@ -586,7 +587,10 @@ async def __await_impl__(self) -> ClientConnection: # async with connect(...) as ...: ... async def __aenter__(self) -> ClientConnection: - return await self + if hasattr(self, "connection"): + raise RuntimeError("connect() isn't reentrant") + self.connection = await self + return self.connection async def __aexit__( self, @@ -594,7 +598,10 @@ async def __aexit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - await self.connection.close() + try: + await self.connection.close() + finally: + del self.connection # async for ... in connect(...): @@ -602,8 +609,8 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: delays: Generator[float] | None = None while True: try: - async with self as protocol: - yield protocol + async with self as connection: + yield connection except Exception as exc: # Determine whether the exception is retryable or fatal. # The API of process_exception is "return an exception or None"; @@ -632,7 +639,6 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]: traceback.format_exception_only(exc)[0].strip(), ) await asyncio.sleep(delay) - continue else: # The connection succeeded. Reset backoff. @@ -721,25 +727,6 @@ async def connect_socks_proxy( raise ProxyError("failed to connect to SOCKS proxy") from exc -def prepare_connect_request( - proxy: Proxy, - ws_uri: WebSocketURI, - user_agent_header: str | None = None, -) -> bytes: - host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) - headers = Headers() - headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) - if user_agent_header is not None: - headers["User-Agent"] = user_agent_header - if proxy.username is not None: - assert proxy.password is not None # enforced by parse_proxy() - headers["Proxy-Authorization"] = build_authorization_basic( - proxy.username, proxy.password - ) - # We cannot use the Request class because it supports only GET requests. - return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() - - class HTTPProxyConnection(asyncio.Protocol): def __init__( self, @@ -795,8 +782,7 @@ def eof_received(self) -> None: def connection_lost(self, exc: Exception | None) -> None: self.reader.feed_eof() - if exc is not None: - self.response.set_exception(exc) + self.run_parser() async def connect_http_proxy( @@ -815,8 +801,8 @@ async def connect_http_proxy( try: # This raises exceptions if the connection to the proxy fails. await protocol.response - except Exception: - transport.close() + except (asyncio.CancelledError, Exception): + transport.abort() raise return transport diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 592480f91..dffc55bb9 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -63,11 +63,10 @@ def __init__( self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout - self.max_queue: tuple[int | None, int | None] if isinstance(max_queue, int) or max_queue is None: - self.max_queue = (max_queue, None) + self.max_queue_high, self.max_queue_low = max_queue, None else: - self.max_queue = max_queue + self.max_queue_high, self.max_queue_low = max_queue if isinstance(write_limit, int): write_limit = (write_limit, None) self.write_limit = write_limit @@ -101,12 +100,12 @@ def __init__( self.close_deadline: float | None = None # Protect sending fragmented messages. - self.fragmented_send_waiter: asyncio.Future[None] | None = None + self.send_in_progress: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + self.pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} - self.latency: float = 0 + self.latency: float = 0.0 """ Latency of the connection, in seconds. @@ -417,8 +416,8 @@ async def send( You may override this behavior with the ``text`` argument: - * Set ``text=True`` to send a bytestring or bytes-like object - (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + * Set ``text=True`` to send an UTF-8 bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) in a Text_ frame. This improves performance when the message is already UTF-8 encoded, for example if the message contains JSON and you're using a JSON library that produces a bytestring. @@ -426,7 +425,7 @@ async def send( frame. This may be useful for servers that expect binary frames instead of text frames. - :meth:`send` also accepts an iterable or an asynchronous iterable of + :meth:`send` also accepts an iterable or asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. All items must be of the same type, or else :meth:`send` will raise a @@ -468,8 +467,8 @@ async def send( """ # While sending a fragmented message, prevent sending other messages # until all fragments are sent. - while self.fragmented_send_waiter is not None: - await asyncio.shield(self.fragmented_send_waiter) + while self.send_in_progress is not None: + await asyncio.shield(self.send_in_progress) # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. @@ -502,8 +501,8 @@ async def send( except StopIteration: return - assert self.fragmented_send_waiter is None - self.fragmented_send_waiter = self.loop.create_future() + assert self.send_in_progress is None + self.send_in_progress = self.loop.create_future() try: # First fragment. if isinstance(chunk, str): @@ -549,8 +548,8 @@ async def send( raise finally: - self.fragmented_send_waiter.set_result(None) - self.fragmented_send_waiter = None + self.send_in_progress.set_result(None) + self.send_in_progress = None # Fragmented message -- async iterator. @@ -561,8 +560,8 @@ async def send( except StopAsyncIteration: return - assert self.fragmented_send_waiter is None - self.fragmented_send_waiter = self.loop.create_future() + assert self.send_in_progress is None + self.send_in_progress = self.loop.create_future() try: # First fragment. if isinstance(chunk, str): @@ -610,8 +609,8 @@ async def send( raise finally: - self.fragmented_send_waiter.set_result(None) - self.fragmented_send_waiter = None + self.send_in_progress.set_result(None) + self.send_in_progress = None else: raise TypeError("data must be str, bytes, iterable, or async iterable") @@ -639,7 +638,7 @@ async def close( # The context manager takes care of waiting for the TCP connection # to terminate after calling a method that sends a close frame. async with self.send_context(): - if self.fragmented_send_waiter is not None: + if self.send_in_progress is not None: self.protocol.fail( CloseCode.INTERNAL_ERROR, "close during fragmented message", @@ -681,9 +680,9 @@ async def ping(self, data: DataLike | None = None) -> Awaitable[float]: :: - pong_waiter = await ws.ping() + pong_received = await ws.ping() # only if you want to wait for the corresponding pong - latency = await pong_waiter + latency = await pong_received Raises: ConnectionClosed: When the connection is closed. @@ -700,19 +699,20 @@ async def ping(self, data: DataLike | None = None) -> Awaitable[float]: async with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pong_waiters: + if data in self.pending_pings: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pong_waiters: + while data is None or data in self.pending_pings: data = struct.pack("!I", random.getrandbits(32)) - pong_waiter = self.loop.create_future() + pong_received = self.loop.create_future() + ping_timestamp = self.loop.time() # The event loop's default clock is time.monotonic(). Its resolution # is a bit low on Windows (~16ms). This is improved in Python 3.13. - self.pong_waiters[data] = (pong_waiter, self.loop.time()) + self.pending_pings[data] = (pong_received, ping_timestamp) self.protocol.send_ping(data) - return pong_waiter + return pong_received async def pong(self, data: DataLike = b"") -> None: """ @@ -761,7 +761,7 @@ def acknowledge_pings(self, data: bytes) -> None: """ # Ignore unsolicited pong. - if data not in self.pong_waiters: + if data not in self.pending_pings: return pong_timestamp = self.loop.time() @@ -770,41 +770,39 @@ def acknowledge_pings(self, data: bytes) -> None: # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + for ping_id, (pong_received, ping_timestamp) in self.pending_pings.items(): ping_ids.append(ping_id) latency = pong_timestamp - ping_timestamp - if not pong_waiter.done(): - pong_waiter.set_result(latency) + if not pong_received.done(): + pong_received.set_result(latency) if ping_id == data: self.latency = latency break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pong_waiters. + # Remove acknowledged pings from self.pending_pings. for ping_id in ping_ids: - del self.pong_waiters[ping_id] + del self.pending_pings[ping_id] - def abort_pings(self) -> None: + def terminate_pending_pings(self) -> None: """ - Raise ConnectionClosed in pending pings. - - They'll never receive a pong once the connection is closed. + Raise ConnectionClosed in pending pings when the connection is closed. """ assert self.protocol.state is CLOSED exc = self.protocol.close_exc - for pong_waiter, _ping_timestamp in self.pong_waiters.values(): - if not pong_waiter.done(): - pong_waiter.set_exception(exc) + for pong_received, _ping_timestamp in self.pending_pings.values(): + if not pong_received.done(): + pong_received.set_exception(exc) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does # nothing, but it prevents logging the exception. - pong_waiter.cancel() + pong_received.cancel() - self.pong_waiters.clear() + self.pending_pings.clear() async def keepalive(self) -> None: """ @@ -825,7 +823,7 @@ async def keepalive(self) -> None: # connection to be closed before raising ConnectionClosed. # However, connection_lost() cancels keepalive_task before # it gets a chance to resume excuting. - pong_waiter = await self.ping() + pong_received = await self.ping() if self.debug: self.logger.debug("% sent keepalive ping") @@ -834,10 +832,11 @@ async def keepalive(self) -> None: async with asyncio_timeout(self.ping_timeout): # connection_lost cancels keepalive immediately # after setting a ConnectionClosed exception on - # pong_waiter. A CancelledError is raised here, + # pong_received. A CancelledError is raised here, # not a ConnectionClosed exception. - latency = await pong_waiter - self.logger.debug("% received keepalive pong") + latency = await pong_received + if self.debug: + self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: self.logger.debug("- timed out waiting for keepalive pong") @@ -908,14 +907,13 @@ async def send_context( # Check if the connection is expected to close soon. if self.protocol.close_expected(): wait_for_close = True - # If the connection is expected to close soon, set the - # close deadline based on the close timeout. - # Since we tested earlier that protocol.state was OPEN + # Set the close deadline based on the close timeout. + # Since we tested earlier that protocol.state is OPEN # (or CONNECTING), self.close_deadline is still None. + assert self.close_deadline is None if self.close_timeout is not None: - assert self.close_deadline is None self.close_deadline = self.loop.time() + self.close_timeout - # Write outgoing data to the socket and enforce flow control. + # Write outgoing data to the socket with flow control. try: self.send_data() await self.drain() @@ -933,9 +931,8 @@ async def send_context( # will be closing soon if it isn't in the expected state. wait_for_close = True # Calculate close_deadline if it wasn't set yet. - if self.close_timeout is not None: - if self.close_deadline is None: - self.close_deadline = self.loop.time() + self.close_timeout + if self.close_deadline is None and self.close_timeout is not None: + self.close_deadline = self.loop.time() + self.close_timeout raise_close_exc = True # If the connection is expected to close soon and the close timeout @@ -966,9 +963,6 @@ def send_data(self) -> None: """ Send outgoing data. - Raises: - OSError: When a socket operations fails. - """ for data in self.protocol.data_to_send(): if data: @@ -982,7 +976,7 @@ def send_data(self) -> None: # OSError is plausible. uvloop can raise RuntimeError here. try: self.transport.write_eof() - except (OSError, RuntimeError): # pragma: no cover + except Exception: # pragma: no cover pass # Else, close the TCP connection. else: # pragma: no cover @@ -1005,7 +999,8 @@ def set_recv_exc(self, exc: BaseException | None) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None: transport = cast(asyncio.Transport, transport) self.recv_messages = Assembler( - *self.max_queue, + self.max_queue_high, + self.max_queue_low, pause=transport.pause_reading, resume=transport.resume_reading, ) @@ -1022,7 +1017,7 @@ def connection_lost(self, exc: Exception | None) -> None: # Abort recv() and pending pings with a ConnectionClosed exception. self.recv_messages.close() - self.abort_pings() + self.terminate_pending_pings() if self.keepalive_task is not None: self.keepalive_task.cancel() @@ -1092,12 +1087,10 @@ def data_received(self, data: bytes) -> None: self.logger.debug("! error while sending data", exc_info=True) self.set_recv_exc(exc) + # If needed, set the close deadline based on the close timeout. if self.protocol.close_expected(): - # If the connection is expected to close soon, set the - # close deadline based on the close timeout. - if self.close_timeout is not None: - if self.close_deadline is None: - self.close_deadline = self.loop.time() + self.close_timeout + if self.close_deadline is None and self.close_timeout is not None: + self.close_deadline = self.loop.time() + self.close_timeout for event in events: # This isn't expected to raise an exception. @@ -1205,7 +1198,7 @@ def broadcast( if connection.protocol.state is not OPEN: continue - if connection.fragmented_send_waiter is not None: + if connection.send_in_progress is not None: if raise_exceptions: exception = ConcurrencyError("sending a fragmented message") exceptions.append(exception) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index af26d5d7a..8ac64da34 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -81,8 +81,7 @@ class Assembler: """ - # coverage reports incorrectly: "line NN didn't jump to the function exit" - def __init__( # pragma: no cover + def __init__( self, high: int | None = None, low: int | None = None, @@ -155,7 +154,7 @@ async def get(self, decode: bool | None = None) -> Data: # until get() fetches a complete message or is canceled. try: - # First frame + # Fetch the first frame. frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY @@ -163,7 +162,7 @@ async def get(self, decode: bool | None = None) -> Data: decode = frame.opcode is OP_TEXT frames = [frame] - # Following frames, for fragmented messages + # Fetch subsequent frames for fragmented messages. while not frame.fin: try: frame = await self.frames.get(not self.closed) @@ -230,7 +229,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. - # First frame + # Yield the first frame. try: frame = await self.frames.get(not self.closed) except asyncio.CancelledError: @@ -247,7 +246,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # Convert to bytes when frame.data is a bytearray. yield bytes(frame.data) - # Following frames, for fragmented messages + # Yield subsequent frames for fragmented messages. while not frame.fin: # We cannot handle asyncio.CancelledError because we don't buffer # previous fragments — we're streaming them. Canceling get_iter() @@ -280,22 +279,22 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" - # Skip if flow control is disabled + # Skip if flow control is disabled. if self.high is None: return - # Check for "> high" to support high = 0 + # Check for "> high" to support high = 0. if len(self.frames) > self.high and not self.paused: self.paused = True self.pause() def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" - # Skip if flow control is disabled + # Skip if flow control is disabled. if self.low is None: return - # Check for "<= low" to support low = 0 + # Check for "<= low" to support low = 0. if len(self.frames) <= self.low and self.paused: self.paused = False self.resume() diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index ef9bd807f..2770e3dcb 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -310,7 +310,11 @@ def connections(self) -> set[ServerConnection]: It can be useful in combination with :func:`~broadcast`. """ - return {connection for connection in self.handlers if connection.state is OPEN} + return { + connection + for connection in self.handlers + if connection.protocol.state is OPEN + } def wrap(self, server: asyncio.Server) -> None: """ @@ -351,6 +355,8 @@ async def conn_handler(self, connection: ServerConnection) -> None: """ try: + # Apply open_timeout to the WebSocket handshake. + # Use ssl_handshake_timeout for the TLS handshake. async with asyncio_timeout(self.open_timeout): try: await connection.handshake( @@ -425,7 +431,7 @@ def close( ``code`` and ``reason`` can be customized, for example to use code 1012 (service restart). - * Wait until all connection handlers terminate. + * Wait until all connection handlers have returned. :meth:`close` is idempotent. @@ -452,6 +458,7 @@ async def _close( self.logger.info("server closing") # Stop accepting new connections. + # Reject OPENING connections with HTTP 503 -- see handshake(). self.server.close() # Wait until all accepted connections reach connection_made() and call @@ -459,15 +466,12 @@ async def _close( # details. This workaround can be removed when dropping Python < 3.11. await asyncio.sleep(0) - # After server.close(), handshake() closes OPENING connections with an - # HTTP 503 error. - + # Close OPEN connections. if close_connections: - # Close OPEN connections with code 1001 by default. close_tasks = [ asyncio.create_task(connection.close(code, reason)) for connection in self.handlers - if connection.protocol.state is not CONNECTING + if connection.protocol.state is OPEN ] # asyncio.wait doesn't accept an empty first argument. if close_tasks: @@ -476,7 +480,7 @@ async def _close( # Wait until all TCP connections are closed. await self.server.wait_closed() - # Wait until all connection handlers terminate. + # Wait until all connection handlers have returned. # asyncio.wait doesn't accept an empty first argument. if self.handlers: await asyncio.wait(self.handlers.values()) @@ -594,7 +598,7 @@ class serve: from websockets.asyncio.server import serve - def handler(websocket): + async def handler(websocket): ... # set this future to exit the server diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index f2d8ea7b5..3a479a05a 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -1372,7 +1372,7 @@ def fail_connection( This requires: - 1. Stopping all processing of incoming data, which means cancelling + 1. Stopping all processing of incoming data, which means canceling :attr:`transfer_data_task`. The close code will be 1006 unless a close frame was received earlier. diff --git a/src/websockets/proxy.py b/src/websockets/proxy.py new file mode 100644 index 000000000..a343b37bc --- /dev/null +++ b/src/websockets/proxy.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import dataclasses +import urllib.parse +import urllib.request + +from .datastructures import Headers +from .exceptions import InvalidProxy +from .headers import build_authorization_basic, build_host +from .http11 import USER_AGENT +from .uri import DELIMS, WebSocketURI + + +__all__ = ["get_proxy", "parse_proxy", "Proxy"] + + +@dataclasses.dataclass +class Proxy: + """ + Proxy address. + + Attributes: + scheme: ``"socks5h"``, ``"socks5"``, ``"socks4a"``, ``"socks4"``, + ``"https"``, or ``"http"``. + host: Normalized to lower case. + port: Always set even if it's the default. + username: Available when the proxy address contains `User Information`_. + password: Available when the proxy address contains `User Information`_. + + .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 + + """ + + scheme: str + host: str + port: int + username: str | None = None + password: str | None = None + + @property + def user_info(self) -> tuple[str, str] | None: + if self.username is None: + return None + assert self.password is not None + return (self.username, self.password) + + +def parse_proxy(proxy: str) -> Proxy: + """ + Parse and validate a proxy. + + Args: + proxy: proxy. + + Returns: + Parsed proxy. + + Raises: + InvalidProxy: If ``proxy`` isn't a valid proxy. + + """ + parsed = urllib.parse.urlparse(proxy) + if parsed.scheme not in ["socks5h", "socks5", "socks4a", "socks4", "https", "http"]: + raise InvalidProxy(proxy, f"scheme {parsed.scheme} isn't supported") + if parsed.hostname is None: + raise InvalidProxy(proxy, "hostname isn't provided") + if parsed.path not in ["", "/"]: + raise InvalidProxy(proxy, "path is meaningless") + if parsed.query != "": + raise InvalidProxy(proxy, "query is meaningless") + if parsed.fragment != "": + raise InvalidProxy(proxy, "fragment is meaningless") + + scheme = parsed.scheme + host = parsed.hostname + port = parsed.port or (443 if parsed.scheme == "https" else 80) + username = parsed.username + password = parsed.password + # urllib.parse.urlparse accepts URLs with a username but without a + # password. This doesn't make sense for HTTP Basic Auth credentials. + if username is not None and password is None: + raise InvalidProxy(proxy, "username provided without password") + + try: + proxy.encode("ascii") + except UnicodeEncodeError: + # Input contains non-ASCII characters. + # It must be an IRI. Convert it to a URI. + host = host.encode("idna").decode() + if username is not None: + assert password is not None + username = urllib.parse.quote(username, safe=DELIMS) + password = urllib.parse.quote(password, safe=DELIMS) + + return Proxy(scheme, host, port, username, password) + + +def get_proxy(uri: WebSocketURI) -> str | None: + """ + Return the proxy to use for connecting to the given WebSocket URI, if any. + + """ + if urllib.request.proxy_bypass(f"{uri.host}:{uri.port}"): + return None + + # According to the _Proxy Usage_ section of RFC 6455, use a SOCKS5 proxy if + # available, else favor the proxy for HTTPS connections over the proxy for + # HTTP connections. + + # The priority of a proxy for WebSocket connections is unspecified. We give + # it the highest priority. This makes it easy to configure a specific proxy + # for websockets. + + # getproxies() may return SOCKS proxies as {"socks": "http://host:port"} or + # as {"https": "socks5h://host:port"} depending on whether they're declared + # in the operating system or in environment variables. + + proxies = urllib.request.getproxies() + if uri.secure: + schemes = ["wss", "socks", "https"] + else: + schemes = ["ws", "socks", "https", "http"] + + for scheme in schemes: + proxy = proxies.get(scheme) + if proxy is not None: + if scheme == "socks" and proxy.startswith("http://"): + proxy = "socks5h://" + proxy[7:] + return proxy + else: + return None + + +def prepare_connect_request( + proxy: Proxy, + ws_uri: WebSocketURI, + user_agent_header: str | None = USER_AGENT, +) -> bytes: + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) + headers = Headers() + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if user_agent_header is not None: + headers["User-Agent"] = user_agent_header + if proxy.username is not None: + assert proxy.password is not None # enforced by parse_proxy() + headers["Proxy-Authorization"] = build_authorization_basic( + proxy.username, proxy.password + ) + # We cannot use the Request class because it supports only GET requests. + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 8042a3744..a70952932 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -8,16 +8,17 @@ from typing import Any, Callable, Literal, TypeVar, cast from ..client import ClientProtocol -from ..datastructures import Headers, HeadersLike +from ..datastructures import HeadersLike from ..exceptions import InvalidProxyMessage, InvalidProxyStatus, ProxyError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import build_authorization_basic, build_host, validate_subprotocols +from ..headers import validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request from ..streams import StreamReader from ..typing import BytesLike, LoggerLike, Origin, Subprotocol -from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri +from ..uri import WebSocketURI, parse_uri from .connection import Connection from .utils import Deadline @@ -156,6 +157,7 @@ def connect( logger: LoggerLike | None = None, # Escape hatch for advanced customization create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to socket.create_connection **kwargs: Any, ) -> ClientConnection: """ @@ -229,6 +231,7 @@ def connect( Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidProxy: If ``proxy`` isn't a valid proxy. OSError: If the TCP connection fails. InvalidHandshake: If the opening handshake fails. TimeoutError: If the opening handshake times out. @@ -476,25 +479,6 @@ def connect_socks_proxy( raise ProxyError("failed to connect to SOCKS proxy") from exc -def prepare_connect_request( - proxy: Proxy, - ws_uri: WebSocketURI, - user_agent_header: str | None = None, -) -> bytes: - host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) - headers = Headers() - headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) - if user_agent_header is not None: - headers["User-Agent"] = user_agent_header - if proxy.username is not None: - assert proxy.password is not None # enforced by parse_proxy() - headers["Proxy-Authorization"] = build_authorization_basic( - proxy.username, proxy.password - ) - # We cannot use the Request class because it supports only GET requests. - return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() - - def read_connect_response(sock: socket.socket, deadline: Deadline) -> Response: reader = StreamReader() parser = Response.parse( @@ -557,7 +541,8 @@ def connect_http_proxy( # Send CONNECT request to the proxy and read response. - sock.sendall(prepare_connect_request(proxy, ws_uri, user_agent_header)) + request = prepare_connect_request(proxy, ws_uri, user_agent_header) + sock.sendall(request) try: read_connect_response(sock, deadline) except Exception: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 6ef1ef039..bf3fdccff 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -59,11 +59,10 @@ def __init__( self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout - self.max_queue: tuple[int | None, int | None] if isinstance(max_queue, int) or max_queue is None: - self.max_queue = (max_queue, None) + max_queue_high, max_queue_low = max_queue, None else: - self.max_queue = max_queue + max_queue_high, max_queue_low = max_queue # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -92,7 +91,8 @@ def __init__( # Assembler turning frames into messages and serializing reads. self.recv_messages = Assembler( - *self.max_queue, + max_queue_high, + max_queue_low, pause=self.recv_flow_control.acquire, resume=self.recv_flow_control.release, ) @@ -104,9 +104,9 @@ def __init__( self.send_in_progress = False # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {} + self.pending_pings: dict[bytes, tuple[threading.Event, float, bool]] = {} - self.latency: float = 0 + self.latency: float = 0.0 """ Latency of the connection, in seconds. @@ -284,8 +284,8 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data is ``0`` or negative, check if a message has been received already and return it, else raise :exc:`TimeoutError`. - If the message is fragmented, wait until all fragments are received, - reassemble them, and return the whole message. + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. Args: timeout: Timeout for receiving a message in seconds. @@ -425,8 +425,8 @@ def send( You may override this behavior with the ``text`` argument: - * Set ``text=True`` to send a bytestring or bytes-like object - (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + * Set ``text=True`` to send an UTF-8 bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) in a Text_ frame. This improves performance when the message is already UTF-8 encoded, for example if the message contains JSON and you're using a JSON library that produces a bytestring. @@ -530,7 +530,7 @@ def send( self.protocol.send_binary(chunk, fin=False) encode = False else: - raise TypeError("data iterable must contain bytes or str") + raise TypeError("iterable must contain bytes or str") # Other fragments for chunk in chunks: @@ -543,7 +543,7 @@ def send( assert self.send_in_progress self.protocol.send_continuation(chunk, fin=False) else: - raise TypeError("data iterable must contain uniform types") + raise TypeError("iterable must contain uniform types") # Final fragment. with self.send_context(): @@ -633,8 +633,9 @@ def ping( :: - pong_event = ws.ping() - pong_event.wait() # only if you want to wait for the pong + pong_received = ws.ping() + # only if you want to wait for the corresponding pong + pong_received.wait() Raises: ConnectionClosed: When the connection is closed. @@ -651,17 +652,18 @@ def ping( with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pong_waiters: + if data in self.pending_pings: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pong_waiters: + while data is None or data in self.pending_pings: data = struct.pack("!I", random.getrandbits(32)) - pong_waiter = threading.Event() - self.pong_waiters[data] = (pong_waiter, time.monotonic(), ack_on_close) + pong_received = threading.Event() + ping_timestamp = time.monotonic() + self.pending_pings[data] = (pong_received, ping_timestamp, ack_on_close) self.protocol.send_ping(data) - return pong_waiter + return pong_received def pong(self, data: DataLike = b"") -> None: """ @@ -711,7 +713,7 @@ def acknowledge_pings(self, data: bytes) -> None: """ with self.protocol_mutex: # Ignore unsolicited pong. - if data not in self.pong_waiters: + if data not in self.pending_pings: return pong_timestamp = time.monotonic() @@ -721,21 +723,21 @@ def acknowledge_pings(self, data: bytes) -> None: ping_id = None ping_ids = [] for ping_id, ( - pong_waiter, + pong_received, ping_timestamp, _ack_on_close, - ) in self.pong_waiters.items(): + ) in self.pending_pings.items(): ping_ids.append(ping_id) - pong_waiter.set() + pong_received.set() if ping_id == data: self.latency = pong_timestamp - ping_timestamp break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pong_waiters. + # Remove acknowledged pings from self.pending_pings. for ping_id in ping_ids: - del self.pong_waiters[ping_id] + del self.pending_pings[ping_id] def acknowledge_pending_pings(self) -> None: """ @@ -744,11 +746,11 @@ def acknowledge_pending_pings(self) -> None: """ assert self.protocol.state is CLOSED - for pong_waiter, _ping_timestamp, ack_on_close in self.pong_waiters.values(): + for pong_received, _ping_timestamp, ack_on_close in self.pending_pings.values(): if ack_on_close: - pong_waiter.set() + pong_received.set() - self.pong_waiters.clear() + self.pending_pings.clear() def keepalive(self) -> None: """ @@ -766,15 +768,14 @@ def keepalive(self) -> None: break try: - pong_waiter = self.ping(ack_on_close=True) + pong_received = self.ping(ack_on_close=True) except ConnectionClosed: break if self.debug: self.logger.debug("% sent keepalive ping") if self.ping_timeout is not None: - # - if pong_waiter.wait(self.ping_timeout): + if pong_received.wait(self.ping_timeout): if self.debug: self.logger.debug("% received keepalive pong") else: @@ -808,15 +809,17 @@ def recv_events(self) -> None: Run this method in a thread as long as the connection is alive. - ``recv_events()`` exits immediately when the ``self.socket`` is closed. + ``recv_events()`` exits immediately when ``self.socket`` is closed. """ try: while True: try: + # If the assembler buffer is full, block until it drains. with self.recv_flow_control: - if self.close_deadline is not None: - self.socket.settimeout(self.close_deadline.timeout()) + pass + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: @@ -859,9 +862,8 @@ def recv_events(self) -> None: self.set_recv_exc(exc) break + # If needed, set the close deadline based on the close timeout. if self.protocol.close_expected(): - # If the connection is expected to close soon, set the - # close deadline based on the close timeout. if self.close_deadline is None: self.close_deadline = Deadline(self.close_timeout) @@ -878,6 +880,7 @@ def recv_events(self) -> None: # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. + with self.protocol_mutex: # Feed the end of the data stream to the protocol. self.protocol.receive_eof() @@ -957,11 +960,10 @@ def send_context( # Check if the connection is expected to close soon. if self.protocol.close_expected(): wait_for_close = True - # If the connection is expected to close soon, set the - # close deadline based on the close timeout. - # Since we tested earlier that protocol.state was OPEN + # Set the close deadline based on the close timeout. + # Since we tested earlier that protocol.state is OPEN # (or CONNECTING) and we didn't release protocol_mutex, - # it is certain that self.close_deadline is still None. + # self.close_deadline is still None. assert self.close_deadline is None self.close_deadline = Deadline(self.close_timeout) # Write outgoing data to the socket. @@ -983,6 +985,9 @@ def send_context( # Minor layering violation: we assume that the connection # will be closing soon if it isn't in the expected state. wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_deadline is None: + self.close_deadline = Deadline(self.close_timeout) raise_close_exc = True # To avoid a deadlock, release the connection lock by exiting the @@ -991,13 +996,10 @@ def send_context( # If the connection is expected to close soon and the close timeout # elapses, close the socket to terminate the connection. if wait_for_close: - if self.close_deadline is None: - timeout = self.close_timeout - else: - # Thread.join() returns immediately if timeout is negative. - timeout = self.close_deadline.timeout(raise_if_elapsed=False) + # Thread.join() returns immediately if timeout is negative. + assert self.close_deadline is not None + timeout = self.close_deadline.timeout(raise_if_elapsed=False) self.recv_events_thread.join(timeout) - if self.recv_events_thread.is_alive(): # There's no risk to overwrite another error because # original_exc is never set when wait_for_close is True. @@ -1023,9 +1025,6 @@ def send_data(self) -> None: This method requires holding protocol_mutex. - Raises: - OSError: When a socket operations fails. - """ assert self.protocol_mutex.locked() for data in self.protocol.data_to_send(): @@ -1047,7 +1046,7 @@ def set_recv_exc(self, exc: BaseException | None) -> None: """ assert self.protocol_mutex.locked() - if self.recv_exc is None: # pragma: no branch + if self.recv_exc is None: self.recv_exc = exc def close_socket(self) -> None: diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index c4d04bc83..d95519f63 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -165,7 +165,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: try: deadline = Deadline(timeout) - # First frame + # Fetch the first frame. frame = self.get_next_frame(deadline.timeout(raise_if_elapsed=False)) with self.mutex: self.maybe_resume() @@ -174,7 +174,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: decode = frame.opcode is OP_TEXT frames = [frame] - # Following frames, for fragmented messages + # Fetch subsequent frames for fragmented messages. while not frame.fin: try: frame = self.get_next_frame( @@ -245,7 +245,7 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. - # First frame + # Yield the first frame. frame = self.get_next_frame() with self.mutex: self.maybe_resume() @@ -259,7 +259,7 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: # Convert to bytes when frame.data is a bytearray. yield bytes(frame.data) - # Following frames, for fragmented messages + # Yield subsequent frames for fragmented messages. while not frame.fin: frame = self.get_next_frame() with self.mutex: @@ -300,26 +300,26 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" - # Skip if flow control is disabled + # Skip if flow control is disabled. if self.high is None: return assert self.mutex.locked() - # Check for "> high" to support high = 0 + # Check for "> high" to support high = 0. if self.frames.qsize() > self.high and not self.paused: self.paused = True self.pause() def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" - # Skip if flow control is disabled + # Skip if flow control is disabled. if self.low is None: return assert self.mutex.locked() - # Check for "<= low" to support low = 0 + # Check for "<= low" to support low = 0. if self.frames.qsize() <= self.low and self.paused: self.paused = False self.resume() diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index ffd82fbad..bf8829772 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -605,7 +605,12 @@ def protocol_select_subprotocol( connection.recv_events_thread.join() return - assert connection.protocol.state is OPEN + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + connection.close_socket() + connection.recv_events_thread.join() + return + try: connection.start_keepalive() handler(connection) diff --git a/src/websockets/trio/__init__.py b/src/websockets/trio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/websockets/trio/client.py b/src/websockets/trio/client.py new file mode 100644 index 000000000..13f98fbd2 --- /dev/null +++ b/src/websockets/trio/client.py @@ -0,0 +1,734 @@ +from __future__ import annotations + +import logging +import os +import ssl as ssl_module +import sys +import traceback +import urllib.parse +from collections.abc import AsyncIterator, Generator, Sequence +from types import TracebackType +from typing import Any, Callable, Literal + +import trio + +from ..asyncio.client import process_exception +from ..client import ClientProtocol, backoff +from ..datastructures import HeadersLike +from ..exceptions import ( + InvalidProxyMessage, + InvalidProxyStatus, + InvalidStatus, + ProxyError, + SecurityError, +) +from ..extensions.base import ClientExtensionFactory +from ..extensions.permessage_deflate import enable_client_permessage_deflate +from ..headers import validate_subprotocols +from ..http11 import USER_AGENT, Response +from ..protocol import CONNECTING, Event +from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request +from ..streams import StreamReader +from ..typing import LoggerLike, Origin, Subprotocol +from ..uri import WebSocketURI, parse_uri +from .connection import Connection +from .utils import race_events + + +if sys.version_info[:2] < (3, 11): # pragma: no cover + from exceptiongroup import BaseExceptionGroup + + +__all__ = ["connect", "ClientConnection"] + +MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) + + +class ClientConnection(Connection): + """ + :mod:`trio` implementation of a WebSocket client connection. + + :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines + for receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`connect`. + + Args: + nursery: Trio nursery. + stream: Trio stream connected to a WebSocket server. + protocol: Sans-I/O connection. + + """ + + def __init__( + self, + nursery: trio.Nursery, + stream: trio.abc.Stream, + protocol: ClientProtocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.protocol: ClientProtocol + super().__init__( + nursery, + stream, + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + self.response_rcvd = trio.Event() + + async def handshake( + self, + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + ) -> None: + """ + Perform the opening handshake. + + """ + async with self.send_context(expected_state=CONNECTING): + self.request = self.protocol.connect() + if additional_headers is not None: + self.request.headers.update(additional_headers) + if user_agent_header is not None: + self.request.headers.setdefault("User-Agent", user_agent_header) + self.protocol.send_request(self.request) + + await race_events(self.response_rcvd, self.stream_closed) + + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a response, when the response cannot be parsed, or when the + # response fails the handshake. + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake response. + if self.response is None: + assert isinstance(event, Response) + self.response = event + self.response_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + +# This is spelled in lower case because it's exposed as a callable in the API. +class connect: + """ + Connect to the WebSocket server at ``uri``. + + This coroutine returns a :class:`ClientConnection` instance, which you can + use to send and receive messages. + + :func:`connect` may be used as an asynchronous context manager:: + + from websockets.trio.client import connect + + async with connect(...) as websocket: + ... + + The connection is closed automatically when exiting the context. + + :func:`connect` can be used as an infinite asynchronous iterator to + reconnect automatically on errors:: + + async for websocket in connect(...): + try: + ... + except websockets.exceptions.ConnectionClosed: + continue + + If the connection fails with a transient error, it is retried with + exponential backoff. If it fails with a fatal error, the exception is + raised, breaking out of the loop. + + The connection is closed automatically after each iteration of the loop. + + Args: + uri: URI of the WebSocket server. + stream: Preexisting TCP stream. ``stream`` overrides the host and port + from ``uri``. You may call :func:`~trio.open_tcp_stream` to create a + suitable TCP stream. + ssl: Configuration for enabling TLS on the connection. + server_hostname: Host name for the TLS handshake. ``server_hostname`` + overrides the host name from ``uri``. + origin: Value of the ``Origin`` header, for servers that require it. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + additional_headers (HeadersLike | None): Arbitrary HTTP headers to add + to the handshake request. + user_agent_header: Value of the ``User-Agent`` request header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. + Setting it to :obj:`None` removes the header. + proxy: If a proxy is configured, it is used by default. Set ``proxy`` + to :obj:`None` to disable the proxy or to the address of a proxy + to override the system configuration. See the :doc:`proxy docs + <../../topics/proxies>` for details. + proxy_ssl: Configuration for enabling TLS on the proxy connection. + proxy_server_hostname: Host name for the TLS handshake with the proxy. + ``proxy_server_hostname`` overrides the host name from ``proxy``. + process_exception: When reconnecting automatically, tell whether an + error is transient or fatal. The default behavior is defined by + :func:`process_exception`. Refer to its documentation for details. + open_timeout: Timeout for opening the connection in seconds. + :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing the connection in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. + logger: Logger for this client. + It defaults to ``logging.getLogger("websockets.client")``. + See the :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ClientConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + + Any other keyword arguments are passed to :func:`~trio.open_tcp_stream`. + + Raises: + InvalidURI: If ``uri`` isn't a valid WebSocket URI. + InvalidProxy: If ``proxy`` isn't a valid proxy. + OSError: If the TCP connection fails. + InvalidHandshake: If the opening handshake fails. + TimeoutError: If the opening handshake times out. + + """ + + # Arguments of type SSLContext don't render correctly in the documentation + # because of https://github.com/sphinx-doc/sphinx/issues/13838. + + def __init__( + self, + uri: str, + *, + # TCP/TLS + stream: trio.abc.Stream | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, + # WebSocket + origin: Origin | None = None, + extensions: Sequence[ClientExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + compression: str | None = "deflate", + # HTTP + additional_headers: HeadersLike | None = None, + user_agent_header: str | None = USER_AGENT, + proxy: str | Literal[True] | None = True, + proxy_ssl: ssl_module.SSLContext | None = None, + proxy_server_hostname: str | None = None, + process_exception: Callable[[Exception], Exception | None] = process_exception, + # Timeouts + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + # Limits + max_size: int | None | tuple[int | None, int | None] = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ClientConnection] | None = None, + # Other keyword arguments are passed to trio.open_tcp_stream + **kwargs: Any, + ) -> None: + self.uri = uri + + ws_uri = parse_uri(uri) + self.ws_uri = ws_uri + + if not ws_uri.secure and ssl is not None: + raise ValueError("ssl argument is incompatible with a ws:// URI") + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_client_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if proxy is True: + proxy = get_proxy(ws_uri) + + if logger is None: + logger = logging.getLogger("websockets.client") + + if create_connection is None: + create_connection = ClientConnection + + self.stream = stream + self.ssl = ssl + self.server_hostname = server_hostname + self.proxy = proxy + self.proxy_ssl = proxy_ssl + self.proxy_server_hostname = proxy_server_hostname + self.additional_headers = additional_headers + self.user_agent_header = user_agent_header + self.process_exception = process_exception + self.open_timeout = open_timeout + self.logger = logger + self.create_connection = create_connection + self.open_tcp_stream_kwargs = kwargs + self.protocol_kwargs = dict( + origin=origin, + extensions=extensions, + subprotocols=subprotocols, + max_size=max_size, + logger=logger, + ) + self.connection_kwargs = dict( + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + + async def open_tcp_stream(self) -> trio.abc.Stream: + """Open a TCP connection to the server, possibly through a proxy.""" + # TCP connection is already established. + if self.stream is not None: + return self.stream + + # Connect to the server through a proxy. + elif self.proxy is not None: + proxy_parsed = parse_proxy(self.proxy) + + if proxy_parsed.scheme[:5] == "socks": + return await connect_socks_proxy( + proxy_parsed, + self.ws_uri, + local_address=self.open_tcp_stream_kwargs.get("local_address"), + ) + + elif proxy_parsed.scheme[:4] == "http": + if proxy_parsed.scheme != "https" and self.proxy_ssl is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + return await connect_http_proxy( + proxy_parsed, + self.ws_uri, + user_agent_header=self.user_agent_header, + ssl=self.proxy_ssl, + server_hostname=self.proxy_server_hostname, + local_address=self.open_tcp_stream_kwargs.get("local_address"), + ) + + else: + raise NotImplementedError(f"unsupported proxy: {self.proxy}") + + # Connect to the server directly. + else: + kwargs = self.open_tcp_stream_kwargs.copy() + kwargs.setdefault("host", self.ws_uri.host) + kwargs.setdefault("port", self.ws_uri.port) + return await trio.open_tcp_stream(**kwargs) + + async def enable_tls(self, stream: trio.abc.Stream) -> trio.abc.Stream: + """Enable TLS on the connection.""" + if self.ssl is None: + ssl = ssl_module.create_default_context() + else: + ssl = self.ssl + if self.server_hostname is None: + server_hostname = self.ws_uri.host + else: + server_hostname = self.server_hostname + ssl_stream = trio.SSLStream( + stream, + ssl, + server_hostname=server_hostname, + https_compatible=True, + ) + await ssl_stream.do_handshake() + return ssl_stream + + async def open_connection(self, nursery: trio.Nursery) -> ClientConnection: + """Create a WebSocket connection.""" + stream: trio.abc.Stream + stream = await self.open_tcp_stream() + + try: + if self.ws_uri.secure: + stream = await self.enable_tls(stream) + + protocol = ClientProtocol( + self.ws_uri, + **self.protocol_kwargs, # type: ignore + ) + + connection = self.create_connection( # default is ClientConnection + nursery, + stream, + protocol, + **self.connection_kwargs, # type: ignore + ) + + await connection.handshake( + self.additional_headers, + self.user_agent_header, + ) + + return connection + + except trio.Cancelled: + await trio.aclose_forcefully(stream) + # The nursery running this coroutine was canceled. + # The next checkpoint raises trio.Cancelled. + # aclose_forcefully() never returns. + raise AssertionError("nursery should be canceled") + except Exception: + # Always close the connection even though keep-alive is the default + # in HTTP/1.1 because the current implementation ties opening the + # TCP/TLS connection with initializing the WebSocket protocol. + await trio.aclose_forcefully(stream) + raise + + def process_redirect(self, exc: Exception) -> Exception | tuple[str, WebSocketURI]: + """ + Determine whether a connection error is a redirect that can be followed. + + Return the new URI if it's a valid redirect. Else, return an exception. + + """ + if not ( + isinstance(exc, InvalidStatus) + and exc.response.status_code + in [ + 300, # Multiple Choices + 301, # Moved Permanently + 302, # Found + 303, # See Other + 307, # Temporary Redirect + 308, # Permanent Redirect + ] + and "Location" in exc.response.headers + ): + return exc + + old_ws_uri = parse_uri(self.uri) + new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) + new_ws_uri = parse_uri(new_uri) + + # If connect() received a stream, it is closed and cannot be reused. + if self.stream is not None: + return ValueError( + f"cannot follow redirect to {new_uri} with a preexisting stream" + ) + + # TLS downgrade is forbidden. + if old_ws_uri.secure and not new_ws_uri.secure: + return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}") + + # Apply restrictions to cross-origin redirects. + if ( + old_ws_uri.secure != new_ws_uri.secure + or old_ws_uri.host != new_ws_uri.host + or old_ws_uri.port != new_ws_uri.port + ): + # Cross-origin redirects when host and port are overridden are ill-defined. + if ( + self.open_tcp_stream_kwargs.get("host") is not None + or self.open_tcp_stream_kwargs.get("port") is not None + ): + return ValueError( + f"cannot follow cross-origin redirect to {new_uri} " + f"with an explicit host or port" + ) + + return new_uri, new_ws_uri + + async def connect(self, nursery: trio.Nursery) -> ClientConnection: + try: + with ( + trio.CancelScope() + if self.open_timeout is None + else trio.fail_after(self.open_timeout) + ): + for _ in range(MAX_REDIRECTS): + try: + connection = await self.open_connection(nursery) + except Exception as exc: + uri_or_exc = self.process_redirect(exc) + # Response is a valid redirect; follow it. + if isinstance(uri_or_exc, Exception): + if uri_or_exc is exc: + raise + else: + raise uri_or_exc from exc + # Response isn't a valid redirect; raise the exception. + else: + self.uri, self.ws_uri = uri_or_exc + continue + + else: + connection.start_keepalive() + return connection + else: + raise SecurityError(f"more than {MAX_REDIRECTS} redirects") + + except trio.TooSlowError as exc: + # Re-raise exception with an informative error message. + raise TimeoutError("timed out during opening handshake") from exc + + # Do not define __await__ for... = await nursery.start(connect, ...) + # because it doesn't look idiomatic in Trio. + + # async with connect(...) as ...: ... + + async def __aenter__(self) -> ClientConnection: + await self.__aenter_nursery__() + try: + self.connection = await self.connect(self.nursery) + return self.connection + except BaseException as exc: + await self.__aexit_nursery__(type(exc), exc, exc.__traceback__) + raise AssertionError("expected __aexit_nursery__ to re-raise the exception") + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + try: + await self.connection.aclose() + del self.connection + finally: + await self.__aexit_nursery__(exc_type, exc_value, traceback) + + async def __aenter_nursery__(self) -> None: + if hasattr(self, "nursery_manager"): # pragma: no cover + raise RuntimeError("connect() isn't reentrant") + self.nursery_manager = trio.open_nursery() + self.nursery = await self.nursery_manager.__aenter__() + + async def __aexit_nursery__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + # We need a nursery to start the recv_events and keepalive coroutines. + # They aren't expected to raise exceptions; instead they catch and log + # all unexpected errors. To keep the nursery an implementation detail, + # unwrap exceptions raised by user code -- per the second option here: + # https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors + try: + await self.nursery_manager.__aexit__(exc_type, exc_value, traceback) + except BaseException as exc: + assert isinstance(exc, BaseExceptionGroup) + try: + trio._util.raise_single_exception_from_group(exc) + except trio._util.MultipleExceptionError: # pragma: no cover + raise AssertionError( + "unexpected multiple exceptions; please file a bug report" + ) from exc + finally: + del self.nursery_manager + + # async for ... in connect(...): + + async def __aiter__(self) -> AsyncIterator[ClientConnection]: + delays: Generator[float] | None = None + while True: + try: + async with self as connection: + yield connection + except Exception as exc: + # Determine whether the exception is retryable or fatal. + # The API of process_exception is "return an exception or None"; + # "raise an exception" is also supported because it's a frequent + # mistake. It isn't documented in order to keep the API simple. + try: + new_exc = self.process_exception(exc) + except Exception as raised_exc: + new_exc = raised_exc + + # The connection failed with a fatal error. + # Raise the exception and exit the loop. + if new_exc is exc: + raise + if new_exc is not None: + raise new_exc from exc + + # The connection failed with a retryable error. + # Start or continue backoff and reconnect. + if delays is None: + delays = backoff() + delay = next(delays) + self.logger.info( + "connect failed; reconnecting in %.1f seconds: %s", + delay, + traceback.format_exception_only(exc)[0].strip(), + ) + await trio.sleep(delay) + continue + + else: + # The connection succeeded. Reset backoff. + delays = None + + +try: + from python_socks import ProxyType + from python_socks.async_.trio import Proxy as SocksProxy + +except ImportError: + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> trio.abc.Stream: + raise ImportError("connecting through a SOCKS proxy requires python-socks") + +else: + SOCKS_PROXY_TYPES = { + "socks5h": ProxyType.SOCKS5, + "socks5": ProxyType.SOCKS5, + "socks4a": ProxyType.SOCKS4, + "socks4": ProxyType.SOCKS4, + } + + SOCKS_PROXY_RDNS = { + "socks5h": True, + "socks5": False, + "socks4a": True, + "socks4": False, + } + + async def connect_socks_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + **kwargs: Any, + ) -> trio.abc.Stream: + """Connect via a SOCKS proxy and return the socket.""" + socks_proxy = SocksProxy( + SOCKS_PROXY_TYPES[proxy.scheme], + proxy.host, + proxy.port, + proxy.username, + proxy.password, + SOCKS_PROXY_RDNS[proxy.scheme], + ) + # connect() is documented to raise OSError. + # socks_proxy.connect() re-raises trio.TooSlowError as ProxyTimeoutError. + # Wrap other exceptions in ProxyError, a subclass of InvalidHandshake. + try: + return trio.SocketStream( + await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs) + ) + except OSError: + raise + except Exception as exc: + raise ProxyError("failed to connect to SOCKS proxy") from exc + + +async def read_connect_response(stream: trio.abc.Stream) -> Response: + reader = StreamReader() + parser = Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + proxy=True, + ) + try: + while True: + data = await stream.receive_some(4096) + if data: + reader.feed_data(data) + else: + reader.feed_eof() + next(parser) + except StopIteration as exc: + assert isinstance(exc.value, Response) # help mypy + response = exc.value + if 200 <= response.status_code < 300: + return response + else: + raise InvalidProxyStatus(response) + except Exception as exc: + raise InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) from exc + + +async def connect_http_proxy( + proxy: Proxy, + ws_uri: WebSocketURI, + *, + user_agent_header: str | None = None, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, + **kwargs: Any, +) -> trio.abc.Stream: + stream: trio.abc.Stream + stream = await trio.open_tcp_stream(proxy.host, proxy.port, **kwargs) + + try: + # Initialize TLS wrapper and perform TLS handshake + if proxy.scheme == "https": + if ssl is None: + ssl = ssl_module.create_default_context() + if server_hostname is None: + server_hostname = proxy.host + ssl_stream = trio.SSLStream( + stream, + ssl, + server_hostname=server_hostname, + https_compatible=True, + ) + await ssl_stream.do_handshake() + stream = ssl_stream + + # Send CONNECT request to the proxy and read response. + request = prepare_connect_request(proxy, ws_uri, user_agent_header) + await stream.send_all(request) + await read_connect_response(stream) + + except (trio.Cancelled, Exception): + await trio.aclose_forcefully(stream) + raise + + return stream diff --git a/src/websockets/trio/connection.py b/src/websockets/trio/connection.py new file mode 100644 index 000000000..61532316f --- /dev/null +++ b/src/websockets/trio/connection.py @@ -0,0 +1,1124 @@ +from __future__ import annotations + +import contextlib +import logging +import random +import struct +import uuid +from collections.abc import AsyncIterable, AsyncIterator, Iterable, Mapping +from types import TracebackType +from typing import Any, Literal, overload + +import trio +import trio.abc + +from ..asyncio.compatibility import ( + TimeoutError, + aiter, + anext, +) +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedOK, + ProtocolError, +) +from ..frames import DATA_OPCODES, CloseCode, Frame, Opcode +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol +from .messages import Assembler + + +__all__ = ["Connection"] + + +class Connection(trio.abc.AsyncResource): + """ + :mod:`trio` implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.trio.client.ClientConnection` or + :class:`~websockets.trio.server.ServerConnection`. + + """ + + def __init__( + self, + nursery: trio.Nursery, + stream: trio.abc.Stream, + protocol: Protocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.nursery = nursery + self.stream = stream + self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + if isinstance(max_queue, int) or max_queue is None: + max_queue_high, max_queue_low = max_queue, None + else: + max_queue_high, max_queue_low = max_queue + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Request | None = None + """Opening handshake request.""" + self.response: Response | None = None + """Opening handshake response.""" + + # Lock stopping reads when the assembler buffer is full. + self.recv_flow_control = trio.Lock() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages = Assembler( + max_queue_high, + max_queue_low, + pause=self.recv_flow_control.acquire_nowait, + resume=self.recv_flow_control.release, + ) + + # Deadline for the closing handshake. + self.close_deadline: float | None = None + + # Lock preventing concurrent calls to send_all or send_eof. + self.send_lock = trio.Lock() + + # Protect sending fragmented messages. + self.send_in_progress: trio.Event | None = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pending_pings: dict[bytes, tuple[trio.Event, float, bool]] = {} + + self.latency: float = 0.0 + """ + Latency of the connection, in seconds. + + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. + """ + + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.stream_closed: trio.Event = trio.Event() + + # Start recv_events only after all attributes are initialized. + self.nursery.start_soon(self.recv_events) + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + if isinstance(self.stream, trio.SSLStream): # pragma: no cover + stream = self.stream.transport_stream + else: + stream = self.stream + if isinstance(stream, trio.SocketStream): + return stream.socket.getsockname() + else: # pragma: no cover + raise NotImplementedError(f"unsupported stream type: {stream}") + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + if isinstance(self.stream, trio.SSLStream): # pragma: no cover + stream = self.stream.transport_stream + else: + stream = self.stream + if isinstance(stream, trio.SocketStream): + return stream.socket.getpeername() + else: # pragma: no cover + raise NotImplementedError(f"unsupported stream type: {stream}") + + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. + + """ + return self.protocol.state + + @property + def subprotocol(self) -> Subprotocol | None: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + + # Public methods + + async def __aenter__(self) -> Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if exc_type is None: + await self.aclose() + else: + await self.aclose(CloseCode.INTERNAL_ERROR) + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages asynchronously in an + infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + @overload + async def recv(self, decode: Literal[True]) -> str: ... + + @overload + async def recv(self, decode: Literal[False]) -> bytes: ... + + @overload + async def recv(self, decode: bool | None = None) -> Data: ... + + async def recv(self, decode: bool | None = None) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing data. The next + invocation of :meth:`recv` will return the next message. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~trio.move_on_after` or :func:`~trio.fail_after`. + + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return await self.recv_messages.get(decode) + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv while another coroutine " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await self.stream_closed.wait() + raise self.protocol.close_exc from self.recv_exc + + @overload + def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + + This method is designed for receiving fragmented messages. It returns an + asynchronous iterator that yields each fragment as it is received. This + iterator must be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection + unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Canceling :meth:`recv_streaming` before receiving the first frame is + safe. Canceling it after receiving one or more frames leaves the + iterator in a partially consumed state, making the connection unusable. + Instead, you should close the connection with :meth:`aclose`. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + async for frame in self.recv_messages.get_iter(decode): + yield frame + return + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await self.stream_closed.wait() + raise self.protocol.close_exc from self.recv_exc + + async def send( + self, + message: DataLike | Iterable[DataLike] | AsyncIterable[DataLike], + text: bool | None = None, + ) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send an UTF-8 bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) in a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + + :meth:`send` also accepts an iterable or asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`aclose`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`aclose` has the same effect and is + more clear: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`aclose` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + TypeError: If ``message`` doesn't have a supported type. + + """ + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self.send_in_progress is not None: + await self.send_in_progress.wait() + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + async with self.send_context(): + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) + + elif isinstance(message, BytesLike): + async with self.send_context(): + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + assert self.send_in_progress is None + self.send_in_progress = trio.Event() + try: + # First fragment. + if isinstance(chunk, str): + async with self.send_context(): + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True + elif isinstance(chunk, BytesLike): + async with self.send_context(): + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False + else: + raise TypeError("iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and encode: + async with self.send_context(): + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: + async with self.send_context(): + self.protocol.send_continuation(chunk, fin=False) + else: + raise TypeError("iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) + raise + + finally: + self.send_in_progress.set() + self.send_in_progress = None + + # Fragmented message -- async iterator. + + elif isinstance(message, AsyncIterable): + achunks = aiter(message) + try: + chunk = await anext(achunks) + except StopAsyncIteration: + return + + assert self.send_in_progress is None + self.send_in_progress = trio.Event() + try: + # First fragment. + if isinstance(chunk, str): + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True + elif isinstance(chunk, BytesLike): + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False + else: + raise TypeError("async iterable must contain bytes or str") + + # Other fragments + async for chunk in achunks: + if isinstance(chunk, str) and encode: + async with self.send_context(): + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: + async with self.send_context(): + self.protocol.send_continuation(chunk, fin=False) + else: + raise TypeError("async iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) + raise + + finally: + self.send_in_progress.set() + self.send_in_progress = None + + else: + raise TypeError("data must be str, bytes, iterable, or async iterable") + + async def aclose( + self, + code: CloseCode | int = CloseCode.NORMAL_CLOSURE, + reason: str = "", + ) -> None: + """ + Perform the closing handshake. + + :meth:`aclose` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`aclose` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + async with self.send_context(): + if self.send_in_progress is not None: + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + # Safety net: enforce the semantics of trio.abc.AsyncResource.aclose(). + except BaseException: # pragma: no cover + await trio.aclose_forcefully(self.stream) + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + :meth:`wait_closed` waits for the closing handshake to complete and for + the TCP connection to terminate. + + """ + await self.stream_closed.wait() + + async def ping( + self, + data: DataLike | None = None, + ack_on_close: bool = False, + ) -> trio.Event: + """ + Send a Ping_. + + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + ack_on_close: when this option is :obj:`True`, the event will also + be set when the connection is closed. While this avoids getting + stuck waiting for a pong that will never arrive, it requires + checking that the state of the connection is still ``OPEN`` to + confirm that a pong was received, rather than the connection + being closed. + + Returns: + An event that will be set when the corresponding pong is received. + You can ignore it if you don't intend to wait. + + :: + + pong_received = await ws.ping() + # only if you want to wait for the corresponding pong + await pong_received.wait() + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + elif data is not None: + raise TypeError("data must be str or bytes-like") + + async with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pending_pings: + raise ConcurrencyError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pending_pings: + data = struct.pack("!I", random.getrandbits(32)) + + pong_received = trio.Event() + ping_timestamp = trio.current_time() + self.pending_pings[data] = (pong_received, ping_timestamp, ack_on_close) + self.protocol.send_ping(data) + return pong_received + + async def pong(self, data: DataLike = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + else: + raise TypeError("data must be str or bytes-like") + + async with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + # Ignore unsolicited pong. + if data not in self.pending_pings: + return + + pong_timestamp = trio.current_time() + + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, ( + pong_received, + ping_timestamp, + _ack_on_close, + ) in self.pending_pings.items(): + ping_ids.append(ping_id) + pong_received.set() + if ping_id == data: + self.latency = pong_timestamp - ping_timestamp + break + else: + raise AssertionError("solicited pong not found in pings") + + # Remove acknowledged pings from self.pending_pings. + for ping_id in ping_ids: + del self.pending_pings[ping_id] + + def acknowledge_pending_pings(self) -> None: + """ + Acknowledge pending pings when the connection is closed. + + """ + assert self.protocol.state is CLOSED + + for pong_received, _ping_timestamp, ack_on_close in self.pending_pings.values(): + if ack_on_close: + pong_received.set() + + self.pending_pings.clear() + + async def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + try: + while True: + # If self.ping_timeout > self.latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. + with trio.move_on_after(self.ping_interval - self.latency): + await self.stream_closed.wait() + break + + try: + pong_received = await self.ping(ack_on_close=True) + except ConnectionClosed: + break + if self.debug: + self.logger.debug("% sent keepalive ping") + + if self.ping_timeout is not None: + with trio.move_on_after(self.ping_timeout) as cancel_scope: + await pong_received.wait() + if self.debug: + self.logger.debug("% received keepalive pong") + if cancel_scope.cancelled_caught: + if self.debug: + self.logger.debug("- timed out waiting for keepalive pong") + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a task, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + self.nursery.start_soon(self.keepalive) + + async def recv_events(self) -> None: + """ + Read incoming data from the stream and process events. + + Run this method in a task as long as the connection is alive. + + ``recv_events()`` exits immediately when ``self.stream`` is closed. + + """ + try: + while True: + try: + # If the assembler buffer is full, block until it drains. + async with self.recv_flow_control: + pass + data = await self.stream.receive_some() + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while receiving data", + exc_info=True, + ) + # When the closing handshake is initiated by our side, + # recv() may block until send_context() closes the stream. + # In that case, send_context() already set recv_exc. + # Calling set_recv_exc() avoids overwriting it. + self.set_recv_exc(exc) + break + + if data == b"": + break + + # Feed incoming data to the protocol. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the stream. + try: + await self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while sending data", + exc_info=True, + ) + # Similarly to the above, avoid overriding an exception + # set by send_context(), in case of a race condition + # i.e. send_context() closes the transport after recv() + # returns above but before send_data() calls send(). + self.set_recv_exc(exc) + break + + # If needed, set the close deadline based on the close timeout. + if self.protocol.close_expected(): + if self.close_deadline is None and self.close_timeout is not None: + self.close_deadline = trio.current_time() + self.close_timeout + + # If self.send_data raised an exception, then events are lost. + # Given that automatic responses write small amounts of data, + # this should be uncommon, so we don't handle the edge case. + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + # Breaking out of the while True: ... loop means that we believe + # that the stream doesn't work anymore. + + # Feed the end of the data stream to the protocol. + self.protocol.receive_eof() + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream here and it handles errors itself. + await self.send_data() + + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + + except Exception as exc: + # This branch should never run. It's a safety net in case of bugs. + self.logger.error("unexpected internal error", exc_info=True) + self.set_recv_exc(exc) + finally: + # This isn't expected to raise an exception. + await self.close_stream() + + @contextlib.asynccontextmanager + async def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> AsyncIterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` checks that the connection is open; on + exit, it writes outgoing data to the socket:: + + async with self.send_context(): + self.protocol.send_text(message.encode()) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the stream and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: BaseException | None = None + + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, ConcurrencyError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # Set the close deadline based on the close timeout. + # Since we tested earlier that protocol.state is OPEN + # (or CONNECTING), self.close_deadline is still None. + assert self.close_deadline is None + if self.close_timeout is not None: + self.close_deadline = trio.current_time() + self.close_timeout + # Write outgoing data to the socket with flow control. + try: + await self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("! error while sending data", exc_info=True) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_deadline is None and self.close_timeout is not None: + self.close_deadline = trio.current_time() + self.close_timeout + raise_close_exc = True + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + if self.close_deadline is not None: + with trio.move_on_at(self.close_deadline) as cancel_scope: + await self.stream_closed.wait() + if cancel_scope.cancelled_caught: + # There's no risk to overwrite another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_exc before closing the transport in order to get + # proper exception reporting. + raise_close_exc = True + self.set_recv_exc(original_exc) + else: + await self.stream_closed.wait() + + # If an error occurred, close the transport to terminate the connection and + # raise an exception. + if raise_close_exc: + await self.close_stream() + raise self.protocol.close_exc from original_exc + + async def send_data(self) -> None: + """ + Send outgoing data. + + """ + # Serialize calls to send_all(). + async with self.send_lock: + for data in self.protocol.data_to_send(): + if data: + await self.stream.send_all(data) + else: + # Half-close the TCP connection when possible i.e. no TLS. + if isinstance(self.stream, trio.abc.HalfCloseableStream): + if self.debug: + self.logger.debug("x half-closing TCP connection") + try: + await self.stream.send_eof() + except Exception: # pragma: no cover + pass + # Else, close the TCP connection. + else: # pragma: no cover + if self.debug: + self.logger.debug("x closing TCP connection") + await self.stream.aclose() + + def set_recv_exc(self, exc: BaseException | None) -> None: + """ + Set recv_exc, if not set yet. + + """ + if self.recv_exc is None: + self.recv_exc = exc + + async def close_stream(self) -> None: + """ + Shutdown and close stream. Close message assembler. + + Calling close_stream() guarantees that recv_events() terminates. Indeed, + recv_events() may block only on stream.recv() or on recv_messages.put(). + + """ + # Close the stream. + await self.stream.aclose() + + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + self.protocol.receive_eof() + assert self.protocol.state is CLOSED + + # Abort recv() with a ConnectionClosed exception. + self.recv_messages.close() + + # Acknowledge pings sent with the ack_on_close option. + self.acknowledge_pending_pings() + + # Unblock coroutines waiting on self.stream_closed. + self.stream_closed.set() diff --git a/src/websockets/trio/messages.py b/src/websockets/trio/messages.py new file mode 100644 index 000000000..51f60e74d --- /dev/null +++ b/src/websockets/trio/messages.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +import codecs +import math +from collections.abc import AsyncIterator +from typing import Any, Callable, Literal, overload + +import trio + +from ..exceptions import ConcurrencyError +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + + +class Assembler: + """ + Assemble messages from frames. + + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + + """ + + def __init__( + self, + high: int | None = None, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: + # Queue of incoming frames. + self.send_frames: trio.MemorySendChannel[Frame] + self.recv_frames: trio.MemoryReceiveChannel[Frame] + self.send_frames, self.recv_frames = trio.open_memory_channel(math.inf) + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if high is not None and low is None: + low = high // 4 + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # This flag marks the end of the connection. + self.closed = False + + @overload + async def get(self, decode: Literal[True]) -> str: ... + + @overload + async def get(self, decode: Literal[False]) -> bytes: ... + + @overload + async def get(self, decode: bool | None = None) -> Data: ... + + async def get(self, decode: bool | None = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + + """ + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or is canceled. + + try: + # Fetch the first frame. + try: + frame = await self.recv_frames.receive() + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Fetch subsequent frames for fragmented messages. + while not frame.fin: + try: + frame = await self.recv_frames.receive() + except trio.Cancelled: + # Put frames already received back into the queue + # so that future calls to get() can return them. + assert not self.send_frames._state.receive_tasks, ( + "no task should be waiting on receive()" + ) + assert not self.send_frames._state.data, "queue should be empty" + for frame in frames: + self.send_frames.send_nowait(frame) + raise + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False + + # This converts frame.data to bytes when it's a bytearray. + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data + + @overload + def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` asynchronously yields a + :class:`str` or :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + + """ + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or is canceled. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + + # Yield the first frame. + try: + frame = await self.recv_frames.receive() + except trio.Cancelled: + self.get_in_progress = False + raise + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) + + # Yield subsequent frames for fragmented messages. + while not frame.fin: + # We cannot handle trio.Cancelled because we don't buffer + # previous fragments — we're streaming them. Canceling get_iter() + # here will leave the assembler in a stuck state. Future calls to + # get() or get_iter() will raise ConcurrencyError. + try: + frame = await self.recv_frames.receive() + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) + + self.get_in_progress = False + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + Raises: + EOFError: If the stream of frames has ended. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + self.send_frames.send_nowait(frame) + self.maybe_pause() + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled. + if self.high is None: + return + + # Bypass the statistics() method for performance reasons. + # Check for "> high" to support high = 0. + if len(self.send_frames._state.data) > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled. + if self.low is None: + return + + # Bypass the statistics() method for performance reasons. + # Check for "<= low" to support low = 0. + if len(self.send_frames._state.data) <= self.low and self.paused: + self.paused = False + self.resume() + + def close(self) -> None: + """ + End the stream of frames. + + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + if self.closed: + return + + self.closed = True + + # Unblock get() or get_iter(). + self.send_frames.close() diff --git a/src/websockets/trio/server.py b/src/websockets/trio/server.py new file mode 100644 index 000000000..390c87876 --- /dev/null +++ b/src/websockets/trio/server.py @@ -0,0 +1,634 @@ +from __future__ import annotations + +import functools +import http +import logging +import re +import ssl as ssl_module +from collections.abc import Awaitable, Sequence +from typing import Any, Callable, Mapping + +import trio +import trio.abc + +from ..asyncio.server import basic_auth +from ..extensions.base import ServerExtensionFactory +from ..extensions.permessage_deflate import enable_server_permessage_deflate +from ..frames import CloseCode +from ..headers import validate_subprotocols +from ..http11 import SERVER, Request, Response +from ..protocol import CONNECTING, OPEN, Event +from ..server import ServerProtocol +from ..typing import LoggerLike, Origin, StatusLike, Subprotocol +from .connection import Connection +from .utils import race_events + + +__all__ = [ + "serve", + "ServerConnection", + "basic_auth", +] + + +class ServerConnection(Connection): + """ + :mod:`trio` implementation of a WebSocket server connection. + + :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for + receiving and sending messages. + + It supports asynchronous iteration to receive messages:: + + async for message in websocket: + await process(message) + + The iterator exits normally when the connection is closed with close code + 1000 (OK) or 1001 (going away) or without a close code. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is + closed with any other code. + + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`serve`. + + Args: + nursery: Trio nursery. + stream: Trio stream connected to a WebSocket client. + protocol: Sans-I/O connection. + server: Server that manages this connection. + + """ + + def __init__( + self, + nursery: trio.Nursery, + stream: trio.abc.Stream, + protocol: ServerProtocol, + server: Server, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.protocol: ServerProtocol + super().__init__( + nursery, + stream, + protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + self.server = server + self.request_rcvd: trio.Event = trio.Event() + self.username: str # see basic_auth() + self.handler: Callable[[ServerConnection], Awaitable[None]] # see route() + self.handler_kwargs: Mapping[str, Any] # see route() + + def respond(self, status: StatusLike, text: str) -> Response: + """ + Create a plain text HTTP response. + + ``process_request`` and ``process_response`` may call this method to + return an HTTP response instead of performing the WebSocket opening + handshake. + + You can modify the response before returning it, for example by changing + HTTP headers. + + Args: + status: HTTP status code. + text: HTTP response body; it will be encoded to UTF-8. + + Returns: + HTTP response to send to the client. + + """ + return self.protocol.reject(status, text) + + async def handshake( + self, + process_request: ( + Callable[ + [ServerConnection, Request], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Awaitable[Response | None] | Response | None, + ] + | None + ) = None, + server_header: str | None = SERVER, + ) -> None: + """ + Perform the opening handshake. + + """ + await race_events(self.request_rcvd, self.stream_closed) + + if self.request is not None: + async with self.send_context(expected_state=CONNECTING): + response = None + + if process_request is not None: + try: + response = process_request(self, self.request) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is None: + if self.server.closing: + self.response = self.protocol.reject( + http.HTTPStatus.SERVICE_UNAVAILABLE, + "Server is shutting down.\n", + ) + else: + self.response = self.protocol.accept(self.request) + else: + assert isinstance(response, Response) # help mypy + self.response = response + + if server_header: + self.response.headers["Server"] = server_header + + response = None + + if process_response is not None: + try: + response = process_response(self, self.request, self.response) + if isinstance(response, Awaitable): + response = await response + except Exception as exc: + self.protocol.handshake_exc = exc + response = self.protocol.reject( + http.HTTPStatus.INTERNAL_SERVER_ERROR, + ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ), + ) + + if response is not None: + assert isinstance(response, Response) # help mypy + self.response = response + + self.protocol.send_response(self.response) + + # self.protocol.handshake_exc is set when the connection is lost before + # receiving a request, when the request cannot be parsed, or when the + # handshake fails, including when process_request or process_response + # raises an exception. + + # It isn't set when process_request or process_response sends an HTTP + # response that rejects the handshake. + + if self.protocol.handshake_exc is not None: + raise self.protocol.handshake_exc + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + """ + # First event - handshake request. + if self.request is None: + assert isinstance(event, Request) + self.request = event + self.request_rcvd.set() + # Later events - frames. + else: + super().process_event(event) + + +class Server(trio.abc.AsyncResource): + """ + WebSocket server returned by :func:`serve`. + + Args: + listeners: Trio listeners. + handler: Handler for one connection. Receives a Trio stream. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. + + """ + + def __init__( + self, + open_listeners: Callable[[], Awaitable[list[trio.SocketListener]]], + stream_handler: Callable[[trio.abc.Stream, Server], Awaitable[None]], + logger: LoggerLike | None = None, + ) -> None: + self.open_listeners = open_listeners + self.stream_handler = stream_handler + if logger is None: + logger = logging.getLogger("websockets.server") + self.logger = logger + + self.listeners: list[trio.SocketListener] + """Trio listeners.""" + + self.waiters: dict[ServerConnection, trio.Event] = {} + self.closing = False + + @property + def connections(self) -> set[ServerConnection]: + """ + Set of active connections. + + This property contains all connections that completed the opening + handshake successfully and didn't start the closing handshake yet. + + .. It can be useful in combination with :func:`~broadcast`. + + """ + return { + connection + for connection in self.waiters + if connection.protocol.state is OPEN + } + + async def serve_forever( + self, + task_status: trio.TaskStatus[Server] = trio.TASK_STATUS_IGNORED, + ) -> None: + self.listeners = await self.open_listeners() # used in tests + # Running handlers in a dedicated nursery makes it possible to close + # listeners while handlers finish running. The nursery for listeners + # is created in trio.serve_listeners(). + async with trio.open_nursery() as self.handler_nursery: + # Wrap trio.serve_listeners() in another nursery to return the + # Server object in task_status instead of a list of listeners. + async with trio.open_nursery() as self.serve_nursery: + await self.serve_nursery.start( + functools.partial( + trio.serve_listeners, + functools.partial(self.stream_handler, server=self), # type: ignore + self.listeners, + handler_nursery=self.handler_nursery, + ) + ) + task_status.started(self) + + # Shutting down the server cleanly when serve_forever() is canceled would be + # the most idiomatic in Trio. However, that would require shielding too many + # asynchronous operations, including the TLS & WebSocket opening handshakes. + + async def aclose( + self, + close_connections: bool = True, + code: CloseCode | int = CloseCode.GOING_AWAY, + reason: str = "", + ) -> None: + """ + Close the server. + + * Close the TCP listeners. + * When ``close_connections`` is :obj:`True`, which is the default, + close existing connections. Specifically: + + * Reject opening WebSocket connections with an HTTP 503 (service + unavailable) error. This happens when the server accepted the TCP + connection but didn't complete the opening handshake before closing. + * Close open WebSocket connections with close code 1001 (going away). + ``code`` and ``reason`` can be customized, for example to use code + 1012 (service restart). + + * Wait until all connection handlers have returned. + + :meth:`aclose` is idempotent. + + """ + self.logger.info("server closing") + + # Stop accepting new connections. + self.serve_nursery.cancel_scope.cancel() + + # Reject OPENING connections with HTTP 503 -- see handshake(). + self.closing = True + + # Close OPEN connections. + if close_connections: + for connection in self.waiters: + if connection.protocol.state is not OPEN: # pragma: no cover + continue + self.handler_nursery.start_soon(connection.aclose, code, reason) + + # Wait until all connection handlers have returned. + while self.waiters: + await next(iter(self.waiters.values())).wait() + + self.logger.info("server closed") + + +async def serve( + handler: Callable[[ServerConnection], Awaitable[None]], + port: int | None = None, + *, + # TCP/TLS + host: str | bytes | None = None, + backlog: int | None = None, + listeners: list[trio.SocketListener] | None = None, + ssl: ssl_module.SSLContext | None = None, + # WebSocket + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, + extensions: Sequence[ServerExtensionFactory] | None = None, + subprotocols: Sequence[Subprotocol] | None = None, + select_subprotocol: ( + Callable[ + [ServerConnection, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None, + compression: str | None = "deflate", + # HTTP + process_request: ( + Callable[ + [ServerConnection, Request], + Response | None, + ] + | None + ) = None, + process_response: ( + Callable[ + [ServerConnection, Request, Response], + Response | None, + ] + | None + ) = None, + server_header: str | None = SERVER, + # Timeouts + open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + # Limits + max_size: int | None | tuple[int | None, int | None] = 2**20, + max_queue: int | None | tuple[int | None, int | None] = 16, + # Logging + logger: LoggerLike | None = None, + # Escape hatch for advanced customization + create_connection: type[ServerConnection] | None = None, + # Trio + task_status: trio.TaskStatus[Server] = trio.TASK_STATUS_IGNORED, +) -> None: + """ + Create a WebSocket server listening on ``port``. + + Whenever a client connects, the server creates a :class:`ServerConnection`, + performs the opening handshake, and delegates to the ``handler`` coroutine. + + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. + + Once the handler completes, either normally or with an exception, the server + performs the closing handshake and closes the connection. + + When using :func:`serve` with :meth:`nursery.start `, + you get back a :class:`Server` object. Call its :meth:`~Server.aclose` + method to stop the server gracefully:: + + from websockets.trio.server import serve + + async def handler(websocket): + ... + + # set this event to exit the server + stop = trio.Event() + + with trio.open_nursery() as nursery: + server = await nursery.start(serve, handler, port) + await stop.wait() + await server.aclose() + + Args: + handler: Connection handler. It receives the WebSocket connection, + which is a :class:`ServerConnection`, in argument. + port: TCP port the server listens on. + See :func:`~trio.open_tcp_listeners` for details. + host: Network interfaces the server binds to. + See :func:`~trio.open_tcp_listeners` for details. + backlog: Listen backlog. See :func:`~trio.open_tcp_listeners` for + details. + listeners: Preexisting TCP listeners. ``listeners`` replaces ``port``, + ``host``, and ``backlog``. See :func:`trio.serve_listeners` for + details. + ssl: Configuration for enabling TLS on the connection. + origins: Acceptable values of the ``Origin`` header, for defending + against Cross-Site WebSocket Hijacking attacks. Values can be + :class:`str` to test for an exact match or regular expressions + compiled by :func:`re.compile` to test against a pattern. Include + :obj:`None` in the list if the lack of an origin is acceptable. + extensions: List of supported extensions, in order in which they + should be negotiated and run. + subprotocols: List of supported subprotocols, in order of decreasing + preference. + select_subprotocol: Callback for selecting a subprotocol among + those supported by the client and the server. It receives a + :class:`ServerConnection` (not a + :class:`~websockets.server.ServerProtocol`!) instance and a list of + subprotocols offered by the client. Other than the first argument, + it has the same behavior as the + :meth:`ServerProtocol.select_subprotocol + ` method. + compression: The "permessage-deflate" extension is enabled by default. + Set ``compression`` to :obj:`None` to disable it. See the + :doc:`compression guide <../../topics/compression>` for details. + process_request: Intercept the request during the opening handshake. + Return an HTTP response to force the response. Return :obj:`None` to + continue normally. When you force an HTTP 101 Continue response, the + handshake is successful. Else, the connection is aborted. + process_response: Intercept the response during the opening handshake. + Modify the response or return a new HTTP response to force the + response. Return :obj:`None` to continue normally. When you force an + HTTP 101 Continue response, the handshake is successful. Else, the + connection is aborted. + server_header: Value of the ``Server`` response header. + It defaults to ``"Python/x.y.z websockets/X.Y"``. Setting it to + :obj:`None` removes the header. + open_timeout: Timeout for opening connections in seconds. + :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. + close_timeout: Timeout for closing connections in seconds. + :obj:`None` disables the timeout. + max_size: Maximum size of incoming messages in bytes. + :obj:`None` disables the limit. You may pass a ``(max_message_size, + max_fragment_size)`` tuple to set different limits for messages and + fragments when you expect long messages sent in short fragments. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. If you want to disable flow control entirely, + you may set it to ``None``, although that's a bad idea. + logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. See the + :doc:`logging guide <../../topics/logging>` for details. + create_connection: Factory for the :class:`ServerConnection` managing + the connection. Set it to a wrapper or a subclass to customize + connection handling. + task_status: For compatibility with :meth:`nursery.start + `. + + """ + + # Process parameters + + if subprotocols is not None: + validate_subprotocols(subprotocols) + + if compression == "deflate": + extensions = enable_server_permessage_deflate(extensions) + elif compression is not None: + raise ValueError(f"unsupported compression: {compression}") + + if create_connection is None: + create_connection = ServerConnection + + # Create listeners + + if listeners is None: + if port is None: + raise ValueError("port is required when listeners is not provided") + + async def open_listeners() -> list[trio.SocketListener]: + return await trio.open_tcp_listeners(port, host=host, backlog=backlog) + else: + if port is not None: + raise ValueError("port is incompatible with listeners") + if host is not None: + raise ValueError("host is incompatible with listeners") + if backlog is not None: + raise ValueError("backlog is incompatible with listeners") + + async def open_listeners() -> list[trio.SocketListener]: + return listeners + + async def stream_handler(stream: trio.abc.Stream, server: Server) -> None: + async with trio.open_nursery() as nursery: + try: + # Apply open_timeout to the TLS and WebSocket handshake. + with ( + trio.CancelScope() + if open_timeout is None + else trio.move_on_after(open_timeout) + ): + # Enable TLS. + if ssl is not None: + # Wrap with SSLStream here rather than with TLSListener + # in order to include the TLS handshake within open_timeout. + stream = trio.SSLStream( + stream, + ssl, + server_side=True, + https_compatible=True, + ) + assert isinstance(stream, trio.SSLStream) # help mypy + try: + await stream.do_handshake() + except trio.BrokenResourceError: + return + + # Create a closure to give select_subprotocol access to connection. + protocol_select_subprotocol: ( + Callable[ + [ServerProtocol, Sequence[Subprotocol]], + Subprotocol | None, + ] + | None + ) = None + if select_subprotocol is not None: + + def protocol_select_subprotocol( + protocol: ServerProtocol, + subprotocols: Sequence[Subprotocol], + ) -> Subprotocol | None: + # mypy doesn't know that select_subprotocol is immutable. + assert select_subprotocol is not None + # Ensure this function is only used in the intended context. + assert protocol is connection.protocol + return select_subprotocol(connection, subprotocols) + + # Initialize WebSocket protocol. + protocol = ServerProtocol( + origins=origins, + extensions=extensions, + subprotocols=subprotocols, + select_subprotocol=protocol_select_subprotocol, + max_size=max_size, + logger=logger, + ) + + # Initialize WebSocket connection. + assert create_connection is not None # help mypy + connection = create_connection( + nursery, + stream, + protocol, + server, + ping_interval=ping_interval, + ping_timeout=ping_timeout, + close_timeout=close_timeout, + max_queue=max_queue, + ) + + try: + await connection.handshake( + process_request, + process_response, + server_header, + ) + except trio.Cancelled: + # The nursery running this coroutine was canceled. + # The next checkpoint raises trio.Cancelled. + # aclose_forcefully() never returns. + await trio.aclose_forcefully(stream) + raise AssertionError("nursery should be canceled") + except Exception: + connection.logger.error( + "opening handshake failed", exc_info=True + ) + await trio.aclose_forcefully(stream) + return + + if connection.protocol.state is not OPEN: + # process_request or process_response rejected the handshake. + await connection.close_stream() + return + + try: + server.waiters[connection] = trio.Event() + connection.start_keepalive() + await handler(connection) + except Exception: + connection.logger.error("connection handler failed", exc_info=True) + await connection.aclose(CloseCode.INTERNAL_ERROR) + else: + await connection.aclose() + finally: + server.waiters.pop(connection).set() + + except Exception: # pragma: no cover + # Don't leak connections on unexpected errors. + await trio.aclose_forcefully(stream) + + server = Server(open_listeners, stream_handler, logger) + await server.serve_forever(task_status=task_status) diff --git a/src/websockets/trio/utils.py b/src/websockets/trio/utils.py new file mode 100644 index 000000000..8f3bdd822 --- /dev/null +++ b/src/websockets/trio/utils.py @@ -0,0 +1,42 @@ +import sys + +import trio + + +if sys.version_info[:2] < (3, 11): # pragma: no cover + from exceptiongroup import BaseExceptionGroup + + +__all__ = ["race_events"] + + +# Based on https://trio.readthedocs.io/en/stable/reference-core.html#custom-supervisors + + +async def jockey(event: trio.Event, cancel_scope: trio.CancelScope) -> None: + await event.wait() + cancel_scope.cancel() + + +async def race_events(*events: trio.Event) -> None: + """ + Wait for any of the given events to be set. + + Args: + *events: The events to wait for. + + """ + if not events: + raise ValueError("no events provided") + + try: + async with trio.open_nursery() as nursery: + for event in events: + nursery.start_soon(jockey, event, nursery.cancel_scope) + except BaseExceptionGroup as exc: + try: + trio._util.raise_single_exception_from_group(exc) + except trio._util.MultipleExceptionError: # pragma: no cover + raise AssertionError( + "race_events should be canceled; please file a bug report" + ) from exc diff --git a/src/websockets/uri.py b/src/websockets/uri.py index b925b99b5..f85e16810 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -2,9 +2,8 @@ import dataclasses import urllib.parse -import urllib.request -from .exceptions import InvalidProxy, InvalidURI +from .exceptions import InvalidURI __all__ = ["parse_uri", "WebSocketURI"] @@ -106,120 +105,3 @@ def parse_uri(uri: str) -> WebSocketURI: password = urllib.parse.quote(password, safe=DELIMS) return WebSocketURI(secure, host, port, path, query, username, password) - - -@dataclasses.dataclass -class Proxy: - """ - Proxy. - - Attributes: - scheme: ``"socks5h"``, ``"socks5"``, ``"socks4a"``, ``"socks4"``, - ``"https"``, or ``"http"``. - host: Normalized to lower case. - port: Always set even if it's the default. - username: Available when the proxy address contains `User Information`_. - password: Available when the proxy address contains `User Information`_. - - .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 - - """ - - scheme: str - host: str - port: int - username: str | None = None - password: str | None = None - - @property - def user_info(self) -> tuple[str, str] | None: - if self.username is None: - return None - assert self.password is not None - return (self.username, self.password) - - -def parse_proxy(proxy: str) -> Proxy: - """ - Parse and validate a proxy. - - Args: - proxy: proxy. - - Returns: - Parsed proxy. - - Raises: - InvalidProxy: If ``proxy`` isn't a valid proxy. - - """ - parsed = urllib.parse.urlparse(proxy) - if parsed.scheme not in ["socks5h", "socks5", "socks4a", "socks4", "https", "http"]: - raise InvalidProxy(proxy, f"scheme {parsed.scheme} isn't supported") - if parsed.hostname is None: - raise InvalidProxy(proxy, "hostname isn't provided") - if parsed.path not in ["", "/"]: - raise InvalidProxy(proxy, "path is meaningless") - if parsed.query != "": - raise InvalidProxy(proxy, "query is meaningless") - if parsed.fragment != "": - raise InvalidProxy(proxy, "fragment is meaningless") - - scheme = parsed.scheme - host = parsed.hostname - port = parsed.port or (443 if parsed.scheme == "https" else 80) - username = parsed.username - password = parsed.password - # urllib.parse.urlparse accepts URLs with a username but without a - # password. This doesn't make sense for HTTP Basic Auth credentials. - if username is not None and password is None: - raise InvalidProxy(proxy, "username provided without password") - - try: - proxy.encode("ascii") - except UnicodeEncodeError: - # Input contains non-ASCII characters. - # It must be an IRI. Convert it to a URI. - host = host.encode("idna").decode() - if username is not None: - assert password is not None - username = urllib.parse.quote(username, safe=DELIMS) - password = urllib.parse.quote(password, safe=DELIMS) - - return Proxy(scheme, host, port, username, password) - - -def get_proxy(uri: WebSocketURI) -> str | None: - """ - Return the proxy to use for connecting to the given WebSocket URI, if any. - - """ - if urllib.request.proxy_bypass(f"{uri.host}:{uri.port}"): - return None - - # According to the _Proxy Usage_ section of RFC 6455, use a SOCKS5 proxy if - # available, else favor the proxy for HTTPS connections over the proxy for - # HTTP connections. - - # The priority of a proxy for WebSocket connections is unspecified. We give - # it the highest priority. This makes it easy to configure a specific proxy - # for websockets. - - # getproxies() may return SOCKS proxies as {"socks": "http://host:port"} or - # as {"https": "socks5h://host:port"} depending on whether they're declared - # in the operating system or in environment variables. - - proxies = urllib.request.getproxies() - if uri.secure: - schemes = ["wss", "socks", "https"] - else: - schemes = ["ws", "socks", "https", "http"] - - for scheme in schemes: - proxy = proxies.get(scheme) - if proxy is not None: - if scheme == "socks" and proxy.startswith("http://"): - proxy = "socks5h://" + proxy[7:] - return proxy - else: - return None diff --git a/tests/__init__.py b/tests/__init__.py index bb1866f2d..83b10efb2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,6 @@ import logging import os +import tracemalloc format = "%(asctime)s %(levelname)s %(name)s %(message)s" @@ -12,3 +13,7 @@ level = logging.CRITICAL logging.basicConfig(format=format, level=level) + +if bool(os.environ.get("WEBSOCKETS_TRACE")): # pragma: no cover + # Trace allocations to debug resource warnings. + tracemalloc.start() diff --git a/tests/asyncio/connection.py b/tests/asyncio/connection.py index ad1c121bf..5cd673d97 100644 --- a/tests/asyncio/connection.py +++ b/tests/asyncio/connection.py @@ -21,7 +21,7 @@ def delay_frames_sent(self, delay): """ Add a delay before sending frames. - This can result in out-of-order writes, which is unrealistic. + Misuse can result in out-of-order writes, which is unrealistic. """ assert self.transport.delay_write is None @@ -36,7 +36,7 @@ def delay_eof_sent(self, delay): """ Add a delay before sending EOF. - This can result in out-of-order writes, which is unrealistic. + Misuse can result in out-of-order writes, which is unrealistic. """ assert self.transport.delay_write_eof is None @@ -83,9 +83,9 @@ class InterceptingTransport: This is coupled to the implementation, which relies on these two methods. - Since ``write()`` and ``write_eof()`` are not coroutines, this effect is - achieved by scheduling writes at a later time, after the methods return. - This can easily result in out-of-order writes, which is unrealistic. + Since ``write()`` and ``write_eof()`` are synchronous, we can only schedule + writes at a later time, after they return. This is unrealistic and can lead + to out-of-order writes if tests aren't written carefully. """ @@ -101,15 +101,15 @@ def __getattr__(self, name): return getattr(self.transport, name) def write(self, data): - if not self.drop_write: - if self.delay_write is not None: - self.loop.call_later(self.delay_write, self.transport.write, data) - else: - self.transport.write(data) + if self.delay_write is not None: + assert not self.drop_write + self.loop.call_later(self.delay_write, self.transport.write, data) + elif not self.drop_write: + self.transport.write(data) def write_eof(self): - if not self.drop_write_eof: - if self.delay_write_eof is not None: - self.loop.call_later(self.delay_write_eof, self.transport.write_eof) - else: - self.transport.write_eof() + if self.delay_write_eof is not None: + assert not self.drop_write_eof + self.loop.call_later(self.delay_write_eof, self.transport.write_eof) + elif not self.drop_write_eof: + self.transport.write_eof() diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index a83074ae8..eff026230 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -75,7 +75,7 @@ async def test_existing_socket(self): """Client connects using a pre-existing socket.""" async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") @@ -341,7 +341,7 @@ def redirect(connection, request): async with serve(*args, process_request=redirect) as server: with socket.create_connection(get_host_port(server)) as sock: with self.assertRaises(ValueError) as raised: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. async with connect("ws://invalid/redirect", sock=sock): self.fail("did not raise") @@ -446,9 +446,11 @@ async def test_junk_handshake(self): """Client closes the connection when receiving non-HTTP response from server.""" async def junk(reader, writer): - await asyncio.sleep(MS) # wait for the client to send the handshake request + # Wait for the client to send the handshake request. + await asyncio.sleep(MS) writer.write(b"220 smtp.invalid ESMTP Postfix\r\n") - await reader.read(4096) # wait for the client to close the connection + # Wait for the client to close the connection. + await reader.read(4096) writer.close() server = await asyncio.start_server(junk, "localhost", 0) @@ -652,7 +654,7 @@ async def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" async with serve(*args) as server: with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. async with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @@ -1000,3 +1002,16 @@ async def test_unsupported_compression(self): str(raised.exception), "unsupported compression: False", ) + + async def test_reentrancy(self): + """Client isn't reentrant.""" + async with serve(*args) as server: + connecter = connect(get_uri(server)) + async with connecter: + with self.assertRaises(RuntimeError) as raised: + async with connecter: + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connect() isn't reentrant", + ) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 6cad971c7..e31cac5b5 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import itertools import logging import socket import sys @@ -19,9 +20,8 @@ from websockets.protocol import CLIENT, SERVER, Protocol, State from ..protocol import RecordingProtocol -from ..utils import MS +from ..utils import MS, alist from .connection import InterceptingConnection -from .utils import alist # Connection implements symmetrical behavior between clients and servers. @@ -33,13 +33,13 @@ class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): REMOTE = SERVER async def asyncSetUp(self): - loop = asyncio.get_running_loop() + self.loop = asyncio.get_running_loop() socket_, remote_socket = socket.socketpair() - self.transport, self.connection = await loop.create_connection( + self.transport, self.connection = await self.loop.create_connection( lambda: Connection(Protocol(self.LOCAL), close_timeout=2 * MS), sock=socket_, ) - self.remote_transport, self.remote_connection = await loop.create_connection( + _remote_transport, self.remote_connection = await self.loop.create_connection( lambda: InterceptingConnection(RecordingProtocol(self.REMOTE)), sock=remote_socket, ) @@ -50,27 +50,25 @@ async def asyncTearDown(self): # Test helpers built upon RecordingProtocol and InterceptingConnection. - async def assertFrameSent(self, frame): - """Check that a single frame was sent.""" - # Let the remote side process messages. + async def wait_for_remote_side(self): + """Wait for the remote side to process messages.""" # Two runs of the event loop are required for answering pings. await asyncio.sleep(0) await asyncio.sleep(0) + + async def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + await self.wait_for_remote_side() self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) async def assertFramesSent(self, frames): """Check that several frames were sent.""" - # Let the remote side process messages. - # Two runs of the event loop are required for answering pings. - await asyncio.sleep(0) - await asyncio.sleep(0) + await self.wait_for_remote_side() self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) async def assertNoFrameSent(self): """Check that no frame was sent.""" - # Run the event loop twice for consistency with assertFrameSent. - await asyncio.sleep(0) - await asyncio.sleep(0) + await self.wait_for_remote_side() self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) @contextlib.asynccontextmanager @@ -78,28 +76,28 @@ async def delay_frames_rcvd(self, delay): """Delay frames before they're received by the connection.""" with self.remote_connection.delay_frames_sent(delay): yield - await asyncio.sleep(MS) # let the remote side process messages + await self.wait_for_remote_side() @contextlib.asynccontextmanager async def delay_eof_rcvd(self, delay): """Delay EOF before it's received by the connection.""" with self.remote_connection.delay_eof_sent(delay): yield - await asyncio.sleep(MS) # let the remote side process messages + await self.wait_for_remote_side() @contextlib.asynccontextmanager async def drop_frames_rcvd(self): """Drop frames before they're received by the connection.""" with self.remote_connection.drop_frames_sent(): yield - await asyncio.sleep(MS) # let the remote side process messages + await self.wait_for_remote_side() @contextlib.asynccontextmanager async def drop_eof_rcvd(self): """Drop EOF before it's received by the connection.""" with self.remote_connection.drop_eof_sent(): yield - await asyncio.sleep(MS) # let the remote side process messages + await self.wait_for_remote_side() # Test __aenter__ and __aexit__. @@ -114,8 +112,8 @@ async def test_aexit(self): await self.assertNoFrameSent() await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) - async def test_exit_with_exception(self): - """__exit__ with an exception closes the connection with code 1011.""" + async def test_aexit_with_exception(self): + """__aexit__ with an exception closes the connection with code 1011.""" with self.assertRaises(RuntimeError): async with self.connection: raise RuntimeError @@ -125,41 +123,46 @@ async def test_exit_with_exception(self): async def test_aiter_text(self): """__aiter__ yields text messages.""" - aiterator = aiter(self.connection) - await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") - await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") async def test_aiter_binary(self): """__aiter__ yields binary messages.""" - aiterator = aiter(self.connection) - await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") - await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") async def test_aiter_mixed(self): """__aiter__ yields a mix of text and binary messages.""" - aiterator = aiter(self.connection) - await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") - await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") async def test_aiter_connection_closed_ok(self): """__aiter__ terminates after a normal closure.""" - aiterator = aiter(self.connection) - await self.remote_connection.close() - with self.assertRaises(StopAsyncIteration): - await anext(aiterator) + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.close() + with self.assertRaises(StopAsyncIteration): + await anext(iterator) async def test_aiter_connection_closed_error(self): """__aiter__ raises ConnectionClosedError after an error.""" - aiterator = aiter(self.connection) - await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) - with self.assertRaises(ConnectionClosedError): - await anext(aiterator) + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await anext(iterator) # Test recv. @@ -245,10 +248,9 @@ async def test_recv_during_recv_streaming(self): ) async def test_recv_cancellation_before_receiving(self): - """recv can be canceled before receiving a frame.""" + """recv can be canceled before receiving a message.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task - recv_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_task @@ -257,25 +259,25 @@ async def test_recv_cancellation_before_receiving(self): self.assertEqual(await self.connection.recv(), "😀") async def test_recv_cancellation_while_receiving(self): - """recv cannot be canceled after receiving a frame.""" - recv_task = asyncio.create_task(self.connection.recv()) - await asyncio.sleep(0) # let the event loop start recv_task - - gate = asyncio.get_running_loop().create_future() + """recv can be canceled while receiving a fragmented message.""" + gate = asyncio.Event() async def fragments(): yield "⏳" - await gate + await gate.wait() yield "⌛️" asyncio.create_task(self.remote_connection.send(fragments())) - await asyncio.sleep(MS) + await asyncio.sleep(0) + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task recv_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_task + gate.set() + # Running recv again receives the complete message. - gate.set_result(None) self.assertEqual(await self.connection.recv(), "⏳⌛️") # Test recv_streaming. @@ -404,28 +406,31 @@ async def test_recv_streaming_cancellation_before_receiving(self): async def test_recv_streaming_cancellation_while_receiving(self): """recv_streaming cannot be canceled after receiving a frame.""" - recv_streaming_task = asyncio.create_task( - alist(self.connection.recv_streaming()) - ) - await asyncio.sleep(0) # let the event loop start recv_streaming_task - - gate = asyncio.get_running_loop().create_future() + gate = asyncio.Event() async def fragments(): yield "⏳" - await gate + await gate.wait() yield "⌛️" asyncio.create_task(self.remote_connection.send(fragments())) - await asyncio.sleep(MS) + await asyncio.sleep(0) + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + await asyncio.sleep(0) # experimentally, two runs of the event loop + await asyncio.sleep(0) # are needed to receive the first fragment recv_streaming_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_streaming_task - gate.set_result(None) + gate.set() + # Running recv_streaming again fails. with self.assertRaises(ConcurrencyError): - await alist(self.connection.recv_streaming()) + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") # Test send. @@ -553,23 +558,31 @@ async def test_send_connection_closed_error(self): with self.assertRaises(ConnectionClosedError): await self.connection.send("😀") - async def test_send_while_send_blocked(self): + async def test_send_during_send(self): """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed - # from send() in the case when message is an Iterable. - self.connection.pause_writing() - asyncio.create_task(self.connection.send(["⏳", "⌛️"])) - await asyncio.sleep(MS) + # This test fails if the guard with send_in_progress is removed + # from send() in the case when message is an AsyncIterable. + gate = asyncio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(0) # let the event loop start the task await self.assertFrameSent( Frame(Opcode.TEXT, "⏳".encode(), fin=False), ) asyncio.create_task(self.connection.send("✅")) - await asyncio.sleep(MS) + await asyncio.sleep(0) # let the event loop start the task await self.assertNoFrameSent() - self.connection.resume_writing() - await asyncio.sleep(MS) + gate.set() + await asyncio.sleep(0) # run the event loop + await asyncio.sleep(0) # three times in order + await asyncio.sleep(0) # to send three frames await self.assertFramesSent( [ Frame(Opcode.CONT, "⌛️".encode(), fin=False), @@ -578,28 +591,26 @@ async def test_send_while_send_blocked(self): ] ) - async def test_send_while_send_async_blocked(self): - """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed - # from send() in the case when message is an AsyncIterable. + async def test_send_while_send_blocked(self): + """send waits for a blocked call to send to complete.""" + # This test fails if the guard with send_in_progress is removed + # from send() in the case when message is an Iterable. self.connection.pause_writing() - async def fragments(): - yield "⏳" - yield "⌛️" - - asyncio.create_task(self.connection.send(fragments())) - await asyncio.sleep(MS) + asyncio.create_task(self.connection.send(["⏳", "⌛️"])) + await asyncio.sleep(0) # let the event loop start the task await self.assertFrameSent( Frame(Opcode.TEXT, "⏳".encode(), fin=False), ) asyncio.create_task(self.connection.send("✅")) - await asyncio.sleep(MS) + await asyncio.sleep(0) # let the event loop start the task await self.assertNoFrameSent() self.connection.resume_writing() - await asyncio.sleep(MS) + await asyncio.sleep(0) # run the event loop + await asyncio.sleep(0) # three times in order + await asyncio.sleep(0) # to send three frames await self.assertFramesSent( [ Frame(Opcode.CONT, "⌛️".encode(), fin=False), @@ -608,29 +619,30 @@ async def fragments(): ] ) - async def test_send_during_send_async(self): - """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed + async def test_send_while_send_async_blocked(self): + """send waits for a blocked call to send to complete.""" + # This test fails if the guard with send_in_progress is removed # from send() in the case when message is an AsyncIterable. - gate = asyncio.get_running_loop().create_future() + self.connection.pause_writing() async def fragments(): yield "⏳" - await gate yield "⌛️" asyncio.create_task(self.connection.send(fragments())) - await asyncio.sleep(MS) + await asyncio.sleep(0) # let the event loop start the task await self.assertFrameSent( Frame(Opcode.TEXT, "⏳".encode(), fin=False), ) asyncio.create_task(self.connection.send("✅")) - await asyncio.sleep(MS) + await asyncio.sleep(0) # let the event loop start the task await self.assertNoFrameSent() - gate.set_result(None) - await asyncio.sleep(MS) + self.connection.resume_writing() + await asyncio.sleep(0) # run the event loop + await asyncio.sleep(0) # three times in order + await asyncio.sleep(0) # to send three frames await self.assertFramesSent( [ Frame(Opcode.CONT, "⌛️".encode(), fin=False), @@ -708,9 +720,15 @@ async def test_close_explicit_code_reason(self): await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) async def test_close_waits_for_close_frame(self): - """close waits for a close frame (then EOF) before returning.""" + """close waits for a close frame then EOF before returning.""" + t0 = self.loop.time() async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -724,8 +742,14 @@ async def test_close_waits_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = self.loop.time() async with self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -735,11 +759,17 @@ async def test_close_waits_for_connection_closed(self): self.assertIsNone(exc.__cause__) async def test_close_no_timeout_waits_for_close_frame(self): - """close without timeout waits for a close frame (then EOF) before returning.""" + """close without timeout waits for a close frame then EOF before returning.""" self.connection.close_timeout = None + t0 = self.loop.time() async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -755,8 +785,14 @@ async def test_close_no_timeout_waits_for_connection_closed(self): self.connection.close_timeout = None + t0 = self.loop.time() async with self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -767,8 +803,14 @@ async def test_close_no_timeout_waits_for_connection_closed(self): async def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" + t0 = self.loop.time() async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() @@ -782,8 +824,14 @@ async def test_close_timeout_waiting_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = self.loop.time() async with self.drop_eof_rcvd(): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -798,13 +846,9 @@ async def test_close_preserves_queued_messages(self): await self.connection.close() self.assertEqual(await self.connection.recv(), "😀") - with self.assertRaises(ConnectionClosedOK) as raised: + with self.assertRaises(ConnectionClosedOK): await self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") - self.assertIsNone(exc.__cause__) - async def test_close_idempotency(self): """close does nothing if the connection is already closed.""" await self.connection.close() @@ -815,11 +859,15 @@ async def test_close_idempotency(self): async def test_close_during_recv(self): """close aborts recv when called concurrently with recv.""" - recv_task = asyncio.create_task(self.connection.recv()) - await asyncio.sleep(MS) - await self.connection.close() + + async def closer(): + await asyncio.sleep(MS) + await self.connection.close() + + asyncio.create_task(closer()) + with self.assertRaises(ConnectionClosedOK) as raised: - await recv_task + await self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") @@ -827,23 +875,26 @@ async def test_close_during_recv(self): async def test_close_during_send(self): """close fails the connection when called concurrently with send.""" - gate = asyncio.get_running_loop().create_future() + close_gate = asyncio.Event() + exit_gate = asyncio.Event() + + async def closer(): + await close_gate.wait() + await self.connection.close() + exit_gate.set() async def fragments(): yield "⏳" - await gate + close_gate.set() + await exit_gate.wait() yield "⌛️" - send_task = asyncio.create_task(self.connection.send(fragments())) - await asyncio.sleep(MS) - - asyncio.create_task(self.connection.close()) - await asyncio.sleep(MS) + asyncio.create_task(closer()) - gate.set_result(None) - - with self.assertRaises(ConnectionClosedError) as raised: - await send_task + iterator = fragments() + async with contextlib.aclosing(iterator): + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send(iterator) exc = raised.exception self.assertEqual( @@ -865,9 +916,10 @@ async def test_wait_closed(self): # Test ping. - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") async def test_ping(self, getrandbits): """ping sends a ping frame with a random payload.""" + getrandbits.side_effect = itertools.count(1918987876) await self.connection.ping() getrandbits.assert_called_once_with(32) await self.assertFrameSent(Frame(Opcode.PING, b"rand")) @@ -885,54 +937,54 @@ async def test_ping_explicit_binary(self): async def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.remote_connection.pong("this") async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_canceled_ping(self): """ping is acknowledged by a pong with the same payload after being canceled.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") - pong_waiter.cancel() + pong_received = await self.connection.ping("this") + pong_received.cancel() await self.remote_connection.pong("this") with self.assertRaises(asyncio.CancelledError): - await pong_waiter + await pong_received async def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.remote_connection.pong("that") with self.assertRaises(TimeoutError): async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong for a later ping.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.connection.ping("that") await self.remote_connection.pong("that") async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_previous_canceled_ping(self): """ping is acknowledged by a pong for a later ping after being canceled.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") - pong_waiter_2 = await self.connection.ping("that") - pong_waiter.cancel() + pong_received = await self.connection.ping("this") + pong_received_2 = await self.connection.ping("that") + pong_received.cancel() await self.remote_connection.pong("that") async with asyncio_timeout(MS): - await pong_waiter_2 + await pong_received_2 with self.assertRaises(asyncio.CancelledError): - await pong_waiter + await pong_received async def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("idem") + pong_received = await self.connection.ping("idem") with self.assertRaises(ConcurrencyError) as raised: await self.connection.ping("idem") @@ -943,7 +995,7 @@ async def test_ping_duplicate_payload(self): await self.remote_connection.pong("idem") async with asyncio_timeout(MS): - await pong_waiter + await pong_received await self.connection.ping("idem") # doesn't raise an exception @@ -976,9 +1028,10 @@ async def test_pong_unsupported_type(self): # Test keepalive. - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") async def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping_interval = 3 * MS self.connection.start_keepalive() self.assertIsNotNone(self.connection.keepalive_task) @@ -997,35 +1050,35 @@ async def test_disable_keepalive(self): self.connection.start_keepalive() self.assertIsNone(self.connection.keepalive_task) - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") async def test_keepalive_times_out(self, getrandbits): """keepalive closes the connection if ping_timeout elapses.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping_interval = 4 * MS self.connection.ping_timeout = 2 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. - await asyncio.sleep(4 * MS) - # Exiting the context manager sleeps for 1 ms. # 4.x ms: a pong frame is dropped. + await asyncio.sleep(5 * MS) # 6 ms: no pong frame is received; the connection is closed. - await asyncio.sleep(2 * MS) + await asyncio.sleep(3 * MS) # 7 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") async def test_keepalive_ignores_timeout(self, getrandbits): """keepalive ignores timeouts if ping_timeout isn't set.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping_interval = 4 * MS self.connection.ping_timeout = None async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. # 4.x ms: a pong frame is dropped. - await asyncio.sleep(4 * MS) - # Exiting the context manager sleeps for 1 ms. + await asyncio.sleep(5 * MS) # 6 ms: no pong frame is received; the connection remains open. - await asyncio.sleep(2 * MS) + await asyncio.sleep(3 * MS) # 7 ms: check that the connection is still open. self.assertEqual(self.connection.state, State.OPEN) @@ -1033,6 +1086,7 @@ async def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" self.connection.ping_interval = 3 * MS self.connection.start_keepalive() + self.assertFalse(self.connection.keepalive_task.done()) await asyncio.sleep(MS) await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) @@ -1040,13 +1094,12 @@ async def test_keepalive_terminates_while_sleeping(self): async def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" self.connection.ping_interval = MS - self.connection.ping_timeout = 3 * MS + self.connection.ping_timeout = 4 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 1 ms: keepalive() sends a ping frame. # 1.x ms: a pong frame is dropped. - await asyncio.sleep(MS) - # Exiting the context manager sleeps for 1 ms. + await asyncio.sleep(2 * MS) # 2 ms: close the connection before ping_timeout elapses. await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) @@ -1061,9 +1114,9 @@ async def test_keepalive_reports_errors(self): await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for 1 ms. # 3 ms: inject a fault: raise an exception in the pending pong waiter. - pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] + pong_received = next(iter(self.connection.pending_pings.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: - pong_waiter.set_exception(Exception("BOOM")) + pong_received.set_exception(Exception("BOOM")) await asyncio.sleep(0) self.assertEqual( [record.getMessage() for record in logs.records], @@ -1078,20 +1131,28 @@ async def test_keepalive_reports_errors(self): async def test_close_timeout(self): """close_timeout parameter configures close timeout.""" - connection = Connection(Protocol(self.LOCAL), close_timeout=42 * MS) + connection = Connection( + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) self.assertEqual(connection.close_timeout, 42 * MS) async def test_max_queue(self): """max_queue configures high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=4) - transport = Mock() - connection.connection_made(transport) + connection = Connection( + Protocol(self.LOCAL), + max_queue=4, + ) + connection.connection_made(Mock(spec=asyncio.Transport)) self.assertEqual(connection.recv_messages.high, 4) async def test_max_queue_none(self): """max_queue disables high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=None) - transport = Mock() + connection = Connection( + Protocol(self.LOCAL), + max_queue=None, + ) + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, None) self.assertEqual(connection.recv_messages.low, None) @@ -1102,7 +1163,7 @@ async def test_max_queue_tuple(self): Protocol(self.LOCAL), max_queue=(4, 2), ) - transport = Mock() + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) self.assertEqual(connection.recv_messages.low, 2) @@ -1113,7 +1174,7 @@ async def test_write_limit(self): Protocol(self.LOCAL), write_limit=4096, ) - transport = Mock() + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, None) @@ -1123,7 +1184,7 @@ async def test_write_limits(self): Protocol(self.LOCAL), write_limit=(4096, 2048), ) - transport = Mock() + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) @@ -1137,13 +1198,13 @@ async def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - @patch("asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234)) + @patch("asyncio.Transport.get_extra_info", return_value=("sock", 1234)) async def test_local_address(self, get_extra_info): """Connection provides a local_address attribute.""" self.assertEqual(self.connection.local_address, ("sock", 1234)) get_extra_info.assert_called_with("sockname") - @patch("asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234)) + @patch("asyncio.Transport.get_extra_info", return_value=("peer", 1234)) async def test_remote_address(self, get_extra_info): """Connection provides a remote_address attribute.""" self.assertEqual(self.connection.remote_address, ("peer", 1234)) @@ -1185,9 +1246,7 @@ async def test_writing_in_data_received_fails(self): # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() - cause = raised.exception.__cause__ - self.assertEqual(str(cause), "Cannot call write() after write_eof()") - self.assertIsInstance(cause, RuntimeError) + self.assertIsInstance(raised.exception.__cause__, RuntimeError) async def test_writing_in_send_context_fails(self): """Error when sending outgoing frame is correctly reported.""" @@ -1198,29 +1257,18 @@ async def test_writing_in_send_context_fails(self): # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.pong() - cause = raised.exception.__cause__ - self.assertEqual(str(cause), "Cannot call write() after write_eof()") - self.assertIsInstance(cause, RuntimeError) + self.assertIsInstance(raised.exception.__cause__, RuntimeError) # Test safety nets — catching all exceptions in case of bugs. - # Inject a fault in a random call in data_received(). - # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) async def test_unexpected_failure_in_data_received(self, events_received): """Unexpected internal error in data_received() is correctly reported.""" - # Receive a message to trigger the fault. await self.remote_connection.send("😀") - with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, AssertionError) - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) - - # Inject a fault in a random call in send_context(). - # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) async def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" @@ -1228,10 +1276,7 @@ async def test_unexpected_failure_in_send_context(self, send_text): # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.send("😀") - - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) # Test broadcast. @@ -1302,11 +1347,11 @@ async def test_broadcast_skips_closing_connection(self): async def test_broadcast_skips_connection_with_send_blocked(self): """broadcast logs a warning when a connection is blocked in send.""" - gate = asyncio.get_running_loop().create_future() + gate = asyncio.Event() async def fragments(): yield "⏳" - await gate + await gate.wait() send_task = asyncio.create_task(self.connection.send(fragments())) await asyncio.sleep(MS) @@ -1320,7 +1365,7 @@ async def fragments(): ["skipped broadcast: sending a fragmented message"], ) - gate.set_result(None) + gate.set() await send_task @unittest.skipIf( @@ -1329,11 +1374,11 @@ async def fragments(): ) async def test_broadcast_reports_connection_with_send_blocked(self): """broadcast raises exceptions for connections blocked in send.""" - gate = asyncio.get_running_loop().create_future() + gate = asyncio.Event() async def fragments(): yield "⏳" - await gate + await gate.wait() send_task = asyncio.create_task(self.connection.send(fragments())) await asyncio.sleep(MS) @@ -1347,7 +1392,7 @@ async def fragments(): self.assertEqual(str(exc), "sending a fragmented message") self.assertIsInstance(exc, ConcurrencyError) - gate.set_result(None) + gate.set() await send_task async def test_broadcast_skips_connection_failing_to_send(self): diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index a90788d02..c862090a3 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import unittest import unittest.mock @@ -8,7 +9,7 @@ from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame -from .utils import alist +from ..utils import alist class SimpleQueueTests(unittest.IsolatedAsyncioTestCase): @@ -32,7 +33,7 @@ async def test_put_then_get(self): async def test_get_then_put(self): """get returns an item when it is put.""" getter_task = asyncio.create_task(self.queue.get()) - await asyncio.sleep(0) # let the task start + await asyncio.sleep(0) # let the event loop start getter_task self.queue.put(42) item = await getter_task self.assertEqual(item, 42) @@ -46,7 +47,7 @@ async def test_reset(self): async def test_abort(self): """abort throws an exception in get.""" getter_task = asyncio.create_task(self.queue.get()) - await asyncio.sleep(0) # let the task start + await asyncio.sleep(0) # let the event loop start getter_task self.queue.abort() with self.assertRaises(EOFError): await getter_task @@ -58,7 +59,7 @@ async def asyncSetUp(self): self.resume = unittest.mock.Mock() self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) - # Test get + # Test get. async def test_get_text_message_already_received(self): """get returns a text message that is already received.""" @@ -107,6 +108,7 @@ async def test_get_fragmented_binary_message_already_received(self): async def test_get_fragmented_text_message_not_received_yet(self): """get reassembles a fragmented text message when it is received.""" getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) @@ -116,6 +118,7 @@ async def test_get_fragmented_text_message_not_received_yet(self): async def test_get_fragmented_binary_message_not_received_yet(self): """get reassembles a fragmented binary message when it is received.""" getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) @@ -126,6 +129,7 @@ async def test_get_fragmented_text_message_being_received(self): """get reassembles a fragmented text message that is partially received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) message = await getter_task @@ -135,6 +139,7 @@ async def test_get_fragmented_binary_message_being_received(self): """get reassembles a fragmented binary message that is partially received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) message = await getter_task @@ -161,11 +166,9 @@ async def test_get_resumes_reading(self): # queue is above the low-water mark await self.assembler.get() self.resume.assert_not_called() - # queue is at the low-water mark await self.assembler.get() self.resume.assert_called_once_with() - # queue is below the low-water mark await self.assembler.get() self.resume.assert_called_once_with() @@ -180,7 +183,6 @@ async def test_get_does_not_resume_reading(self): await self.assembler.get() await self.assembler.get() await self.assembler.get() - self.resume.assert_not_called() async def test_cancel_get_before_first_frame(self): @@ -192,7 +194,6 @@ async def test_cancel_get_before_first_frame(self): await getter_task self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - message = await self.assembler.get() self.assertEqual(message, "café") @@ -208,11 +209,10 @@ async def test_cancel_get_after_first_frame(self): self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - message = await self.assembler.get() self.assertEqual(message, "café") - # Test get_iter + # Test get_iter. async def test_get_iter_text_message_already_received(self): """get_iter yields a text message that is already received.""" @@ -261,42 +261,46 @@ async def test_get_iter_fragmented_binary_message_already_received(self): async def test_get_iter_fragmented_text_message_not_received_yet(self): """get_iter yields a fragmented text message when it is received.""" iterator = aiter(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assertEqual(await anext(iterator), "ca") - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assertEqual(await anext(iterator), "f") - self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(await anext(iterator), "é") + async with contextlib.aclosing(iterator): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") async def test_get_iter_fragmented_binary_message_not_received_yet(self): """get_iter yields a fragmented binary message when it is received.""" iterator = aiter(self.assembler.get_iter()) - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assertEqual(await anext(iterator), b"t") - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assertEqual(await anext(iterator), b"e") - self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(await anext(iterator), b"a") + async with contextlib.aclosing(iterator): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") async def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) iterator = aiter(self.assembler.get_iter()) - self.assertEqual(await anext(iterator), "ca") - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assertEqual(await anext(iterator), "f") - self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(await anext(iterator), "é") + async with contextlib.aclosing(iterator): + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") async def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) iterator = aiter(self.assembler.get_iter()) - self.assertEqual(await anext(iterator), b"t") - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assertEqual(await anext(iterator), b"e") - self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(await anext(iterator), b"a") + async with contextlib.aclosing(iterator): + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") async def test_get_iter_encoded_text_message(self): """get_iter yields a text message without UTF-8 decoding.""" @@ -321,18 +325,16 @@ async def test_get_iter_resumes_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) iterator = aiter(self.assembler.get_iter()) - - # queue is above the low-water mark - await anext(iterator) - self.resume.assert_not_called() - - # queue is at the low-water mark - await anext(iterator) - self.resume.assert_called_once_with() - - # queue is below the low-water mark - await anext(iterator) - self.resume.assert_called_once_with() + async with contextlib.aclosing(iterator): + # queue is above the low-water mark + await anext(iterator) + self.resume.assert_not_called() + # queue is at the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + # queue is below the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() async def test_get_iter_does_not_resume_reading(self): """get_iter does not resume reading when the low-water mark is unset.""" @@ -342,11 +344,11 @@ async def test_get_iter_does_not_resume_reading(self): self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) iterator = aiter(self.assembler.get_iter()) - await anext(iterator) - await anext(iterator) - await anext(iterator) - - self.resume.assert_not_called() + async with contextlib.aclosing(iterator): + await anext(iterator) + await anext(iterator) + await anext(iterator) + self.resume.assert_not_called() async def test_cancel_get_iter_before_first_frame(self): """get_iter can be canceled safely before reading the first frame.""" @@ -357,7 +359,6 @@ async def test_cancel_get_iter_before_first_frame(self): await getter_task self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - fragments = await alist(self.assembler.get_iter()) self.assertEqual(fragments, ["café"]) @@ -373,11 +374,10 @@ async def test_cancel_get_iter_after_first_frame(self): self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - # Test put + # Test put. async def test_put_pauses_reading(self): """put pauses reading when queue goes above the high-water mark.""" @@ -385,11 +385,9 @@ async def test_put_pauses_reading(self): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.pause.assert_not_called() - # queue is at the high-water mark self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.pause.assert_called_once_with() - # queue is above the high-water mark self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() @@ -402,10 +400,9 @@ async def test_put_does_not_pause_reading(self): self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) - self.pause.assert_not_called() - # Test termination + # Test termination. async def test_get_fails_when_interrupted_by_close(self): """get raises EOFError when close is called.""" @@ -467,7 +464,7 @@ async def test_get_iter_queued_fragmented_message_after_close(self): self.assertEqual(fragments, [b"t", b"e", b"a"]) async def test_get_partially_queued_fragmented_message_after_close(self): - """get raises EOF on a partial fragmented message after close is called.""" + """get raises EOFError on a partial fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.close() @@ -496,41 +493,41 @@ async def test_close_is_idempotent(self): self.assembler.close() self.assembler.close() - # Test (non-)concurrency + # Test (non-)concurrency. async def test_get_fails_when_get_is_running(self): """get cannot be called concurrently.""" - asyncio.create_task(self.assembler.get()) - await asyncio.sleep(0) + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start the task with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate + getter_task.cancel() async def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" - asyncio.create_task(alist(self.assembler.get_iter())) - await asyncio.sleep(0) + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start the task with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate + getter_task.cancel() async def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" - asyncio.create_task(self.assembler.get()) - await asyncio.sleep(0) + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start the task with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate + getter_task.cancel() async def test_get_iter_fails_when_get_iter_is_running(self): """get_iter cannot be called concurrently.""" - asyncio.create_task(alist(self.assembler.get_iter())) - await asyncio.sleep(0) + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start the task with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate + getter_task.cancel() - # Test setting limits + # Test setting limits. async def test_set_high_water_mark(self): """high sets the high-water and low-water marks.""" diff --git a/tests/asyncio/test_router.py b/tests/asyncio/test_router.py index 3dd766c96..b746052c1 100644 --- a/tests/asyncio/test_router.py +++ b/tests/asyncio/test_router.py @@ -8,9 +8,8 @@ from websockets.asyncio.router import * from websockets.exceptions import InvalidStatus -from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path +from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, alist, temp_unix_socket_path from .server import EvalShellMixin, get_uri, handler -from .utils import alist try: diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 00dcb3010..fe225067c 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -568,7 +568,7 @@ async def test_connection(self): async with serve(*args, ssl=SERVER_CONTEXT) as server: async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") - await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + await self.assertEval(client, f"{SSL_OBJECT}.version()[:3]", "TLS") async def test_timeout_during_tls_handshake(self): """Server times out before receiving TLS handshake request from client.""" @@ -604,7 +604,7 @@ async def test_connection(self): async with unix_serve(handler, path, ssl=SERVER_CONTEXT): async with unix_connect(path, ssl=CLIENT_CONTEXT) as client: await self.assertEval(client, "ws.protocol.state.name", "OPEN") - await self.assertEval(client, SSL_OBJECT + ".version()[:3]", "TLS") + await self.assertEval(client, f"{SSL_OBJECT}.version()[:3]", "TLS") class ServerUsageErrorsTests(unittest.IsolatedAsyncioTestCase): diff --git a/tests/asyncio/utils.py b/tests/asyncio/utils.py deleted file mode 100644 index a611bfc4b..000000000 --- a/tests/asyncio/utils.py +++ /dev/null @@ -1,5 +0,0 @@ -async def alist(async_iterable): - items = [] - async for item in async_iterable: - items.append(item) - return items diff --git a/tests/requirements.txt b/tests/requirements.txt index f375e6f69..77de5350b 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,2 +1,3 @@ -python-socks[asyncio] mitmproxy +python-socks[asyncio] +trio diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 415343911..cc5949c93 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -44,7 +44,7 @@ def test_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") @@ -225,9 +225,11 @@ def test_junk_handshake(self): class JunkHandler(socketserver.BaseRequestHandler): def handle(self): - time.sleep(MS) # wait for the client to send the handshake request + # Wait for the client to send the handshake request. + time.sleep(MS) self.request.send(b"220 smtp.invalid ESMTP Postfix\r\n") - self.request.recv(4096) # wait for the client to close the connection + # Wait for the client to close the connection. + self.request.recv(4096) self.request.close() server = socketserver.TCPServer(("localhost", 0), JunkHandler) @@ -401,7 +403,7 @@ def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" with run_server() as server: with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to sock. + # Use a non-existing domain to ensure we connect via sock. with connect("ws://invalid/", sock=sock) as client: self.assertEqual(client.protocol.state.name, "OPEN") self.assertNumFlows(0) @@ -648,7 +650,7 @@ def test_proxy_ssl_without_https_proxy(self): connect( "ws://localhost/", proxy="http://localhost:8080", - proxy_ssl=True, + proxy_ssl=CLIENT_CONTEXT, ) self.assertEqual( str(raised.exception), diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 07730c48c..9a3789d1b 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -1,12 +1,12 @@ import contextlib +import itertools import logging import socket -import sys import threading import time import unittest import uuid -from unittest.mock import patch +from unittest.mock import Mock, patch from websockets.exceptions import ( ConcurrencyError, @@ -43,14 +43,20 @@ def tearDown(self): # Test helpers built upon RecordingProtocol and InterceptingConnection. + def wait_for_remote_side(self): + """Wait for the remote side to process messages.""" + # We don't have a way to tell if the remote side is blocked on I/O. + # The sync tests still run faster than the asyncio and trio tests :-) + time.sleep(MS) + def assertFrameSent(self, frame): """Check that a single frame was sent.""" - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) def assertNoFrameSent(self): """Check that no frame was sent.""" - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) @contextlib.contextmanager @@ -58,28 +64,28 @@ def delay_frames_rcvd(self, delay): """Delay frames before they're received by the connection.""" with self.remote_connection.delay_frames_sent(delay): yield - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() @contextlib.contextmanager def delay_eof_rcvd(self, delay): """Delay EOF before it's received by the connection.""" with self.remote_connection.delay_eof_sent(delay): yield - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() @contextlib.contextmanager def drop_frames_rcvd(self): """Drop frames before they're received by the connection.""" with self.remote_connection.drop_frames_sent(): yield - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() @contextlib.contextmanager def drop_eof_rcvd(self): """Drop EOF before it's received by the connection.""" with self.remote_connection.drop_eof_sent(): yield - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() # Test __enter__ and __exit__. @@ -106,40 +112,45 @@ def test_exit_with_exception(self): def test_iter_text(self): """__iter__ yields text messages.""" iterator = iter(self.connection) - self.remote_connection.send("😀") - self.assertEqual(next(iterator), "😀") - self.remote_connection.send("😀") - self.assertEqual(next(iterator), "😀") + with contextlib.closing(iterator): + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") def test_iter_binary(self): """__iter__ yields binary messages.""" iterator = iter(self.connection) - self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") - self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + with contextlib.closing(iterator): + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") def test_iter_mixed(self): """__iter__ yields a mix of text and binary messages.""" iterator = iter(self.connection) - self.remote_connection.send("😀") - self.assertEqual(next(iterator), "😀") - self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + with contextlib.closing(iterator): + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") def test_iter_connection_closed_ok(self): """__iter__ terminates after a normal closure.""" iterator = iter(self.connection) - self.remote_connection.close() - with self.assertRaises(StopIteration): - next(iterator) + with contextlib.closing(iterator): + self.remote_connection.close() + with self.assertRaises(StopIteration): + next(iterator) def test_iter_connection_closed_error(self): """__iter__ raises ConnectionClosedError after an error.""" iterator = iter(self.connection) - self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) - with self.assertRaises(ConnectionClosedError): - next(iterator) + with contextlib.closing(iterator): + self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + next(iterator) # Test recv. @@ -199,16 +210,18 @@ def test_recv_during_recv(self): recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaises(ConcurrencyError) as raised: - self.connection.recv() - self.assertEqual( - str(raised.exception), - "cannot call recv while another thread " - "is already running recv or recv_streaming", - ) + try: + with self.assertRaises(ConcurrencyError) as raised: + self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another thread " + "is already running recv or recv_streaming", + ) - self.remote_connection.send("") - recv_thread.join() + finally: + self.remote_connection.send("") + recv_thread.join() def test_recv_during_recv_streaming(self): """recv raises ConcurrencyError when called concurrently with recv_streaming.""" @@ -217,16 +230,18 @@ def test_recv_during_recv_streaming(self): ) recv_streaming_thread.start() - with self.assertRaises(ConcurrencyError) as raised: - self.connection.recv() - self.assertEqual( - str(raised.exception), - "cannot call recv while another thread " - "is already running recv or recv_streaming", - ) + try: + with self.assertRaises(ConcurrencyError) as raised: + self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another thread " + "is already running recv or recv_streaming", + ) - self.remote_connection.send("") - recv_streaming_thread.join() + finally: + self.remote_connection.send("") + recv_streaming_thread.join() # Test recv_streaming. @@ -308,17 +323,19 @@ def test_recv_streaming_during_recv(self): recv_thread = threading.Thread(target=self.connection.recv) recv_thread.start() - with self.assertRaises(ConcurrencyError) as raised: - for _ in self.connection.recv_streaming(): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - "cannot call recv_streaming while another thread " - "is already running recv or recv_streaming", - ) + try: + with self.assertRaises(ConcurrencyError) as raised: + for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "cannot call recv_streaming while another thread " + "is already running recv or recv_streaming", + ) - self.remote_connection.send("") - recv_thread.join() + finally: + self.remote_connection.send("") + recv_thread.join() def test_recv_streaming_during_recv_streaming(self): """recv_streaming raises ConcurrencyError when called concurrently.""" @@ -327,17 +344,19 @@ def test_recv_streaming_during_recv_streaming(self): ) recv_streaming_thread.start() - with self.assertRaises(ConcurrencyError) as raised: - for _ in self.connection.recv_streaming(): - self.fail("did not raise") - self.assertEqual( - str(raised.exception), - r"cannot call recv_streaming while another thread " - r"is already running recv or recv_streaming", - ) + try: + with self.assertRaises(ConcurrencyError) as raised: + for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + r"cannot call recv_streaming while another thread " + r"is already running recv or recv_streaming", + ) - self.remote_connection.send("") - recv_streaming_thread.join() + finally: + self.remote_connection.send("") + recv_streaming_thread.join() # Test send. @@ -488,9 +507,15 @@ def test_close_explicit_code_reason(self): self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) def test_close_waits_for_close_frame(self): - """close waits for a close frame (then EOF) before returning.""" + """close waits for a close frame then EOF before returning.""" + t0 = time.time() with self.delay_frames_rcvd(MS): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -504,8 +529,14 @@ def test_close_waits_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = time.time() with self.delay_eof_rcvd(MS): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -516,8 +547,14 @@ def test_close_waits_for_connection_closed(self): def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" + t0 = time.time() with self.drop_frames_rcvd(), self.drop_eof_rcvd(): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() @@ -531,8 +568,14 @@ def test_close_timeout_waiting_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = time.time() with self.drop_eof_rcvd(): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -547,13 +590,9 @@ def test_close_preserves_queued_messages(self): self.connection.close() self.assertEqual(self.connection.recv(), "😀") - with self.assertRaises(ConnectionClosedOK) as raised: + with self.assertRaises(ConnectionClosedOK): self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") - self.assertIsNone(exc.__cause__) - def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() @@ -621,16 +660,18 @@ def closer(): exit_gate.set() def fragments(): - yield "😀" + yield "⏳" close_gate.set() exit_gate.wait() - yield "😀" + yield "⌛️" close_thread = threading.Thread(target=closer) close_thread.start() - with self.assertRaises(ConnectionClosedError) as raised: - self.connection.send(fragments()) + iterator = fragments() + with contextlib.closing(iterator): + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.send(iterator) exc = raised.exception self.assertEqual( @@ -644,9 +685,10 @@ def fragments(): # Test ping. - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") def test_ping(self, getrandbits): """ping sends a ping frame with a random payload.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping() getrandbits.assert_called_once_with(32) self.assertFrameSent(Frame(Opcode.PING, b"rand")) @@ -664,38 +706,38 @@ def test_ping_explicit_binary(self): def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.remote_connection.pong("this") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.remote_connection.pong("that") - self.assertFalse(pong_waiter.wait(MS)) + self.assertFalse(pong_received.wait(MS)) def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong for as a later ping.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.connection.ping("that") self.remote_connection.pong("that") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) def test_acknowledge_ping_on_close(self): """ping with ack_on_close is acknowledged when the connection is closed.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter_ack_on_close = self.connection.ping("this", ack_on_close=True) - pong_waiter = self.connection.ping("that") + pong_received_ack_on_close = self.connection.ping("this", ack_on_close=True) + pong_received = self.connection.ping("that") self.connection.close() - self.assertTrue(pong_waiter_ack_on_close.wait(MS)) - self.assertFalse(pong_waiter.wait(MS)) + self.assertTrue(pong_received_ack_on_close.wait(MS)) + self.assertFalse(pong_received.wait(MS)) def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("idem") + pong_received = self.connection.ping("idem") with self.assertRaises(ConcurrencyError) as raised: self.connection.ping("idem") @@ -705,7 +747,7 @@ def test_ping_duplicate_payload(self): ) self.remote_connection.pong("idem") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) self.connection.ping("idem") # doesn't raise an exception @@ -738,10 +780,11 @@ def test_pong_unsupported_type(self): # Test keepalive. - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" - self.connection.ping_interval = 4 * MS + getrandbits.side_effect = itertools.count(1918987876) + self.connection.ping_interval = 3 * MS self.connection.start_keepalive() self.assertIsNotNone(self.connection.keepalive_thread) self.assertEqual(self.connection.latency, 0) @@ -759,25 +802,27 @@ def test_disable_keepalive(self): self.connection.start_keepalive() self.assertIsNone(self.connection.keepalive_thread) - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") def test_keepalive_times_out(self, getrandbits): """keepalive closes the connection if ping_timeout elapses.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping_interval = 4 * MS self.connection.ping_timeout = 2 * MS with self.drop_frames_rcvd(): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. + # 4.x ms: a pong frame is dropped. time.sleep(4 * MS) # Exiting the context manager sleeps for 1 ms. - # 4.x ms: a pong frame is dropped. # 6 ms: no pong frame is received; the connection is closed. - time.sleep(2 * MS) + time.sleep(3 * MS) # 7 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") def test_keepalive_ignores_timeout(self, getrandbits): """keepalive ignores timeouts if ping_timeout isn't set.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping_interval = 4 * MS self.connection.ping_timeout = None with self.drop_frames_rcvd(): @@ -787,7 +832,7 @@ def test_keepalive_ignores_timeout(self, getrandbits): # Exiting the context manager sleeps for 1 ms. # 4.x ms: a pong frame is dropped. # 6 ms: no pong frame is received; the connection remains open. - time.sleep(2 * MS) + time.sleep(3 * MS) # 7 ms: check that the connection is still open. self.assertEqual(self.connection.state, State.OPEN) @@ -795,6 +840,7 @@ def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" self.connection.ping_interval = 3 * MS self.connection.start_keepalive() + self.assertTrue(self.connection.keepalive_thread.is_alive()) time.sleep(MS) self.connection.close() self.connection.keepalive_thread.join(MS) @@ -802,8 +848,9 @@ def test_keepalive_terminates_while_sleeping(self): def test_keepalive_terminates_when_sending_ping_fails(self): """keepalive task terminates when sending a ping fails.""" - self.connection.ping_interval = 1 * MS + self.connection.ping_interval = MS self.connection.start_keepalive() + self.assertTrue(self.connection.keepalive_thread.is_alive()) with self.drop_eof_rcvd(), self.drop_frames_rcvd(): self.connection.close() self.assertFalse(self.connection.keepalive_thread.is_alive()) @@ -826,14 +873,13 @@ def test_keepalive_terminates_while_waiting_for_pong(self): def test_keepalive_reports_errors(self): """keepalive reports unexpected errors in logs.""" self.connection.ping_interval = 2 * MS - with self.drop_frames_rcvd(): - self.connection.start_keepalive() - # 2 ms: keepalive() sends a ping frame. - # 2.x ms: a pong frame is dropped. - with self.assertLogs("websockets", logging.ERROR) as logs: - with patch("threading.Event.wait", side_effect=Exception("BOOM")): - time.sleep(3 * MS) - # Exiting the context manager sleeps for 1 ms. + self.connection.start_keepalive() + # Inject a fault when waiting to receive a pong. + with self.assertLogs("websockets", logging.ERROR) as logs: + with patch("threading.Event.wait", side_effect=Exception("BOOM")): + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is dropped. + time.sleep(3 * MS) self.assertEqual( [record.getMessage() for record in logs.records], ["keepalive ping failed"], @@ -847,11 +893,8 @@ def test_keepalive_reports_errors(self): def test_close_timeout(self): """close_timeout parameter configures close timeout.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), close_timeout=42 * MS, ) @@ -859,11 +902,8 @@ def test_close_timeout(self): def test_max_queue(self): """max_queue configures high-water mark of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=4, ) @@ -871,11 +911,8 @@ def test_max_queue(self): def test_max_queue_none(self): """max_queue disables high-water mark of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=None, ) @@ -884,11 +921,8 @@ def test_max_queue_none(self): def test_max_queue_tuple(self): """max_queue configures high-water and low-water marks of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=(4, 2), ) @@ -943,17 +977,6 @@ def test_close_reason(self): # Test reporting of network errors. - @unittest.skipUnless(sys.platform == "darwin", "works only on BSD") - def test_reading_in_recv_events_fails(self): - """Error when reading incoming frames is correctly reported.""" - # Inject a fault by closing the socket. This works only on BSD. - # I cannot find a way to achieve the same effect on Linux. - self.connection.socket.close() - # The connection closed exception reports the injected fault. - with self.assertRaises(ConnectionClosedError) as raised: - self.connection.recv() - self.assertIsInstance(raised.exception.__cause__, IOError) - def test_writing_in_recv_events_fails(self): """Error when responding to incoming frames is correctly reported.""" # Inject a fault by shutting down the socket for writing — but not by @@ -979,23 +1002,14 @@ def test_writing_in_send_context_fails(self): # Test safety nets — catching all exceptions in case of bugs. - # Inject a fault in a random call in recv_events(). - # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) def test_unexpected_failure_in_recv_events(self, events_received): """Unexpected internal error in recv_events() is correctly reported.""" - # Receive a message to trigger the fault. self.remote_connection.send("😀") - with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, AssertionError) - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) - - # Inject a fault in a random call in send_context(). - # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" @@ -1003,10 +1017,7 @@ def test_unexpected_failure_in_send_context(self, send_text): # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.send("😀") - - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) class ServerConnectionTests(ClientConnectionTests): diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index e42784094..7a6569ddd 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,3 +1,4 @@ +import contextlib import time import unittest import unittest.mock @@ -16,7 +17,7 @@ def setUp(self): self.resume = unittest.mock.Mock() self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) - # Test get + # Test get. def test_get_text_message_already_received(self): """get returns a text message that is already received.""" @@ -40,7 +41,6 @@ def getter(): with self.run_in_thread(getter): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - self.assertEqual(message, "café") def test_get_binary_message_not_received_yet(self): @@ -53,7 +53,6 @@ def getter(): with self.run_in_thread(getter): self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assertEqual(message, b"tea") def test_get_fragmented_text_message_already_received(self): @@ -84,7 +83,6 @@ def getter(): self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(message, "café") def test_get_fragmented_binary_message_not_received_yet(self): @@ -99,7 +97,6 @@ def getter(): self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(message, b"tea") def test_get_fragmented_text_message_being_received(self): @@ -114,7 +111,6 @@ def getter(): with self.run_in_thread(getter): self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(message, "café") def test_get_fragmented_binary_message_being_received(self): @@ -129,7 +125,6 @@ def getter(): with self.run_in_thread(getter): self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(message, b"tea") def test_get_encoded_text_message(self): @@ -153,11 +148,9 @@ def test_get_resumes_reading(self): # queue is above the low-water mark self.assembler.get() self.resume.assert_not_called() - # queue is at the low-water mark self.assembler.get() self.resume.assert_called_once_with() - # queue is below the low-water mark self.assembler.get() self.resume.assert_called_once_with() @@ -172,7 +165,6 @@ def test_get_does_not_resume_reading(self): self.assembler.get() self.assembler.get() self.assembler.get() - self.resume.assert_not_called() def test_get_timeout_before_first_frame(self): @@ -181,7 +173,6 @@ def test_get_timeout_before_first_frame(self): self.assembler.get(timeout=MS) self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - message = self.assembler.get() self.assertEqual(message, "café") @@ -194,7 +185,6 @@ def test_get_timeout_after_first_frame(self): self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - message = self.assembler.get() self.assertEqual(message, "café") @@ -224,7 +214,7 @@ def test_get_timeout_0_fragmented_message_partially_received(self): with self.assertRaises(TimeoutError): self.assembler.get(timeout=0) - # Test get_iter + # Test get_iter. def test_get_iter_text_message_already_received(self): """get_iter yields a text message that is already received.""" @@ -240,30 +230,26 @@ def test_get_iter_binary_message_already_received(self): def test_get_iter_text_message_not_received_yet(self): """get_iter yields a text message when it is received.""" - fragments = [] + fragments = None def getter(): nonlocal fragments - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + fragments = list(self.assembler.get_iter()) with self.run_in_thread(getter): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - self.assertEqual(fragments, ["café"]) def test_get_iter_binary_message_not_received_yet(self): """get_iter yields a binary message when it is received.""" - fragments = [] + fragments = None def getter(): nonlocal fragments - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + fragments = list(self.assembler.get_iter()) with self.run_in_thread(getter): self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assertEqual(fragments, [b"tea"]) def test_get_iter_fragmented_text_message_already_received(self): @@ -285,42 +271,46 @@ def test_get_iter_fragmented_binary_message_already_received(self): def test_get_iter_fragmented_text_message_not_received_yet(self): """get_iter yields a fragmented text message when it is received.""" iterator = self.assembler.get_iter() - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assertEqual(next(iterator), "ca") - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assertEqual(next(iterator), "f") - self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(next(iterator), "é") + with contextlib.closing(iterator): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_not_received_yet(self): """get_iter yields a fragmented binary message when it is received.""" iterator = self.assembler.get_iter() - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assertEqual(next(iterator), b"t") - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assertEqual(next(iterator), b"e") - self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(next(iterator), b"a") + with contextlib.closing(iterator): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) iterator = self.assembler.get_iter() - self.assertEqual(next(iterator), "ca") - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assertEqual(next(iterator), "f") - self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(next(iterator), "é") + with contextlib.closing(iterator): + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) iterator = self.assembler.get_iter() - self.assertEqual(next(iterator), b"t") - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assertEqual(next(iterator), b"e") - self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(next(iterator), b"a") + with contextlib.closing(iterator): + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") def test_get_iter_encoded_text_message(self): """get_iter yields a text message without UTF-8 decoding.""" @@ -345,18 +335,16 @@ def test_get_iter_resumes_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) iterator = self.assembler.get_iter() - - # queue is above the low-water mark - next(iterator) - self.resume.assert_not_called() - - # queue is at the low-water mark - next(iterator) - self.resume.assert_called_once_with() - - # queue is below the low-water mark - next(iterator) - self.resume.assert_called_once_with() + with contextlib.closing(iterator): + # queue is above the low-water mark + next(iterator) + self.resume.assert_not_called() + # queue is at the low-water mark + next(iterator) + self.resume.assert_called_once_with() + # queue is below the low-water mark + next(iterator) + self.resume.assert_called_once_with() def test_get_iter_does_not_resume_reading(self): """get_iter does not resume reading when the low-water mark is unset.""" @@ -366,13 +354,13 @@ def test_get_iter_does_not_resume_reading(self): self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) iterator = self.assembler.get_iter() - next(iterator) - next(iterator) - next(iterator) + with contextlib.closing(iterator): + next(iterator) + next(iterator) + next(iterator) + self.resume.assert_not_called() - self.resume.assert_not_called() - - # Test put + # Test put. def test_put_pauses_reading(self): """put pauses reading when queue goes above the high-water mark.""" @@ -380,11 +368,9 @@ def test_put_pauses_reading(self): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.pause.assert_not_called() - # queue is at the high-water mark self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.pause.assert_called_once_with() - # queue is above the high-water mark self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() @@ -397,10 +383,9 @@ def test_put_does_not_pause_reading(self): self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) - self.pause.assert_not_called() - # Test termination + # Test termination. def test_get_fails_when_interrupted_by_close(self): """get raises EOFError when close is called.""" @@ -501,7 +486,6 @@ def test_close_resumes_reading(self): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) - # queue is at the high-water mark assert self.assembler.paused @@ -513,7 +497,7 @@ def test_close_is_idempotent(self): self.assembler.close() self.assembler.close() - # Test (non-)concurrency + # Test (non-)concurrency. def test_get_fails_when_get_is_running(self): """get cannot be called concurrently.""" @@ -543,7 +527,7 @@ def test_get_iter_fails_when_get_iter_is_running(self): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - # Test setting limits + # Test setting limits. def test_set_high_water_mark(self): """high sets the high-water and low-water marks.""" diff --git a/tests/test_localhost.cnf b/tests/test_localhost.cnf index 4069e3967..15d49228c 100644 --- a/tests/test_localhost.cnf +++ b/tests/test_localhost.cnf @@ -24,4 +24,5 @@ subjectAltName = @san DNS.1 = localhost DNS.2 = overridden IP.3 = 127.0.0.1 -IP.4 = ::1 +IP.4 = 0.0.0.0 +IP.5 = ::1 diff --git a/tests/test_localhost.pem b/tests/test_localhost.pem index 8df63ec8f..1f26df715 100644 --- a/tests/test_localhost.pem +++ b/tests/test_localhost.pem @@ -1,48 +1,49 @@ -----BEGIN PRIVATE KEY----- -MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDYOOQyq8yYtn5x -K3yRborFxTFse16JIVb4x/ZhZgGm49eARCi09fmczQxJdQpHz81Ij6z0xi7AUYH7 -9wS8T0Lh3uGFDDS1GzITUVPIqSUi0xim2T6XPzXFVQYI1D/OjUxlHm+3/up+WwbL -sBgBO/lDmzoa3ZN7kt9HQoGc/14oQz1Qsv1QTDQs69r+o7mmBJr/hf/g7S0Csyy3 -iC6aaq+yCUyzDbjXceTI7WJqbTGNnK0/DjdFD/SJS/uSDNEg0AH53eqcCSjm+Ei/ -UF8qR5Pu4sSsNwToOW2MVgjtHFazc+kG3rzD6+3Dp+t6x6uI/npyuudOMCmOtd6z -kX0UPQaNAgMBAAECggEAS4eMBztGC+5rusKTEAZKSY15l0h9HG/d/qdzJFDKsO6T -/8VPZu8pk6F48kwFHFK1hexSYWq9OAcA3fBK4jDZzybZJm2+F6l5U5AsMUMMqt6M -lPP8Tj8RXG433muuIkvvbL82DVLpvNu1Qv+vUvcNOpWFtY7DDv6eKjlMJ3h4/pzh -89MNt26VMCYOlq1NSjuZBzFohL2u9nsFehlOpcVsqNfNfcYCq9+5yoH8fWJP90Op -hqhvqUoGLN7DRKV1f+AWHSA4nmGgvVviV5PQgMhtk5exlN7kG+rDc3LbzhefS1Sp -Tat1qIgm8fK2n+Q/obQPjHOGOGuvE5cIF7E275ZKgQKBgQDt87BqALKWnbkbQnb7 -GS1h6LRcKyZhFbxnO2qbviBWSo15LEF8jPGV33Dj+T56hqufa/rUkbZiUbIR9yOX -dnOwpAVTo+ObAwZfGfHvrnufiIbHFqJBumaYLqjRZ7AC0QtS3G+kjS9dbllrr7ok -fO4JdfKRXzBJKrkQdCn8hR22rQKBgQDon0b49Dxs1EfdSDbDode2TSwE83fI3vmR -SKUkNY8ma6CRbomVRWijhBM458wJeuhpjPZOvjNMsnDzGwrtdAp2VfFlMIDnA8ZC -fEWIAAH2QYKXKGmkoXOcWB2QbvbI154zCm6zFGtzvRKOCGmTXuhFajO8VPwOyJVt -aSJA3bLrYQKBgQDJM2/tAfAAKRdW9GlUwqI8Ep9G+/l0yANJqtTnIemH7XwYhJJO -9YJlPszfB2aMBgliQNSUHy1/jyKpzDYdITyLlPUoFwEilnkxuud2yiuf5rpH51yF -hU6wyWtXvXv3tbkEdH42PmdZcjBMPQeBSN2hxEi6ISncBDL9tau26PwJ9QKBgQCs -cNYl2reoXTzgtpWSNDk6NL769JjJWTFcF6QD0YhKjOI8rNpkw00sWc3+EybXqDr9 -c7dq6+gPZQAB1vwkxi6zRkZqIqiLl+qygnjwtkC+EhYCg7y8g8q2DUPtO7TJcb0e -TQ9+xRZad8B3dZj93A8G1hF//OfU9bB/qL3xo+bsQQKBgC/9YJvgLIWA/UziLcB2 -29Ai0nbPkN5df7z4PifUHHSlbQJHKak8UKbMP+8S064Ul0F7g8UCjZMk2LzSbaNY -XU5+2j0sIOnGUFoSlvcpdowzYrD2LN5PkKBot7AOq/v7HlcOoR8J8RGWAMpCrHsI -a/u/dlZs+/K16RcavQwx8rag +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDKiNs9JHIq5I2c +GjupVn8QJ3oi+lSpEwdUu6aw/q1H9mVzv1dFtp7hT8kuhclNf1tlBBFiB+NWbRZc +uyBRq+mIIWfepcHRHpquxyopesD+CdeC0rogq3vry94FJNmN8257WZiraNl3v9ht +eBqTy0xYDsDtl8iYLfT4xPDfJVOMq0R6SQEljWi6jSbR3b74wiLpXoWjvx7KJahH +hd/p48meuq95tGfxDEb7r/h02RpZF5rq2zRqBOcO4nL5drWYBh1I4+RFp+AbCixX +MqWh1e0vl/wXiKwYTPIgqH2DIXxS3m8dn4O74zO0ktRqPkIXMyKAZQkdUNLngE7v +pNeDcQatAgMBAAECggEACRc/WtZvBt7YYu9IgP0btWBF9hoa0yOwA8P97FpQ8YkI +rpa0bVZrnjz2fkZNdwodLd43YBlKZe1ZbhxD1S1+uuYEY3TvpvWC7A78pPz86IEN +TPu/Jt1AMeo4d5vtLoS7fSYLBwl2H7OI03Y0ROeS8FJXfrKixdp2OmLmVcOAXDDj +Eq0Xs2tSXXPVZ8KKGMidKqvfxcVAhOZvJfHvkMJ+tS/FRAn7Qxc1tn7OTUOg+glr +sHdMwImfzDCbyhP5gZXL/MP35UqnKUBAGdJmfp3BkFxk0yGLhlCOefs1/a9PhVOt +Q83+kjWnuYeP3R4jB7fuWtEu0/gPZT/P1iJF4MIhjwKBgQDqPtT+7G7KMThGjdm6 +bu77VDsW10T5uDU55G3LvXHoFTZUnleSOtWrh2mdR3KVj5PdHDR4VSuA0d65S39n +LYVul82FMgjCWKL4odgssPcLD6SsybdF9xXSXJKtQ96eJjW0o7vMu0/CHrhF6whA +EmCeDcD81Bzvj8DbkSyHpIaolwKBgQDdWBn43eVBt8FStAXx3J49pMyw83AXyqNA +3taHTGjG9BnjgsRgQeYmZG82xpD/Yu6dYyzF+rI4iODkSzF1FN+j64ElDRJbAMvS +yThbAKAb+xegh0EQm43+kYG1sDavWT4pvzh6DCltN82eHwJ5utDuneiAB66DeAqY +ttXmw+fPWwKBgHYEoBWsE4mlUMAjWc5Xc+qGnpq8bNEQISkA0Ny0nv4aKdxqRp6z +K9IXEHwgcjeuNgZR3pG9/4QQuRFMW20lfzOgIfj4o3cfZ0SzbhHeOymEgShZHRCQ +E5t/7pqDNlch0y8my0i0GtQn3BnF98soNyuKrG/1gnqkR7uYIgJZP0sTAoGAGHLt +0353H04zzXXTHkcXN4nnjjgljos0gyraGXHINQmrfmToWhWNXXpEipFeXMdJwhq9 +TFUHsJT1+mGP4fXfShTuW/BYsbKh0POnBO5JwS14C6RE/JeiFJdv82i2caHy6tuT +Wm/Td5vtW2Tjehy3jVPl5ZZzoVP2H646bFYBWfcCgYEAkWJLFzvXsF9SW9Ku6cc0 +7Yhuoolad/AWCXe5Q3+k+icgOQFnMsOkuEPIlRHPgjaOnXMq76VyO4a66vK+ucgr +R3O8/h5QZiuxE3dfqXsDrGr/6W2kmDWWXXK9r5oJQ1J4ndj65ZaGcAuw/77hf5K8 +PnN3beykcf5xxuaPNpq0cbg= -----END PRIVATE KEY----- -----BEGIN CERTIFICATE----- -MIIDWTCCAkGgAwIBAgIJAOL9UKiOOxupMA0GCSqGSIb3DQEBCwUAMEwxCzAJBgNV -BAYTAkZSMQ4wDAYDVQQHDAVQYXJpczEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3Rp -bjESMBAGA1UEAwwJbG9jYWxob3N0MCAXDTIyMTAxNTE5Mjg0MVoYDzIwNjQxMDE0 -MTkyODQxWjBMMQswCQYDVQQGEwJGUjEOMAwGA1UEBwwFUGFyaXMxGTAXBgNVBAoM -EEF5bWVyaWMgQXVndXN0aW4xEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZI -hvcNAQEBBQADggEPADCCAQoCggEBANg45DKrzJi2fnErfJFuisXFMWx7XokhVvjH -9mFmAabj14BEKLT1+ZzNDEl1CkfPzUiPrPTGLsBRgfv3BLxPQuHe4YUMNLUbMhNR -U8ipJSLTGKbZPpc/NcVVBgjUP86NTGUeb7f+6n5bBsuwGAE7+UObOhrdk3uS30dC -gZz/XihDPVCy/VBMNCzr2v6juaYEmv+F/+DtLQKzLLeILppqr7IJTLMNuNdx5Mjt -YmptMY2crT8ON0UP9IlL+5IM0SDQAfnd6pwJKOb4SL9QXypHk+7ixKw3BOg5bYxW -CO0cVrNz6QbevMPr7cOn63rHq4j+enK6504wKY613rORfRQ9Bo0CAwEAAaM8MDow -OAYDVR0RBDEwL4IJbG9jYWxob3N0ggpvdmVycmlkZGVuhwR/AAABhxAAAAAAAAAA -AAAAAAAAAAABMA0GCSqGSIb3DQEBCwUAA4IBAQBPNDGDdl4wsCRlDuyCHBC8o+vW -Vb14thUw9Z6UrlsQRXLONxHOXbNAj1sYQACNwIWuNz36HXu5m8Xw/ID/bOhnIg+b -Y6l/JU/kZQYB7SV1aR3ZdbCK0gjfkE0POBHuKOjUFIOPBCtJ4tIBUX94zlgJrR9v -2rqJC3TIYrR7pVQumHZsI5GZEMpM5NxfreWwxcgltgxmGdm7elcizHfz7k5+szwh -4eZ/rxK9bw1q8BIvVBWelRvUR55mIrCjzfZp5ZObSYQTZlW7PzXBe5Jk+1w31YHM -RSBA2EpPhYlGNqPidi7bg7rnQcsc6+hE0OqzTL/hWxPm9Vbp9dj3HFTik1wa +MIIDiTCCAnGgAwIBAgIURQDnIfsMPAhuq9Uq1dka01Qoc9IwDQYJKoZIhvcNAQEL +BQAwTDELMAkGA1UEBhMCRlIxDjAMBgNVBAcMBVBhcmlzMRkwFwYDVQQKDBBBeW1l +cmljIEF1Z3VzdGluMRIwEAYDVQQDDAlsb2NhbGhvc3QwIBcNMjUwNTMxMjAxMDU1 +WhgPMjA2NzA1MzEyMDEwNTVaMEwxCzAJBgNVBAYTAkZSMQ4wDAYDVQQHDAVQYXJp +czEZMBcGA1UECgwQQXltZXJpYyBBdWd1c3RpbjESMBAGA1UEAwwJbG9jYWxob3N0 +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAyojbPSRyKuSNnBo7qVZ/ +ECd6IvpUqRMHVLumsP6tR/Zlc79XRbae4U/JLoXJTX9bZQQRYgfjVm0WXLsgUavp +iCFn3qXB0R6arscqKXrA/gnXgtK6IKt768veBSTZjfNue1mYq2jZd7/YbXgak8tM +WA7A7ZfImC30+MTw3yVTjKtEekkBJY1ouo0m0d2++MIi6V6Fo78eyiWoR4Xf6ePJ +nrqvebRn8QxG+6/4dNkaWRea6ts0agTnDuJy+Xa1mAYdSOPkRafgGwosVzKlodXt +L5f8F4isGEzyIKh9gyF8Ut5vHZ+Du+MztJLUaj5CFzMigGUJHVDS54BO76TXg3EG +rQIDAQABo2EwXzA+BgNVHREENzA1gglsb2NhbGhvc3SCCm92ZXJyaWRkZW6HBH8A +AAGHBAAAAACHEAAAAAAAAAAAAAAAAAAAAAEwHQYDVR0OBBYEFB7eswhXVVmG32UR +MGtc2vewZjM0MA0GCSqGSIb3DQEBCwUAA4IBAQBt9KGnnrtn15H9wz4fWHzPTGaO +laJQE5RnqlzyQ3aDLRtZIc/OA+0L6rW7+xiiN0v1irqCD/M0YGYGomy//3J444bT +SxciJQarZPtNRaLJx17geQOwbY5NpTsfEKmvhwCnMLx9Wy6kyHx0NyD3e1MJwH47 +QdJDmKCVF2R10AKGlnsp6zYaoOvoY48MvCBOnaZEVXPypta0N3XXrASsllw5QJSb +XXPIdNbwA22necSoa7PchMXIbyDXIhygf+tXVBAKvNaSNCzQPehTmepENYJPFEh/ +NJrYPB769uRPgZxIvivo1QjNik4ywcZlvEU6LC6JPUasUcGY6FTnipLL6lD0 -----END CERTIFICATE----- diff --git a/tests/test_proxy.py b/tests/test_proxy.py new file mode 100644 index 000000000..e0d12898e --- /dev/null +++ b/tests/test_proxy.py @@ -0,0 +1,233 @@ +import os +import unittest +from unittest.mock import patch + +from websockets.exceptions import InvalidProxy +from websockets.http11 import USER_AGENT +from websockets.proxy import * +from websockets.proxy import prepare_connect_request +from websockets.uri import parse_uri + + +VALID_PROXIES = [ + ( + "http://proxy:8080", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "https://proxy:8080", + Proxy("https", "proxy", 8080, None, None), + ), + ( + "http://proxy", + Proxy("http", "proxy", 80, None, None), + ), + ( + "http://proxy:8080/", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "http://PROXY:8080", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "http://user:pass@proxy:8080", + Proxy("http", "proxy", 8080, "user", "pass"), + ), + ( + "http://høst:8080/", + Proxy("http", "xn--hst-0na", 8080, None, None), + ), + ( + "http://üser:påss@høst:8080", + Proxy("http", "xn--hst-0na", 8080, "%C3%BCser", "p%C3%A5ss"), + ), +] + +INVALID_PROXIES = [ + "ws://proxy:8080", + "wss://proxy:8080", + "http://proxy:8080/path", + "http://proxy:8080/?query", + "http://proxy:8080/#fragment", + "http://user@proxy", + "http:///", +] + +PROXIES_WITH_USER_INFO = [ + ("http://proxy", None), + ("http://user:pass@proxy", ("user", "pass")), + ("http://üser:påss@høst", ("%C3%BCser", "p%C3%A5ss")), +] + +PROXY_ENVS = [ + ( + {"ws_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"ws_proxy": "http://proxy:8080"}, + "wss://example.com/", + None, + ), + ( + {"wss_proxy": "http://proxy:8080"}, + "ws://example.com/", + None, + ), + ( + {"wss_proxy": "http://proxy:8080"}, + "wss://example.com/", + "http://proxy:8080", + ), + ( + {"http_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"http_proxy": "http://proxy:8080"}, + "wss://example.com/", + None, + ), + ( + {"https_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"https_proxy": "http://proxy:8080"}, + "wss://example.com/", + "http://proxy:8080", + ), + ( + {"socks_proxy": "http://proxy:1080"}, + "ws://example.com/", + "socks5h://proxy:1080", + ), + ( + {"socks_proxy": "http://proxy:1080"}, + "wss://example.com/", + "socks5h://proxy:1080", + ), + ( + {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, + "ws://example.com/", + "http://proxy1:8080", + ), + ( + {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, + "wss://example.com/", + "http://proxy2:8080", + ), + ( + {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, + "ws://example.com/", + "http://proxy2:8080", + ), + ( + {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, + "wss://example.com/", + "http://proxy2:8080", + ), + ( + {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, + "ws://example.com/", + "socks5h://proxy:1080", + ), + ( + {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, + "wss://example.com/", + "socks5h://proxy:1080", + ), + ( + {"socks_proxy": "http://proxy:1080", "no_proxy": ".local"}, + "ws://example.local/", + None, + ), +] + +CONNECT_REQUESTS = [ + ( + {"https_proxy": "http://proxy:8080"}, + "ws://example.com/", + ( + b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com\r\n" + b"User-Agent: " + USER_AGENT.encode() + b"\r\n\r\n" + ), + ), + ( + {"https_proxy": "http://proxy:8080"}, + "wss://example.com/", + ( + b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com\r\n" + b"User-Agent: " + USER_AGENT.encode() + b"\r\n\r\n" + ), + ), + ( + {"https_proxy": "http://hello:iloveyou@proxy:8080"}, + "ws://example.com/", + ( + b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com\r\n" + b"User-Agent: " + USER_AGENT.encode() + b"\r\n" + b"Proxy-Authorization: Basic aGVsbG86aWxvdmV5b3U=\r\n\r\n" + ), + ), +] + +CONNECT_REQUESTS_WITH_USER_AGENT = [ + ( + "Smith", + ( + b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com\r\n" + b"User-Agent: Smith\r\n\r\n" + ), + ), + ( + None, + b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com\r\n\r\n", + ), +] + + +class ProxyTests(unittest.TestCase): + def test_parse_valid_proxies(self): + for proxy, parsed in VALID_PROXIES: + with self.subTest(proxy=proxy): + self.assertEqual(parse_proxy(proxy), parsed) + + def test_parse_invalid_proxies(self): + for proxy in INVALID_PROXIES: + with self.subTest(proxy=proxy): + with self.assertRaises(InvalidProxy): + parse_proxy(proxy) + + def test_parse_proxy_user_info(self): + for proxy, user_info in PROXIES_WITH_USER_INFO: + with self.subTest(proxy=proxy): + self.assertEqual(parse_proxy(proxy).user_info, user_info) + + def test_get_proxy(self): + for environ, uri, proxy in PROXY_ENVS: + with patch.dict(os.environ, environ): + with self.subTest(environ=environ, uri=uri): + self.assertEqual(get_proxy(parse_uri(uri)), proxy) + + def test_prepare_connect_request(self): + for environ, uri, request in CONNECT_REQUESTS: + with patch.dict(os.environ, environ): + with self.subTest(environ=environ, uri=uri): + uri = parse_uri(uri) + proxy = parse_proxy(get_proxy(uri)) + self.assertEqual(prepare_connect_request(proxy, uri), request) + + def test_prepare_connect_request_with_user_agent(self): + for user_agent_header, request in CONNECT_REQUESTS_WITH_USER_AGENT: + with self.subTest(user_agent_header=user_agent_header): + uri = parse_uri("ws://example.com") + proxy = parse_proxy("http://proxy:8080") + self.assertEqual( + prepare_connect_request(proxy, uri, user_agent_header), + request, + ) diff --git a/tests/test_uri.py b/tests/test_uri.py index 3ccf21158..057a17291 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -1,10 +1,7 @@ -import os import unittest -from unittest.mock import patch -from websockets.exceptions import InvalidProxy, InvalidURI +from websockets.exceptions import InvalidURI from websockets.uri import * -from websockets.uri import Proxy, get_proxy, parse_proxy VALID_URIS = [ @@ -75,145 +72,6 @@ ("ws://üser:påss@høst/", ("%C3%BCser", "p%C3%A5ss")), ] -VALID_PROXIES = [ - ( - "http://proxy:8080", - Proxy("http", "proxy", 8080, None, None), - ), - ( - "https://proxy:8080", - Proxy("https", "proxy", 8080, None, None), - ), - ( - "http://proxy", - Proxy("http", "proxy", 80, None, None), - ), - ( - "http://proxy:8080/", - Proxy("http", "proxy", 8080, None, None), - ), - ( - "http://PROXY:8080", - Proxy("http", "proxy", 8080, None, None), - ), - ( - "http://user:pass@proxy:8080", - Proxy("http", "proxy", 8080, "user", "pass"), - ), - ( - "http://høst:8080/", - Proxy("http", "xn--hst-0na", 8080, None, None), - ), - ( - "http://üser:påss@høst:8080", - Proxy("http", "xn--hst-0na", 8080, "%C3%BCser", "p%C3%A5ss"), - ), -] - -INVALID_PROXIES = [ - "ws://proxy:8080", - "wss://proxy:8080", - "http://proxy:8080/path", - "http://proxy:8080/?query", - "http://proxy:8080/#fragment", - "http://user@proxy", - "http:///", -] - -PROXIES_WITH_USER_INFO = [ - ("http://proxy", None), - ("http://user:pass@proxy", ("user", "pass")), - ("http://üser:påss@høst", ("%C3%BCser", "p%C3%A5ss")), -] - -PROXY_ENVS = [ - ( - {"ws_proxy": "http://proxy:8080"}, - "ws://example.com/", - "http://proxy:8080", - ), - ( - {"ws_proxy": "http://proxy:8080"}, - "wss://example.com/", - None, - ), - ( - {"wss_proxy": "http://proxy:8080"}, - "ws://example.com/", - None, - ), - ( - {"wss_proxy": "http://proxy:8080"}, - "wss://example.com/", - "http://proxy:8080", - ), - ( - {"http_proxy": "http://proxy:8080"}, - "ws://example.com/", - "http://proxy:8080", - ), - ( - {"http_proxy": "http://proxy:8080"}, - "wss://example.com/", - None, - ), - ( - {"https_proxy": "http://proxy:8080"}, - "ws://example.com/", - "http://proxy:8080", - ), - ( - {"https_proxy": "http://proxy:8080"}, - "wss://example.com/", - "http://proxy:8080", - ), - ( - {"socks_proxy": "http://proxy:1080"}, - "ws://example.com/", - "socks5h://proxy:1080", - ), - ( - {"socks_proxy": "http://proxy:1080"}, - "wss://example.com/", - "socks5h://proxy:1080", - ), - ( - {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, - "ws://example.com/", - "http://proxy1:8080", - ), - ( - {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, - "wss://example.com/", - "http://proxy2:8080", - ), - ( - {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, - "ws://example.com/", - "http://proxy2:8080", - ), - ( - {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, - "wss://example.com/", - "http://proxy2:8080", - ), - ( - {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, - "ws://example.com/", - "socks5h://proxy:1080", - ), - ( - {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, - "wss://example.com/", - "socks5h://proxy:1080", - ), - ( - {"socks_proxy": "http://proxy:1080", "no_proxy": ".local"}, - "ws://example.local/", - None, - ), -] - class URITests(unittest.TestCase): def test_parse_valid_uris(self): @@ -236,25 +94,3 @@ def test_parse_user_info(self): for uri, user_info in URIS_WITH_USER_INFO: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri).user_info, user_info) - - def test_parse_valid_proxies(self): - for proxy, parsed in VALID_PROXIES: - with self.subTest(proxy=proxy): - self.assertEqual(parse_proxy(proxy), parsed) - - def test_parse_invalid_proxies(self): - for proxy in INVALID_PROXIES: - with self.subTest(proxy=proxy): - with self.assertRaises(InvalidProxy): - parse_proxy(proxy) - - def test_parse_proxy_user_info(self): - for proxy, user_info in PROXIES_WITH_USER_INFO: - with self.subTest(proxy=proxy): - self.assertEqual(parse_proxy(proxy).user_info, user_info) - - def test_get_proxy(self): - for environ, uri, proxy in PROXY_ENVS: - with patch.dict(os.environ, environ): - with self.subTest(environ=environ, uri=uri): - self.assertEqual(get_proxy(parse_uri(uri)), proxy) diff --git a/tests/trio/__init__.py b/tests/trio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/trio/connection.py b/tests/trio/connection.py new file mode 100644 index 000000000..226f74a3a --- /dev/null +++ b/tests/trio/connection.py @@ -0,0 +1,116 @@ +import contextlib + +import trio + +from websockets.trio.connection import Connection + + +class InterceptingConnection(Connection): + """ + Connection subclass that can intercept outgoing packets. + + By interfacing with this connection, we simulate network conditions + affecting what the component being tested receives during a test. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stream = InterceptingStream(self.stream) + + @contextlib.contextmanager + def delay_frames_sent(self, delay): + """ + Add a delay before sending frames. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.stream.delay_send_all is None + self.stream.delay_send_all = delay + try: + yield + finally: + self.stream.delay_send_all = None + + @contextlib.contextmanager + def delay_eof_sent(self, delay): + """ + Add a delay before sending EOF. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.stream.delay_send_eof is None + self.stream.delay_send_eof = delay + try: + yield + finally: + self.stream.delay_send_eof = None + + @contextlib.contextmanager + def drop_frames_sent(self): + """ + Prevent frames from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.stream.drop_send_all + self.stream.drop_send_all = True + try: + yield + finally: + self.stream.drop_send_all = False + + @contextlib.contextmanager + def drop_eof_sent(self): + """ + Prevent EOF from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.stream.drop_send_eof + self.stream.drop_send_eof = True + try: + yield + finally: + self.stream.drop_send_eof = False + + +class InterceptingStream: + """ + Stream wrapper that intercepts calls to ``send_all()`` and ``send_eof()``. + + This is coupled to the implementation, which relies on these two methods. + + """ + + # We cannot delay EOF with trio's virtual streams because close_hook is + # synchronous. We adopt the same approach as the other implementations. + + def __init__(self, stream): + self.stream = stream + self.delay_send_all = None + self.delay_send_eof = None + self.drop_send_all = False + self.drop_send_eof = False + + def __getattr__(self, name): + return getattr(self.stream, name) + + async def send_all(self, data): + if self.delay_send_all is not None: + await trio.sleep(self.delay_send_all) + if not self.drop_send_all: + await self.stream.send_all(data) + + async def send_eof(self): + if self.delay_send_eof is not None: + await trio.sleep(self.delay_send_eof) + if not self.drop_send_eof: + await self.stream.send_eof() + + +trio.abc.HalfCloseableStream.register(InterceptingStream) diff --git a/tests/trio/server.py b/tests/trio/server.py new file mode 100644 index 000000000..d2172af21 --- /dev/null +++ b/tests/trio/server.py @@ -0,0 +1,63 @@ +import contextlib +import functools +import socket +import urllib.parse + +import trio + +from websockets.trio.server import * + + +def get_host_port(listeners): + for listener in listeners: + if listener.socket.family == socket.AF_INET: # pragma: no branch + return listener.socket.getsockname() + raise AssertionError("expected at least one IPv4 socket") + + +def get_uri(server, secure=False): + protocol = "wss" if secure else "ws" + host, port = get_host_port(server.listeners) + return f"{protocol}://{host}:{port}" + + +async def handler(ws): + path = urllib.parse.urlparse(ws.request.path).path + if path == "/": + # The default path is an eval shell. + async for expr in ws: + value = eval(expr) + await ws.send(str(value)) + elif path == "/crash": + raise RuntimeError + elif path == "/no-op": + pass + elif path == "/delay": + delay = float(await ws.recv()) + await ws.aclose() + await trio.sleep(delay) + else: + raise AssertionError(f"unexpected path: {path}") + + +kwargs = {"handler": handler, "port": 0, "host": "localhost"} + + +@contextlib.asynccontextmanager +async def run_server(**overrides): + merged_kwargs = {**kwargs, **overrides} + async with trio.open_nursery() as nursery: + server = await nursery.start(functools.partial(serve, **merged_kwargs)) + try: + yield server + finally: + # Run all tasks to guarantee that any exceptions are raised. + # Otherwise, canceling the nursery could hide errors. + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + +class EvalShellMixin: + async def assertEval(self, client, expr, value): + await client.send(expr) + self.assertEqual(await client.recv(), value) diff --git a/tests/trio/test_client.py b/tests/trio/test_client.py new file mode 100644 index 000000000..7448b5fd1 --- /dev/null +++ b/tests/trio/test_client.py @@ -0,0 +1,927 @@ +import contextlib +import http +import logging +import os +import socket +import ssl +import sys +import unittest +from unittest.mock import patch + +import trio + +from websockets.client import backoff +from websockets.exceptions import ( + InvalidHandshake, + InvalidMessage, + InvalidProxy, + InvalidProxyMessage, + InvalidStatus, + InvalidURI, + ProxyError, + SecurityError, +) +from websockets.extensions.permessage_deflate import PerMessageDeflate +from websockets.trio.client import * + +from ..proxy import ProxyMixin +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT +from .server import get_host_port, get_uri, run_server +from .utils import IsolatedTrioTestCase + + +@contextlib.asynccontextmanager +async def short_backoff_delay(): + defaults = backoff.__defaults__ + backoff.__defaults__ = ( + defaults[0] * MS, + defaults[1] * MS, + defaults[2] * MS, + defaults[3], + ) + try: + yield + finally: + backoff.__defaults__ = defaults + + +@contextlib.asynccontextmanager +async def few_redirects(): + from websockets.trio import client + + max_redirects = client.MAX_REDIRECTS + client.MAX_REDIRECTS = 2 + try: + yield + finally: + client.MAX_REDIRECTS = max_redirects + + +class ClientTests(IsolatedTrioTestCase): + async def test_connection(self): + """Client connects to server.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_explicit_host_port(self): + """Client connects using an explicit host / port.""" + async with run_server() as server: + host, port = get_host_port(server.listeners) + async with connect("ws://overridden/", host=host, port=port) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_existing_stream(self): + """Client connects using a pre-existing stream.""" + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + # Use a non-existing domain to ensure we connect via stream. + async with connect("ws://invalid/", stream=stream) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with run_server() as server: + async with connect(get_uri(server), compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + async def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + async with run_server() as server: + async with connect( + get_uri(server), additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + async def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with run_server() as server: + async with connect(get_uri(server), user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with run_server() as server: + async with connect(get_uri(server), user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + async def test_legacy_user_agent(self): + """Client can override User-Agent header with additional_headers.""" + async with run_server() as server: + async with connect( + get_uri(server), additional_headers={"User-Agent": "Smith"} + ) as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + async with run_server() as server: + async with connect(get_uri(server), ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + await trio.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with run_server() as server: + async with connect(get_uri(server), ping_interval=None) as client: + await trio.sleep(2 * MS) + self.assertEqual(client.latency, 0) + + async def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server() as server: + async with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + + async def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + async with run_server() as server: + async with connect( + get_uri(server), create_connection=create_connection + ) as client: + self.assertTrue(client.create_connection_ran) + + @short_backoff_delay() + async def test_reconnect(self): + """Client reconnects to server.""" + iterations = 0 + successful = 0 + + async def process_request(connection, request): + nonlocal iterations + iterations += 1 + # Retriable errors + if iterations == 1: + await trio.sleep(3 * MS) + elif iterations == 2: + await connection.stream.aclose() + elif iterations == 3: + return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") + # Fatal error + elif iterations == 6: + return connection.respond(http.HTTPStatus.PAYMENT_REQUIRED, "💸") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async for client in connect(get_uri(server), open_timeout=3 * MS): + self.assertEqual(client.protocol.state.name, "OPEN") + successful += 1 + + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 402", + ) + self.assertEqual(iterations, 6) + self.assertEqual(successful, 2) + + @short_backoff_delay() + async def test_reconnect_with_custom_process_exception(self): + """Client runs process_exception to tell if errors are retryable or fatal.""" + iteration = 0 + + def process_request(connection, request): + nonlocal iteration + iteration += 1 + if iteration == 1: + return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") + return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") + + def process_exception(exc): + if isinstance(exc, InvalidStatus): + if 500 <= exc.response.status_code < 600: + return None + if exc.response.status_code == 418: + return Exception("🫖 💔 ☕️") + self.fail("unexpected exception") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(Exception) as raised: + async for _ in connect( + get_uri(server), process_exception=process_exception + ): + self.fail("did not raise") + + self.assertEqual(iteration, 2) + self.assertEqual( + str(raised.exception), + "🫖 💔 ☕️", + ) + + @short_backoff_delay() + async def test_reconnect_with_custom_process_exception_raising_exception(self): + """Client supports raising an exception in process_exception.""" + + def process_request(connection, request): + return connection.respond(http.HTTPStatus.IM_A_TEAPOT, "🫖") + + def process_exception(exc): + if isinstance(exc, InvalidStatus) and exc.response.status_code == 418: + raise Exception("🫖 💔 ☕️") + self.fail("unexpected exception") + + async with run_server(process_request=process_request) as server: + with self.assertRaises(Exception) as raised: + async for _ in connect( + get_uri(server), process_exception=process_exception + ): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "🫖 💔 ☕️", + ) + + async def test_redirect(self): + """Client follows redirect.""" + + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with run_server(process_request=redirect) as server: + async with connect(get_uri(server) + "/redirect") as client: + self.assertEqual(client.protocol.uri.path, "/") + + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server) + return response + + async with run_server(process_request=redirect) as server: + async with run_server() as other_server: + async with connect(get_uri(server)): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) + + @few_redirects() + async def test_redirect_limit(self): + """Client stops following redirects after limit is reached.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = request.path + return response + + async with run_server(process_request=redirect) as server: + with self.assertRaises(SecurityError) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "more than 2 redirects", + ) + + async def test_redirect_with_explicit_host_port(self): + """Client follows redirect with an explicit host / port.""" + + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with run_server(process_request=redirect) as server: + host, port = get_host_port(server.listeners) + async with connect( + "ws://overridden/redirect", host=host, port=port + ) as client: + self.assertEqual(client.protocol.uri.path, "/") + + async def test_cross_origin_redirect_with_explicit_host_port(self): + """Client doesn't follow cross-origin redirect with an explicit host / port.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "ws://other/" + return response + + async with run_server(process_request=redirect) as server: + host, port = get_host_port(server.listeners) + with self.assertRaises(ValueError) as raised: + async with connect("ws://overridden/", host=host, port=port): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow cross-origin redirect to ws://other/ " + "with an explicit host or port", + ) + + async def test_redirect_with_existing_stream(self): + """Client doesn't follow redirect when using a pre-existing stream.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with run_server(process_request=redirect) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + with self.assertRaises(ValueError) as raised: + # Use a non-existing domain to ensure we connect via sock. + async with connect("ws://invalid/redirect", stream=stream): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow redirect to ws://invalid/ with a preexisting stream", + ) + + async def test_invalid_uri(self): + """Client receives an invalid URI.""" + with self.assertRaises(InvalidURI): + async with connect("http://localhost"): # invalid scheme + self.fail("did not raise") + + async def test_tcp_connection_fails(self): + """Client fails to connect to server.""" + with self.assertRaises(OSError): + async with connect("ws://localhost:54321"): # invalid port + self.fail("did not raise") + + async def test_handshake_fails(self): + """Client connects to server but the handshake fails.""" + + def remove_accept_header(self, request, response): + del response.headers["Sec-WebSocket-Accept"] + + # The connection will be open for the server but failed for the client. + # Use a connection handler that exits immediately to avoid an exception. + async with run_server(process_response=remove_accept_header) as server: + with self.assertRaises(InvalidHandshake) as raised: + async with connect(get_uri(server) + "/no-op", close_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "missing Sec-WebSocket-Accept header", + ) + + async def test_timeout_during_handshake(self): + """Client times out before receiving handshake response from server.""" + # Replace the WebSocket server with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with self.assertRaises(TimeoutError) as raised: + async with connect(f"ws://{host}:{port}", open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + + async def test_connection_closed_during_handshake(self): + """Client reads EOF before receiving handshake response from server.""" + + async def close_connection(self, request): + await self.stream.aclose() + + async with run_server(process_request=close_connection) as server: + with self.assertRaises(InvalidMessage) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, EOFError) + self.assertEqual( + str(raised.exception.__cause__), + "connection closed while reading HTTP status line", + ) + + async def test_http_response(self): + """Client reads HTTP response.""" + + def http_response(connection, request): + return connection.respond(http.HTTPStatus.OK, "👌") + + async with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + async def test_http_response_without_content_length(self): + """Client reads HTTP response without a Content-Length header.""" + + def http_response(connection, request): + response = connection.respond(http.HTTPStatus.OK, "👌") + del response.headers["Content-Length"] + return response + + async with run_server(process_request=http_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual(raised.exception.response.status_code, 200) + self.assertEqual(raised.exception.response.body.decode(), "👌") + + async def test_junk_handshake(self): + """Client closes the connection when receiving non-HTTP response from server.""" + + async def junk(stream): + # Wait for the client to send the handshake request. + await trio.testing.wait_all_tasks_blocked() + await stream.send_all(b"220 smtp.invalid ESMTP Postfix\r\n") + # Wait for the client to close the connection. + await stream.receive_some() + await stream.aclose() + + async with trio.open_nursery() as nursery: + try: + listeners = await nursery.start(trio.serve_tcp, junk, 0) + host, port = get_host_port(listeners) + with self.assertRaises(InvalidMessage) as raised: + async with connect(f"ws://{host}:{port}"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, ValueError) + self.assertEqual( + str(raised.exception.__cause__), + "unsupported protocol; expected HTTP/1.1: " + "220 smtp.invalid ESMTP Postfix", + ) + finally: + nursery.cancel_scope.cancel() + + +class SecureClientTests(IsolatedTrioTestCase): + async def test_connection(self): + """Client connects to server securely.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.version()[:3], "TLS") + + async def test_set_server_hostname_implicitly(self): + """Client sets server_hostname to the host in the WebSocket URI.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + host, port = get_host_port(server.listeners) + async with connect( + "wss://overridden/", host=host, port=port, ssl=CLIENT_CONTEXT + ) as client: + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_set_server_hostname_explicitly(self): + """Client sets server_hostname to the value provided in argument.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), + ssl=CLIENT_CONTEXT, + server_hostname="overridden", + ) as client: + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.server_hostname, "overridden") + + async def test_reject_invalid_server_certificate(self): + """Client rejects certificate where server certificate isn't trusted.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(trio.BrokenResourceError) as raised: + # The test certificate is self-signed. + async with connect(get_uri(server, secure=True)): + self.fail("did not raise") + self.assertIsInstance( + raised.exception.__cause__, + ssl.SSLCertVerificationError, + ) + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception.__cause__).replace("-", " "), + ) + + async def test_reject_invalid_server_hostname(self): + """Client rejects certificate where server hostname doesn't match.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(trio.BrokenResourceError) as raised: + # This hostname isn't included in the test certificate. + async with connect( + get_uri(server, secure=True), + ssl=CLIENT_CONTEXT, + server_hostname="invalid", + ): + self.fail("did not raise") + self.assertIsInstance( + raised.exception.__cause__, + ssl.SSLCertVerificationError, + ) + self.assertIn( + "certificate verify failed: Hostname mismatch", + str(raised.exception.__cause__), + ) + + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server, secure=True) + return response + + async with run_server(ssl=SERVER_CONTEXT, process_request=redirect) as server: + async with run_server(ssl=SERVER_CONTEXT) as other_server: + async with connect(get_uri(server, secure=True), ssl=CLIENT_CONTEXT): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) + + async def test_redirect_to_insecure_uri(self): + """Client doesn't follow redirect from secure URI to non-secure URI.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = insecure_uri + return response + + async with run_server(ssl=SERVER_CONTEXT, process_request=redirect) as server: + with self.assertRaises(SecurityError) as raised: + secure_uri = get_uri(server, secure=True) + insecure_uri = secure_uri.replace("wss://", "ws://") + async with connect(secure_uri, ssl=CLIENT_CONTEXT): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + f"cannot follow redirect to non-secure URI {insecure_uri}", + ) + + +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class SocksProxyClientTests(ProxyMixin, IsolatedTrioTestCase): + proxy_mode = "socks5@51080" + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_secure_socks_proxy(self): + """Client connects to server securely through a SOCKS5 proxy.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) + async def test_authenticated_socks_proxy(self): + """Client connects to server through an authenticated SOCKS5 proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_authenticated_socks_proxy_error(self): + """Client fails to authenticate to the SOCKS5 proxy.""" + from python_socks import ProxyError as SocksProxyError + + try: + self.proxy_options.update(proxyauth="any") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "failed to connect to SOCKS proxy", + ) + self.assertIsInstance(raised.exception.__cause__, SocksProxyError) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port + async def test_socks_proxy_connection_failure(self): + """Client fails to connect to the SOCKS5 proxy.""" + from python_socks import ProxyConnectionError as SocksProxyConnectionError + + with self.assertRaises(OSError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertIsInstance(raised.exception, SocksProxyConnectionError) + self.assertNumFlows(0) + + async def test_socks_proxy_connection_timeout(self): + """Client times out while connecting to the SOCKS5 proxy.""" + # Replace the proxy with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + self.assertNumFlows(0) + + async def test_explicit_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy set explicitly.""" + async with run_server() as server: + async with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:51080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_ignore_proxy_with_existing_stream(self): + """Cli ent connects using a pre-existing stream.""" + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + # Use a non-existing domain to ensure we connect via stream. + async with connect("ws://invalid/", stream=stream) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(0) + + +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class HTTPProxyClientTests(ProxyMixin, IsolatedTrioTestCase): + proxy_mode = "regular@58080" + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy(self): + """Client connects to server through an HTTP proxy.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_secure_http_proxy(self): + """Client connects to server securely through an HTTP proxy.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) + async def test_authenticated_http_proxy(self): + """Client connects to server through an authenticated HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + async with run_server() as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_authenticated_http_proxy_error(self): + """Client fails to authenticate to the HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="any") + with self.assertRaises(ProxyError) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "proxy rejected connection: HTTP 407", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_protocol_error(self): + """Client receives invalid data when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(break_http_connect=True) + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(break_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_connection_error(self): + """Client receives no response when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(close_http_connect=True) + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(close_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:48080"}) # bad port + async def test_http_proxy_connection_failure(self): + """Client fails to connect to the HTTP proxy.""" + with self.assertRaises(OSError): + async with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertNumFlows(0) + + async def test_http_proxy_connection_timeout(self): + """Client times out while connecting to the HTTP proxy.""" + # Replace the proxy with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=2 * MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy(self): + """Client connects to server through an HTTPS proxy.""" + async with run_server() as server: + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_secure_https_proxy(self): + """Client connects to server securely through an HTTPS proxy.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_server_hostname(self): + """Client sets server_hostname to the value of proxy_server_hostname.""" + async with run_server() as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + ssl_object = client.stream._ssl_object + self.assertEqual(ssl_object.server_hostname, "overridden") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy_invalid_proxy_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with self.assertRaises(trio.BrokenResourceError) as raised: + # The proxy certificate isn't trusted. + async with connect("wss://example.com/"): + self.fail("did not raise") + self.assertIsInstance(raised.exception.__cause__, ssl.SSLCertVerificationError) + self.assertIn( + "certificate verify failed: unable to get local issuer certificate", + str(raised.exception.__cause__), + ) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy_invalid_server_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(trio.BrokenResourceError) as raised: + # The test certificate is self-signed. + async with connect( + get_uri(server, secure=True), proxy_ssl=self.proxy_context + ): + self.fail("did not raise") + self.assertIsInstance(raised.exception.__cause__, ssl.SSLCertVerificationError) + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception.__cause__).replace("-", " "), + ) + self.assertNumFlows(1) + + +class ClientUsageErrorsTests(IsolatedTrioTestCase): + async def test_ssl_without_secure_uri(self): + """Client rejects ssl when URI isn't secure.""" + with self.assertRaises(ValueError) as raised: + async with connect("ws://localhost/", ssl=CLIENT_CONTEXT): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "ssl argument is incompatible with a ws:// URI", + ) + + async def test_proxy_ssl_without_https_proxy(self): + """Client rejects proxy_ssl when proxy isn't HTTPS.""" + with self.assertRaises(ValueError) as raised: + async with connect( + "ws://localhost/", + proxy="http://localhost:8080", + proxy_ssl=True, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "proxy_ssl argument is incompatible with an http:// proxy", + ) + + async def test_unsupported_proxy(self): + """Client rejects unsupported proxy.""" + with self.assertRaises(InvalidProxy) as raised: + async with connect("ws://example.com/", proxy="other://localhost:51080"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", + ) + + async def test_invalid_subprotocol(self): + """Client rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + async with connect("ws://localhost/", subprotocols="chat"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Client rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + async with connect("ws://localhost/", compression=False): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) + + async def test_reentrancy(self): + """Client isn't reentrant.""" + async with run_server() as server: + connecter = connect(get_uri(server)) + async with connecter: + with self.assertRaises(RuntimeError) as raised: + async with connecter: + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "connect() isn't reentrant", + ) diff --git a/tests/trio/test_connection.py b/tests/trio/test_connection.py new file mode 100644 index 000000000..f75e609d3 --- /dev/null +++ b/tests/trio/test_connection.py @@ -0,0 +1,1267 @@ +import contextlib +import itertools +import logging +import uuid +from unittest.mock import patch + +import trio.testing + +from websockets.asyncio.compatibility import TimeoutError, aiter, anext +from websockets.exceptions import ( + ConcurrencyError, + ConnectionClosedError, + ConnectionClosedOK, +) +from websockets.frames import CloseCode, Frame, Opcode +from websockets.protocol import CLIENT, SERVER, Protocol, State +from websockets.trio.connection import * + +from ..protocol import RecordingProtocol +from ..utils import MS, alist +from .connection import InterceptingConnection +from .utils import IsolatedTrioTestCase + + +# Connection implements symmetrical behavior between clients and servers. +# All tests run on the client side and the server side to validate this. + + +class ClientConnectionTests(IsolatedTrioTestCase): + LOCAL = CLIENT + REMOTE = SERVER + + async def asyncSetUp(self): + stream, remote_stream = trio.testing.memory_stream_pair() + protocol = Protocol(self.LOCAL) + remote_protocol = RecordingProtocol(self.REMOTE) + self.connection = Connection( + self.nursery, stream, protocol, close_timeout=2 * MS + ) + self.remote_connection = InterceptingConnection( + self.nursery, remote_stream, remote_protocol + ) + + async def asyncTearDown(self): + await self.remote_connection.aclose() + await self.connection.aclose() + + # Test helpers built upon RecordingProtocol and InterceptingConnection. + + async def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) + + async def assertFramesSent(self, frames): + """Check that several frames were sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) + + async def assertNoFrameSent(self): + """Check that no frame was sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) + + @contextlib.asynccontextmanager + async def delay_frames_rcvd(self, delay): + """Delay frames before they're received by the connection.""" + with self.remote_connection.delay_frames_sent(delay): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def delay_eof_rcvd(self, delay): + """Delay EOF before it's received by the connection.""" + with self.remote_connection.delay_eof_sent(delay): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def drop_frames_rcvd(self): + """Drop frames before they're received by the connection.""" + with self.remote_connection.drop_frames_sent(): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def drop_eof_rcvd(self): + """Drop EOF before it's received by the connection.""" + with self.remote_connection.drop_eof_sent(): + yield + await trio.testing.wait_all_tasks_blocked() + + # Test __aenter__ and __aexit__. + + async def test_aenter(self): + """__aenter__ returns the connection itself.""" + async with self.connection as connection: + self.assertIs(connection, self.connection) + + async def test_aexit(self): + """__aexit__ closes the connection with code 1000.""" + async with self.connection: + await self.assertNoFrameSent() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_aexit_with_exception(self): + """__aexit__ with an exception closes the connection with code 1011.""" + with self.assertRaises(RuntimeError): + async with self.connection: + raise RuntimeError + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) + + # Test __aiter__. + + async def test_aiter_text(self): + """__aiter__ yields text messages.""" + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + + async def test_aiter_binary(self): + """__aiter__ yields binary messages.""" + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + + async def test_aiter_mixed(self): + """__aiter__ yields a mix of text and binary messages.""" + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + + async def test_aiter_connection_closed_ok(self): + """__aiter__ terminates after a normal closure.""" + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.aclose() + with self.assertRaises(StopAsyncIteration): + await anext(iterator) + + async def test_aiter_connection_closed_error(self): + """__aiter__ raises ConnectionClosedError after an error.""" + iterator = aiter(self.connection) + await self.remote_connection.aclose(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await anext(iterator) + await iterator.aclose() + + # Test recv. + + async def test_recv_text(self): + """recv receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_binary(self): + """recv receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) + + async def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual(await self.connection.recv(decode=True), "😀") + + async def test_recv_fragmented_text(self): + """recv receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual(await self.connection.recv(), "😀😀") + + async def test_recv_fragmented_binary(self): + """recv receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_connection_closed_ok(self): + """recv raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.aclose() + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_recv_connection_closed_error(self): + """recv raises ConnectionClosedError after an error.""" + await self.remote_connection.aclose(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + + async def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + + async def test_recv_during_recv(self): + """recv raises ConcurrencyError when called concurrently.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + + try: + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + finally: + nursery.cancel_scope.cancel() + + async def test_recv_during_recv_streaming(self): + """recv raises ConcurrencyError when called concurrently with recv_streaming.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + + try: + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + finally: + nursery.cancel_scope.cancel() + + async def test_recv_cancellation_before_receiving(self): + """recv can be canceled before receiving a message.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + # Running recv again receives the next message. + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_cancellation_while_receiving(self): + """recv can be canceled while receiving a fragmented message.""" + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + self.nursery.start_soon(self.remote_connection.send, fragments()) + await trio.testing.wait_all_tasks_blocked() + + async with trio.open_nursery() as nursery: + nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + gate.set() + + # Running recv again receives the complete message. + self.assertEqual(await self.connection.recv(), "⏳⌛️") + + # Test recv_streaming. + + async def test_recv_streaming_text(self): + """recv_streaming receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀"], + ) + + async def test_recv_streaming_binary(self): + """recv_streaming receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02\xfe\xff"], + ) + + async def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + async def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual( + await alist(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + + async def test_recv_streaming_fragmented_text(self): + """recv_streaming receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_fragmented_binary(self): + """recv_streaming receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_recv_streaming_connection_closed_ok(self): + """recv_streaming raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.aclose() + with self.assertRaises(ConnectionClosedOK): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_connection_closed_error(self): + """recv_streaming raises ConnectionClosedError after an error.""" + await self.remote_connection.aclose(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await alist(self.connection.recv_streaming()) + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + + async def test_recv_streaming_during_recv(self): + """recv_streaming raises ConcurrencyError when called concurrently with recv.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + + try: + with self.assertRaises(ConcurrencyError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming", + ) + + finally: + nursery.cancel_scope.cancel() + + async def test_recv_streaming_during_recv_streaming(self): + """recv_streaming raises ConcurrencyError when called concurrently.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + + try: + with self.assertRaises(ConcurrencyError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + r"cannot call recv_streaming while another coroutine " + r"is already running recv or recv_streaming", + ) + + finally: + nursery.cancel_scope.cancel() + + async def test_recv_streaming_cancellation_before_receiving(self): + """recv_streaming can be canceled before receiving a message.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + # Running recv_streaming again receives the next message. + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_cancellation_while_receiving(self): + """recv_streaming cannot be canceled while receiving a fragmented message.""" + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + async def send_fragments(): + iterator = fragments() + with self.assertRaises(ConnectionClosedError): + await self.remote_connection.send(iterator) + await iterator.aclose() + + self.nursery.start_soon(send_fragments) + await trio.testing.wait_all_tasks_blocked() + + async with trio.open_nursery() as nursery: + nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + gate.set() + + # Running recv_streaming again fails. + with self.assertRaises(ConcurrencyError): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + # Test send. + + async def test_send_text(self): + """send sends a text message.""" + await self.connection.send("😀") + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_binary(self): + """send sends a binary message.""" + await self.connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + + async def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + await self.connection.send("😀", text=False) + self.assertEqual(await self.remote_connection.recv(), "😀".encode()) + + async def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + await self.connection.send("😀".encode(), text=True) + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_fragmented_text(self): + """send sends a fragmented text message.""" + await self.connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_fragmented_binary(self): + """send sends a fragmented binary message.""" + await self.connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + await self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + await self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_text(self): + """send sends a fragmented text message asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_binary(self): + """send sends a fragmented binary message asynchronously.""" + + async def fragments(): + yield b"\x01\x02" + yield b"\xfe\xff" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_async_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments(), text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_async_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes asynchronously.""" + + async def fragments(): + yield "😀".encode() + yield "😀".encode() + + await self.connection.send(fragments(), text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_connection_closed_ok(self): + """send raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.aclose() + with self.assertRaises(ConnectionClosedOK): + await self.connection.send("😀") + + async def test_send_connection_closed_error(self): + """send raises ConnectionClosedError after an error.""" + await self.remote_connection.aclose(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.send("😀") + + async def test_send_during_send(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with send_in_progress is removed + # from send() in the case when message is an AsyncIterable. + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + self.nursery.start_soon(self.connection.send, fragments()) + await trio.testing.wait_all_tasks_blocked() + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + self.nursery.start_soon(self.connection.send, "✅") + await trio.testing.wait_all_tasks_blocked() + await self.assertNoFrameSent() + + gate.set() + await trio.testing.wait_all_tasks_blocked() + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_empty_iterable(self): + """send does nothing when called with an empty iterable.""" + await self.connection.send([]) + await self.connection.aclose() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + with self.assertRaises(TypeError): + await self.connection.send(["😀", b"\xfe\xff"]) + + async def test_send_unsupported_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send([None]) + + async def test_send_empty_async_iterable(self): + """send does nothing when called with an empty async iterable.""" + + async def fragments(): + return + yield # pragma: no cover + + await self.connection.send(fragments()) + await self.connection.aclose() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_async_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + + async def fragments(): + yield "😀" + yield b"\xfe\xff" + + iterator = fragments() + with self.assertRaises(TypeError): + await self.connection.send(iterator) + await iterator.aclose() + + async def test_send_unsupported_async_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + + async def fragments(): + yield None + + iterator = fragments() + with self.assertRaises(TypeError): + await self.connection.send(iterator) + await iterator.aclose() + + async def test_send_dict(self): + """send raises TypeError when called with a dict.""" + with self.assertRaises(TypeError): + await self.connection.send({"type": "object"}) + + async def test_send_unsupported_type(self): + """send raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send(None) + + # Test close. + + async def test_aclose(self): + """aclose sends a close frame.""" + await self.connection.aclose() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_aclose_explicit_code_reason(self): + """aclose sends a close frame with a given code and reason.""" + await self.connection.aclose(CloseCode.GOING_AWAY, "bye!") + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) + + async def test_aclose_waits_for_close_frame(self): + """aclose waits for a close frame then EOF before returning.""" + t0 = trio.current_time() + async with self.delay_frames_rcvd(MS): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_aclose_waits_for_connection_closed(self): + """aclose waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + t0 = trio.current_time() + async with self.delay_eof_rcvd(MS): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_aclose_no_timeout_waits_for_close_frame(self): + """aclose without timeout waits for a close frame then EOF before returning.""" + self.connection.close_timeout = None + + t0 = trio.current_time() + async with self.delay_frames_rcvd(MS): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_aclose_no_timeout_waits_for_connection_closed(self): + """aclose without timeout waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + self.connection.close_timeout = None + + t0 = trio.current_time() + async with self.delay_eof_rcvd(MS): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_timeout_waiting_for_close_frame(self): + """aclose times out if no close frame is received.""" + t0 = trio.current_time() + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_close_timeout_waiting_for_connection_closed(self): + """aclose times out if EOF isn't received.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + t0 = trio.current_time() + async with self.drop_eof_rcvd(): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_aclose_preserves_queued_messages(self): + """aclose preserves messages buffered in the assembler.""" + await self.remote_connection.send("😀") + await self.connection.aclose() + + self.assertEqual(await self.connection.recv(), "😀") + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_aclose_idempotency(self): + """aclose does nothing if the connection is already closed.""" + await self.connection.aclose() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + await self.connection.aclose() + await self.assertNoFrameSent() + + async def test_aclose_during_recv(self): + """aclose aborts recv when called concurrently with recv.""" + + async def closer(): + await trio.sleep(MS) + await self.connection.aclose() + + self.nursery.start_soon(closer) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_aclose_during_send(self): + """aclose fails the connection when called concurrently with send.""" + close_gate = trio.Event() + exit_gate = trio.Event() + + async def closer(): + await close_gate.wait() + await trio.testing.wait_all_tasks_blocked() + await self.connection.aclose() + exit_gate.set() + + async def fragments(): + yield "⏳" + close_gate.set() + await exit_gate.wait() + yield "⌛️" + + self.nursery.start_soon(closer) + + iterator = fragments() + async with contextlib.aclosing(iterator): + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send(iterator) + + exc = raised.exception + self.assertEqual( + str(exc), + "sent 1011 (internal error) close during fragmented message; " + "no close frame received", + ) + self.assertIsNone(exc.__cause__) + + # Test wait_closed. + + async def test_wait_closed(self): + """wait_closed waits for the connection to close.""" + closed = trio.Event() + + async def closer(): + await self.connection.wait_closed() + closed.set() + + self.nursery.start_soon(closer) + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(closed.is_set()) + + await self.connection.aclose() + await trio.testing.wait_all_tasks_blocked() + self.assertTrue(closed.is_set()) + + # Test ping. + + @patch("random.getrandbits") + async def test_ping(self, getrandbits): + """ping sends a ping frame with a random payload.""" + getrandbits.side_effect = itertools.count(1918987876) + await self.connection.ping() + getrandbits.assert_called_once_with(32) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + async def test_ping_explicit_text(self): + """ping sends a ping frame with a payload provided as text.""" + await self.connection.ping("ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_ping_explicit_binary(self): + """ping sends a ping frame with a payload provided as binary.""" + await self.connection.ping(b"ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_acknowledge_ping(self): + """ping is acknowledged by a pong with the same payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.remote_connection.pong("this") + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_ping_non_matching_pong(self): + """ping isn't acknowledged by a pong with a different payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.remote_connection.pong("that") + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_previous_ping(self): + """ping is acknowledged by a pong for a later ping.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.connection.ping("that") + await self.remote_connection.pong("that") + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_ping_on_close(self): + """ping with ack_on_close is acknowledged when the connection is closed.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received_ack_on_close = await self.connection.ping( + "this", ack_on_close=True + ) + pong_received = await self.connection.ping("that") + await self.connection.aclose() + with trio.fail_after(MS): + await pong_received_ack_on_close.wait() + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(MS): + await pong_received.wait() + + async def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("idem") + + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + await self.remote_connection.pong("idem") + with trio.fail_after(MS): + await pong_received.wait() + + await self.connection.ping("idem") # doesn't raise an exception + + async def test_ping_unsupported_type(self): + """ping raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.ping([]) + + # Test pong. + + async def test_pong(self): + """pong sends a pong frame.""" + await self.connection.pong() + await self.assertFrameSent(Frame(Opcode.PONG, b"")) + + async def test_pong_explicit_text(self): + """pong sends a pong frame with a payload provided as text.""" + await self.connection.pong("pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_explicit_binary(self): + """pong sends a pong frame with a payload provided as binary.""" + await self.connection.pong(b"pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_unsupported_type(self): + """pong raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.pong([]) + + # Test keepalive. + + def keepalive_task_is_running(self): + return any( + task.name == "websockets.trio.connection.Connection.keepalive" + for task in self.nursery.child_tasks + ) + + @patch("random.getrandbits") + async def test_keepalive(self, getrandbits): + """keepalive sends pings at ping_interval and measures latency.""" + getrandbits.side_effect = itertools.count(1918987876) + self.connection.ping_interval = 3 * MS + self.connection.start_keepalive() + self.assertTrue(self.keepalive_task_is_running()) + self.assertEqual(self.connection.latency, 0) + # 3 ms: keepalive() sends a ping frame. + # 3.x ms: a pong frame is received. + await trio.sleep(4 * MS) + # 4 ms: check that the ping frame was sent. + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + self.assertGreater(self.connection.latency, 0) + self.assertLess(self.connection.latency, MS) + + async def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + self.assertFalse(self.keepalive_task_is_running()) + + @patch("random.getrandbits") + async def test_keepalive_times_out(self, getrandbits): + """keepalive closes the connection if ping_timeout elapses.""" + getrandbits.side_effect = itertools.count(1918987876) + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = 2 * MS + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + # 4.x ms: a pong frame is dropped. + await trio.sleep(5 * MS) + # 6 ms: no pong frame is received; the connection is closed. + await trio.sleep(3 * MS) + # 8 ms: check that the connection is closed. + self.assertEqual(self.connection.state, State.CLOSED) + + @patch("random.getrandbits") + async def test_keepalive_ignores_timeout(self, getrandbits): + """keepalive ignores timeouts if ping_timeout isn't set.""" + getrandbits.side_effect = itertools.count(1918987876) + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = None + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + # 4.x ms: a pong frame is dropped. + await trio.sleep(5 * MS) + # 6 ms: no pong frame is received; the connection remains open. + await trio.sleep(3 * MS) + # 7 ms: check that the connection is still open. + self.assertEqual(self.connection.state, State.OPEN) + + async def test_keepalive_terminates_while_sleeping(self): + """keepalive task terminates while waiting to send a ping.""" + self.connection.ping_interval = 3 * MS + self.connection.start_keepalive() + self.assertTrue(self.keepalive_task_is_running()) + await trio.testing.wait_all_tasks_blocked() + await self.connection.aclose() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_terminates_when_sending_ping_fails(self): + """keepalive task terminates when sending a ping fails.""" + self.connection.ping_interval = MS + self.connection.start_keepalive() + self.assertTrue(self.keepalive_task_is_running()) + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.aclose() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_terminates_while_waiting_for_pong(self): + """keepalive task terminates while waiting to receive a pong.""" + self.connection.ping_interval = MS + self.connection.ping_timeout = 4 * MS + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 1 ms: keepalive() sends a ping frame. + # 1.x ms: a pong frame is dropped. + await trio.sleep(2 * MS) + # 2 ms: close the connection before ping_timeout elapses. + await self.connection.aclose() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_reports_errors(self): + """keepalive reports unexpected errors in logs.""" + self.connection.ping_interval = 2 * MS + self.connection.start_keepalive() + # Inject a fault when waiting to receive a pong. + with self.assertLogs("websockets", logging.ERROR) as logs: + with patch("trio.Event.wait", side_effect=Exception("BOOM")): + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is dropped. + await trio.sleep(3 * MS) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["keepalive ping failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + # Test parameters. + + async def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + stream, remote_stream = trio.testing.memory_stream_pair() + async with contextlib.aclosing(remote_stream): + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) + self.assertEqual(connection.close_timeout, 42 * MS) + + async def test_max_queue(self): + """max_queue configures high-water mark of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + async with contextlib.aclosing(remote_stream): + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=4, + ) + self.assertEqual(connection.recv_messages.high, 4) + + async def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + async with contextlib.aclosing(remote_stream): + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=None, + ) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.low, None) + + async def test_max_queue_tuple(self): + """max_queue configures high-water and low-water marks of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + async with contextlib.aclosing(remote_stream): + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=(4, 2), + ) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + + # Test attributes. + + async def test_id(self): + """Connection has an id attribute.""" + self.assertIsInstance(self.connection.id, uuid.UUID) + + async def test_logger(self): + """Connection has a logger attribute.""" + self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) + + @contextlib.asynccontextmanager + async def get_server_and_client_streams(self): + listeners = await trio.open_tcp_listeners(0, host="127.0.0.1") + assert len(listeners) == 1 + listener = listeners[0] + client_stream = await trio.testing.open_stream_to_socket_listener(listener) + client_port = client_stream.socket.getsockname()[1] + server_stream = await listener.accept() + server_port = listener.socket.getsockname()[1] + try: + yield client_stream, server_stream, client_port, server_port + finally: + await server_stream.aclose() + await client_stream.aclose() + await listener.aclose() + + async def test_local_address(self): + """Connection provides a local_address attribute.""" + async with self.get_server_and_client_streams() as ( + client_stream, + server_stream, + client_port, + server_port, + ): + stream = {CLIENT: client_stream, SERVER: server_stream}[self.LOCAL] + port = {CLIENT: client_port, SERVER: server_port}[self.LOCAL] + connection = Connection(self.nursery, stream, Protocol(self.LOCAL)) + self.assertEqual(connection.local_address, ("127.0.0.1", port)) + + async def test_remote_address(self): + """Connection provides a remote_address attribute.""" + async with self.get_server_and_client_streams() as ( + client_stream, + server_stream, + client_port, + server_port, + ): + stream = {CLIENT: client_stream, SERVER: server_stream}[self.LOCAL] + remote_port = {CLIENT: server_port, SERVER: client_port}[self.LOCAL] + connection = Connection(self.nursery, stream, Protocol(self.LOCAL)) + self.assertEqual(connection.remote_address, ("127.0.0.1", remote_port)) + + async def test_state(self): + """Connection has a state attribute.""" + self.assertIs(self.connection.state, State.OPEN) + + async def test_request(self): + """Connection has a request attribute.""" + self.assertIsNone(self.connection.request) + + async def test_response(self): + """Connection has a response attribute.""" + self.assertIsNone(self.connection.response) + + async def test_subprotocol(self): + """Connection has a subprotocol attribute.""" + self.assertIsNone(self.connection.subprotocol) + + async def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + async def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + + # Test reporting of network errors. + + async def test_writing_in_recv_events_fails(self): + """Error when responding to incoming frames is correctly reported.""" + # Inject a fault by shutting down the stream for writing — but not the + # stream for reading because that would terminate the connection. + self.connection.stream.send_stream.close() + # Receive a ping. Responding with a pong will fail. + await self.remote_connection.ping() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, trio.ClosedResourceError) + + async def test_writing_in_send_context_fails(self): + """Error when sending outgoing frame is correctly reported.""" + # Inject a fault by shutting down the stream for writing — but not the + # stream for reading because that would terminate the connection. + self.connection.stream.send_stream.close() + # Sending a pong will fail. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.pong() + self.assertIsInstance(raised.exception.__cause__, trio.ClosedResourceError) + + # Test safety nets — catching all exceptions in case of bugs. + + @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) + async def test_unexpected_failure_in_recv_events(self, events_received): + """Unexpected internal error in recv_events() is correctly reported.""" + await self.remote_connection.send("😀") + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, AssertionError) + + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) + async def test_unexpected_failure_in_send_context(self, send_text): + """Unexpected internal error in send_context() is correctly reported.""" + # Send a message to trigger the fault. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send("😀") + self.assertIsInstance(raised.exception.__cause__, AssertionError) + + +class ServerConnectionTests(ClientConnectionTests): + LOCAL = SERVER + REMOTE = CLIENT diff --git a/tests/trio/test_messages.py b/tests/trio/test_messages.py new file mode 100644 index 000000000..ea25a8f87 --- /dev/null +++ b/tests/trio/test_messages.py @@ -0,0 +1,579 @@ +import contextlib +import unittest +import unittest.mock + +import trio.testing + +from websockets.asyncio.compatibility import aiter, anext +from websockets.exceptions import ConcurrencyError +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from websockets.trio.messages import * + +from ..utils import alist +from .utils import IsolatedTrioTestCase + + +class AssemblerTests(IsolatedTrioTestCase): + def setUp(self): + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) + + # Test get. + + async def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, "café") + + async def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"tea")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, b"tea") + + async def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + async def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + async def test_get_resumes_reading(self): + """get resumes reading when queue goes below the low-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + await self.assembler.get() + self.resume.assert_not_called() + # queue is at the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + # queue is below the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + async def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + await self.assembler.get() + await self.assembler.get() + await self.assembler.get() + self.resume.assert_not_called() + + async def test_cancel_get_before_first_frame(self): + """get can be canceled safely before reading the first frame.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(self.assembler.get) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_cancel_get_after_first_frame(self): + """get can be canceled safely after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + async with trio.open_nursery() as nursery: + nursery.start_soon(self.assembler.get) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + # Test get_iter. + + async def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + fragments = None + + async def getter(): + nonlocal fragments + fragments = await alist(self.assembler.get_iter()) + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + fragments = None + + async def getter(): + nonlocal fragments + fragments = await alist(self.assembler.get_iter()) + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"tea")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) + + async def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) + + async def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the low-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + # queue is above the low-water mark + await anext(iterator) + self.resume.assert_not_called() + # queue is at the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + # queue is below the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + async def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + await anext(iterator) + await anext(iterator) + await anext(iterator) + self.resume.assert_not_called() + + async def test_cancel_get_iter_before_first_frame(self): + """get_iter can be canceled safely before reading the first frame.""" + + async def getter(): + await alist(self.assembler.get_iter()) + + async with trio.open_nursery() as nursery: + nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_cancel_get_iter_after_first_frame(self): + """get_iter cannot be canceled after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + async def getter(): + await alist(self.assembler.get_iter()) + + async with trio.open_nursery() as nursery: + nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + with self.assertRaises(ConcurrencyError): + await alist(self.assembler.get_iter()) + + # Test put. + + async def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() + + async def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_not_called() + + # Test termination. + + async def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + + async def closer(): + self.assembler.close() + + self.nursery.start_soon(closer) + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + + async def closer(): + self.assembler.close() + + self.nursery.start_soon(closer) + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOFError on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + async for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + + async def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + async def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close() + + # Test (non-)concurrency. + + async def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently.""" + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(self.assembler.get) + nursery.start_soon(self.assembler.get) + + async def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(lambda: alist(self.assembler.get_iter())) + nursery.start_soon(self.assembler.get) + + async def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(lambda: alist(self.assembler.get_iter())) + nursery.start_soon(self.assembler.get) + + async def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently.""" + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(lambda: alist(self.assembler.get_iter())) + nursery.start_soon(lambda: alist(self.assembler.get_iter())) + + # Test setting limits. + + async def test_set_high_water_mark(self): + """high sets the high-water and low-water marks.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + async def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) + + async def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + async def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + + async def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + async def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5) diff --git a/tests/trio/test_server.py b/tests/trio/test_server.py new file mode 100644 index 000000000..12dcafc7d --- /dev/null +++ b/tests/trio/test_server.py @@ -0,0 +1,831 @@ +import dataclasses +import hmac +import http +import logging + +import trio + +from websockets.exceptions import ( + ConnectionClosedError, + ConnectionClosedOK, + InvalidStatus, + NegotiationError, +) +from websockets.http11 import Request, Response +from websockets.trio.client import connect +from websockets.trio.server import * + +from ..utils import ( + CLIENT_CONTEXT, + MS, + SERVER_CONTEXT, +) +from .server import ( + EvalShellMixin, + get_host_port, + get_uri, + handler, + run_server, +) +from .utils import IsolatedTrioTestCase + + +class ServerTests(EvalShellMixin, IsolatedTrioTestCase): + async def test_connection(self): + """Server receives connection from client and the handshake succeeds.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_connection_handler_returns(self): + """Connection handler returns.""" + async with run_server() as server: + async with connect(get_uri(server) + "/no-op") as client: + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1000 (OK); then sent 1000 (OK)", + ) + + async def test_connection_handler_raises_exception(self): + """Connection handler raises an exception.""" + async with run_server() as server: + async with connect(get_uri(server) + "/crash") as client: + with self.assertRaises(ConnectionClosedError) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1011 (internal error); then sent 1011 (internal error)", + ) + + async def test_existing_listeners(self): + """Server receives connection using pre-existing listeners.""" + listeners = await trio.open_tcp_listeners(0, host="localhost") + host, port = get_host_port(listeners) + async with run_server(port=None, host=None, listeners=listeners): + async with connect(f"ws://{host}:{port}/") as client: # type: ignore + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + + async def test_select_subprotocol(self): + """Server selects a subprotocol with the select_subprotocol callable.""" + + def select_subprotocol(ws, subprotocols): + ws.select_subprotocol_ran = True + assert "chat" in subprotocols + return "chat" + + async with run_server( + subprotocols=["chat"], select_subprotocol=select_subprotocol + ) as server: + async with connect(get_uri(server), subprotocols=["chat"]) as client: + await self.assertEval(client, "ws.select_subprotocol_ran", "True") + await self.assertEval(client, "ws.subprotocol", "chat") + + async def test_select_subprotocol_rejects_handshake(self): + """Server rejects handshake if select_subprotocol raises NegotiationError.""" + + def select_subprotocol(ws, subprotocols): + raise NegotiationError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_select_subprotocol_raises_exception(self): + """Server returns an error if select_subprotocol raises an exception.""" + + def select_subprotocol(ws, subprotocols): + raise RuntimeError + + async with run_server(select_subprotocol=select_subprotocol) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + + async def test_compression_is_enabled(self): + """Server enables compression by default.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await self.assertEval( + client, + "[type(ext).__name__ for ext in ws.protocol.extensions]", + "['PerMessageDeflate']", + ) + + async def test_disable_compression(self): + """Server disables compression.""" + async with run_server(compression=None) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.protocol.extensions", "[]") + + async def test_process_request_returns_none(self): + """Server runs process_request and continues the handshake.""" + + def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_async_process_request_returns_none(self): + """Server runs async process_request and continues the handshake.""" + + async def process_request(ws, request): + self.assertIsInstance(request, Request) + ws.process_request_ran = True + + async with run_server(process_request=process_request) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_request_ran", "True") + + async def test_process_request_returns_response(self): + """Server aborts handshake if process_request returns a response.""" + + def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async def handler(ws): + self.fail("handler must not run") + + with self.assertNoLogs("websockets", logging.ERROR): + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_async_process_request_returns_response(self): + """Server aborts handshake if async process_request returns a response.""" + + async def process_request(ws, request): + return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden") + + async def handler(ws): + self.fail("handler must not run") + + with self.assertNoLogs("websockets", logging.ERROR): + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 403", + ) + + async def test_process_request_raises_exception(self): + """Server returns an error if process_request raises an exception.""" + + def process_request(ws, request): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_async_process_request_raises_exception(self): + """Server returns an error if async process_request raises an exception.""" + + async def process_request(ws, request): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_request=process_request) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_process_response_returns_none(self): + """Server runs process_response but keeps the handshake response.""" + + def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_async_process_response_returns_none(self): + """Server runs async process_response but keeps the handshake response.""" + + async def process_response(ws, request, response): + self.assertIsInstance(request, Request) + self.assertIsInstance(response, Response) + ws.process_response_ran = True + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.process_response_ran", "True") + + async def test_process_response_modifies_response(self): + """Server runs process_response and modifies the handshake response.""" + + def process_response(ws, request, response): + response.headers["X-ProcessResponse"] = "OK" + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_async_process_response_modifies_response(self): + """Server runs async process_response and modifies the handshake response.""" + + async def process_response(ws, request, response): + response.headers["X-ProcessResponse"] = "OK" + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_process_response_replaces_response(self): + """Server runs process_response and replaces the handshake response.""" + + def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_async_process_response_replaces_response(self): + """Server runs async process_response and replaces the handshake response.""" + + async def process_response(ws, request, response): + headers = response.headers.copy() + headers["X-ProcessResponse"] = "OK" + return dataclasses.replace(response, headers=headers) + + async with run_server(process_response=process_response) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.response.headers["X-ProcessResponse"], "OK") + + async def test_process_response_raises_exception(self): + """Server returns an error if process_response raises an exception.""" + + def process_response(ws, request, response): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_async_process_response_raises_exception(self): + """Server returns an error if async process_response raises an exception.""" + + async def process_response(ws, request, response): + raise RuntimeError("BOOM") + + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server(process_response=process_response) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 500", + ) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + async def test_override_server(self): + """Server can override Server header with server_header.""" + async with run_server(server_header="Neo") as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.response.headers['Server']", "Neo") + + async def test_remove_server(self): + """Server can remove Server header with server_header.""" + async with run_server(server_header=None) as server: + async with connect(get_uri(server)) as client: + await self.assertEval( + client, "'Server' in ws.response.headers", "False" + ) + + async def test_keepalive_is_enabled(self): + """Server enables keepalive and measures latency.""" + async with run_server(ping_interval=MS) as server: + async with connect(get_uri(server)) as client: + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + await trio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertGreater(latency, 0) + + async def test_disable_keepalive(self): + """Server disables keepalive.""" + async with run_server(ping_interval=None) as server: + async with connect(get_uri(server)) as client: + await trio.sleep(2 * MS) + await client.send("ws.latency") + latency = eval(await client.recv()) + self.assertEqual(latency, 0) + + async def test_logger(self): + """Server accepts a logger argument.""" + logger = logging.getLogger("test") + async with run_server(logger=logger) as server: + self.assertEqual(server.logger.name, logger.name) + + async def test_custom_connection_factory(self): + """Server runs ServerConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + server = ServerConnection(*args, **kwargs) + server.create_connection_ran = True + return server + + async with run_server(create_connection=create_connection) as server: + async with connect(get_uri(server)) as client: + await self.assertEval(client, "ws.create_connection_ran", "True") + + async def test_connections(self): + """Server provides a connections property.""" + async with run_server() as server: + self.assertEqual(server.connections, set()) + async with connect(get_uri(server)) as client: + self.assertEqual(len(server.connections), 1) + ws_id = str(next(iter(server.connections)).id) + await self.assertEval(client, "ws.id", ws_id) + self.assertEqual(server.connections, set()) + + async def test_handshake_fails(self): + """Server receives connection from client but the handshake fails.""" + + def remove_key_header(self, request): + del request.headers["Sec-WebSocket-Key"] + + async with run_server(process_request=remove_key_header) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 400", + ) + + async def test_timeout_during_handshake(self): + """Server times out before receiving handshake request from client.""" + async with run_server(open_timeout=MS) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + try: + self.assertEqual(await stream.receive_some(4096), b"") + finally: + await stream.aclose() + + async def test_connection_closed_during_handshake(self): + """Server reads EOF before receiving handshake request from client.""" + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + await stream.aclose() + + async def test_junk_handshake(self): + """Server closes the connection when receiving non-HTTP request from client.""" + with self.assertLogs("websockets", logging.ERROR) as logs: + async with run_server() as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + await stream.send_all(b"HELO relay.invalid\r\n") + try: + # Wait for the server to close the connection. + self.assertEqual(await stream.receive_some(4096), b"") + finally: + await stream.aclose() + + self.assertEqual( + [record.getMessage() for record in logs.records], + ["opening handshake failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["did not receive a valid HTTP request"], + ) + self.assertEqual( + [str(record.exc_info[1].__cause__) for record in logs.records], + ["invalid HTTP request line: HELO relay.invalid"], + ) + + async def test_close_server_rejects_connecting_connections(self): + """Server rejects connecting connections with HTTP 503 when closing.""" + + async def process_request(ws, _request): + while not ws.server.closing: + await trio.sleep(0) # pragma: no cover + + async with run_server(process_request=process_request) as server: + + async def close_server(server): + await trio.sleep(MS) + await server.aclose() + + self.nursery.start_soon(close_server, server) + + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 503", + ) + + async def test_close_server_closes_open_connections(self): + """Server closes open connections with close code 1001 when closing.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await server.aclose() + with self.assertRaises(ConnectionClosedOK) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1001 (going away); then sent 1001 (going away)", + ) + + async def test_close_server_closes_open_connections_with_code_and_reason(self): + """Server closes open connections with custom code and reason when closing.""" + async with run_server() as server: + async with connect(get_uri(server)) as client: + await server.aclose(code=1012, reason="restarting") + with self.assertRaises(ConnectionClosedError) as raised: + await client.recv() + self.assertEqual( + str(raised.exception), + "received 1012 (service restart) restarting; " + "then sent 1012 (service restart) restarting", + ) + + async def test_close_server_keeps_connections_open(self): + """Server waits for client to close open connections when closing.""" + + async with run_server() as server: + server_closed = trio.Event() + + async def close_server(): + await server.aclose(close_connections=False) + server_closed.set() + + async with connect(get_uri(server)) as client: + self.nursery.start_soon(close_server) + + # Server cannot receive new connections. + with self.assertRaises(OSError): + async with connect(get_uri(server)): + self.fail("did not raise") + + # The server waits for the client to close the connection. + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(MS): + await server_closed.wait() + + # Once the client closes the connection, the server terminates. + await client.aclose() + with trio.fail_after(MS): + await server_closed.wait() + + async def test_close_server_keeps_handlers_running(self): + """Server waits for connection handlers to terminate.""" + async with run_server() as server: + server_closed = trio.Event() + + async def close_server(): + await server.aclose(close_connections=False) + server_closed.set() + + async with connect(get_uri(server) + "/delay") as client: + # Delay termination of connection handler. + await client.send(str(3 * MS)) + + self.nursery.start_soon(close_server) + + # The server waits for the connection handler to terminate. + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(2 * MS): + await server_closed.wait() + + # Set a large timeout here, else the test becomes flaky. + with trio.fail_after(5 * MS): + await server_closed.wait() + + +SSL_OBJECT = "ws.stream._ssl_object" + + +class SecureServerTests(EvalShellMixin, IsolatedTrioTestCase): + async def test_connection(self): + """Server receives secure connection from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server, secure=True), ssl=CLIENT_CONTEXT + ) as client: + await self.assertEval(client, "ws.protocol.state.name", "OPEN") + await self.assertEval(client, f"{SSL_OBJECT}.version()[:3]", "TLS") + + async def test_timeout_during_tls_handshake(self): + """Server times out before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT, open_timeout=MS) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + try: + self.assertEqual(await stream.receive_some(4096), b"") + finally: + await stream.aclose() + + async def test_connection_closed_during_tls_handshake(self): + """Server reads EOF before receiving TLS handshake request from client.""" + async with run_server(ssl=SERVER_CONTEXT) as server: + stream = await trio.open_tcp_stream(*get_host_port(server.listeners)) + await stream.aclose() + + +class ServerUsageErrorsTests(IsolatedTrioTestCase): + async def test_missing_port(self): + """Server requires port.""" + with self.assertRaises(ValueError) as raised: + await serve(handler, None) + self.assertEqual( + str(raised.exception), + "port is required when listeners is not provided", + ) + + async def test_port_and_listeners(self): + """Server rejects port when listeners is provided.""" + listeners = await trio.open_tcp_listeners(0) + try: + with self.assertRaises(ValueError) as raised: + await serve(handler, port=0, listeners=listeners) + self.assertEqual( + str(raised.exception), + "port is incompatible with listeners", + ) + finally: + for listener in listeners: + await listener.aclose() + + async def test_host_and_listeners(self): + """Server rejects host when listeners is provided.""" + listeners = await trio.open_tcp_listeners(0) + try: + with self.assertRaises(ValueError) as raised: + await serve(handler, host="localhost", listeners=listeners) + self.assertEqual( + str(raised.exception), + "host is incompatible with listeners", + ) + finally: + for listener in listeners: + await listener.aclose() + + async def test_backlog_and_listeners(self): + """Server rejects backlog when listeners is provided.""" + listeners = await trio.open_tcp_listeners(0) + try: + with self.assertRaises(ValueError) as raised: + await serve(handler, backlog=65535, listeners=listeners) + self.assertEqual( + str(raised.exception), + "backlog is incompatible with listeners", + ) + finally: + for listener in listeners: + await listener.aclose() + + async def test_invalid_subprotocol(self): + """Server rejects single value of subprotocols.""" + with self.assertRaises(TypeError) as raised: + await serve(handler, subprotocols="chat") + self.assertEqual( + str(raised.exception), + "subprotocols must be a list, not a str", + ) + + async def test_unsupported_compression(self): + """Server rejects incorrect value of compression.""" + with self.assertRaises(ValueError) as raised: + await serve(handler, compression=False) + self.assertEqual( + str(raised.exception), + "unsupported compression: False", + ) + + +class BasicAuthTests(EvalShellMixin, IsolatedTrioTestCase): + async def test_valid_authorization(self): + """basic_auth authenticates client with HTTP Basic Authentication.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_missing_authorization(self): + """basic_auth rejects client without credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_unsupported_authorization(self): + """basic_auth rejects client with unsupported credentials.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Negotiate ..."}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_unknown_username(self): + """basic_auth rejects client with unknown username.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "iloveyou")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_authorization_with_incorrect_password(self): + """basic_auth rejects client with incorrect password.""" + async with run_server( + process_request=basic_auth(credentials=("hello", "changeme")), + ) as server: + with self.assertRaises(InvalidStatus) as raised: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "server rejected WebSocket connection: HTTP 401", + ) + + async def test_list_of_credentials(self): + """basic_auth accepts a list of hard coded credentials.""" + async with run_server( + process_request=basic_auth( + credentials=[ + ("hello", "iloveyou"), + ("bye", "youloveme"), + ] + ), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic YnllOnlvdWxvdmVtZQ=="}, + ) as client: + await self.assertEval(client, "ws.username", "bye") + + async def test_check_credentials_function(self): + """basic_auth accepts a check_credentials function.""" + + def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_check_credentials_coroutine(self): + """basic_auth accepts a check_credentials coroutine.""" + + async def check_credentials(username, password): + return hmac.compare_digest(password, "iloveyou") + + async with run_server( + process_request=basic_auth(check_credentials=check_credentials), + ) as server: + async with connect( + get_uri(server), + additional_headers={"Authorization": "Basic aGVsbG86aWxvdmV5b3U="}, + ) as client: + await self.assertEval(client, "ws.username", "hello") + + async def test_without_credentials_or_check_credentials(self): + """basic_auth requires either credentials or check_credentials.""" + with self.assertRaises(ValueError) as raised: + basic_auth() + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_with_credentials_and_check_credentials(self): + """basic_auth requires only one of credentials and check_credentials.""" + with self.assertRaises(ValueError) as raised: + basic_auth( + credentials=("hello", "iloveyou"), + check_credentials=lambda: False, # pragma: no cover + ) + self.assertEqual( + str(raised.exception), + "provide either credentials or check_credentials", + ) + + async def test_bad_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=42) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: 42", + ) + + async def test_bad_list_of_credentials(self): + """basic_auth receives an unsupported credentials argument.""" + with self.assertRaises(TypeError) as raised: + basic_auth(credentials=[42]) + self.assertEqual( + str(raised.exception), + "invalid credentials argument: [42]", + ) diff --git a/tests/trio/test_utils.py b/tests/trio/test_utils.py new file mode 100644 index 000000000..1ecdd80f1 --- /dev/null +++ b/tests/trio/test_utils.py @@ -0,0 +1,40 @@ +import trio.testing + +from websockets.trio.utils import * + +from .utils import IsolatedTrioTestCase + + +class UtilsTests(IsolatedTrioTestCase): + async def test_race_events(self): + event1 = trio.Event() + event2 = trio.Event() + done = trio.Event() + + async def waiter(): + await race_events(event1, event2) + done.set() + + async with trio.open_nursery() as nursery: + nursery.start_soon(waiter) + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(done.is_set()) + + event1.set() + await trio.testing.wait_all_tasks_blocked() + self.assertTrue(done.is_set()) + + async def test_race_events_cancelled(self): + event1 = trio.Event() + event2 = trio.Event() + + async def waiter(): + with trio.move_on_after(0): + await race_events(event1, event2) + + async with trio.open_nursery() as nursery: + nursery.start_soon(waiter) + + async def test_race_events_no_events(self): + with self.assertRaises(ValueError): + await race_events() diff --git a/tests/trio/utils.py b/tests/trio/utils.py new file mode 100644 index 000000000..bf325cc36 --- /dev/null +++ b/tests/trio/utils.py @@ -0,0 +1,61 @@ +import functools +import inspect +import sys +import unittest + +import trio.testing + + +if sys.version_info[:2] < (3, 11): # pragma: no cover + from exceptiongroup import BaseExceptionGroup + + +class IsolatedTrioTestCase(unittest.TestCase): + """ + Wrap test coroutines with :func:`trio.testing.trio_test` automatically. + + Create a nursery for each test, available in the :attr:`nursery` attribute. + + :meth:`asyncSetUp` and :meth:`asyncTearDown` are supported, similar to + :class:`unittest.IsolatedAsyncioTestCase`, but ``addAsyncCleanup`` isn't. + + """ + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + for name in unittest.defaultTestLoader.getTestCaseNames(cls): + test = getattr(cls, name) + if getattr(test, "converted_to_trio", False): + return + assert inspect.iscoroutinefunction(test) + setattr(cls, name, cls.convert_to_trio(test)) + + @staticmethod + def convert_to_trio(test): + @trio.testing.trio_test + @functools.wraps(test) + async def new_test(self, *args, **kwargs): + try: + # Provide a nursery so it's easy to start tasks. + async with trio.open_nursery() as self.nursery: + await self.asyncSetUp() + try: + return await test(self, *args, **kwargs) + finally: + await self.asyncTearDown() + except BaseExceptionGroup as exc: + # Unwrap exceptions like unittest.SkipTest. Multiple exceptions + # could occur is a test fails with multiple errors; this is OK. + try: + trio._util.raise_single_exception_from_group(exc) + except trio._util.MultipleExceptionError: # pragma: no cover + raise + + new_test.converted_to_trio = True + return new_test + + async def asyncSetUp(self): + pass + + async def asyncTearDown(self): + pass diff --git a/tests/utils.py b/tests/utils.py index bd3bb0ed9..4db014337 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -112,6 +112,13 @@ def assertDeprecationWarning(self, message): self.assertEqual(str(warning.message), message) +async def alist(async_iterable): + items = [] + async for item in async_iterable: + items.append(item) + return items + + @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tox.ini b/tox.ini index ce4572e59..dce6698c3 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ pass_env = deps = py311,py312,py313,py314,coverage,maxi_cov: mitmproxy py311,py312,py313,py314,coverage,maxi_cov: python-socks[asyncio] + trio werkzeug [testenv:coverage] @@ -48,4 +49,5 @@ commands = deps = mypy python-socks + trio werkzeug