diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index bf50bd6f5..05947f3a0 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -12,7 +12,7 @@ from typing import Any, Callable, Literal, cast from ..client import ClientProtocol, backoff -from ..datastructures import Headers, HeadersLike +from ..datastructures import HeadersLike from ..exceptions import ( InvalidMessage, InvalidProxyMessage, @@ -23,12 +23,13 @@ ) from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import build_authorization_basic, build_host, validate_subprotocols +from ..headers import validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol -from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri +from ..uri import WebSocketURI, parse_uri from .compatibility import TimeoutError, asyncio_timeout from .connection import Connection @@ -721,25 +722,6 @@ async def connect_socks_proxy( raise ProxyError("failed to connect to SOCKS proxy") from exc -def prepare_connect_request( - proxy: Proxy, - ws_uri: WebSocketURI, - user_agent_header: str | None = None, -) -> bytes: - host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) - headers = Headers() - headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) - if user_agent_header is not None: - headers["User-Agent"] = user_agent_header - if proxy.username is not None: - assert proxy.password is not None # enforced by parse_proxy() - headers["Proxy-Authorization"] = build_authorization_basic( - proxy.username, proxy.password - ) - # We cannot use the Request class because it supports only GET requests. - return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() - - class HTTPProxyConnection(asyncio.Protocol): def __init__( self, diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index af26d5d7a..f27cb2e7c 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -35,7 +35,7 @@ def __len__(self) -> int: return len(self.queue) def put(self, item: T) -> None: - """Put an item into the queue without waiting.""" + """Put an item into the queue.""" self.queue.append(item) if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_result(None) @@ -81,8 +81,7 @@ class Assembler: """ - # coverage reports incorrectly: "line NN didn't jump to the function exit" - def __init__( # pragma: no cover + def __init__( self, high: int | None = None, low: int | None = None, @@ -155,7 +154,7 @@ async def get(self, decode: bool | None = None) -> Data: # until get() fetches a complete message or is canceled. try: - # First frame + # Fetch the first frame. frame = await self.frames.get(not self.closed) self.maybe_resume() assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY @@ -163,7 +162,7 @@ async def get(self, decode: bool | None = None) -> Data: decode = frame.opcode is OP_TEXT frames = [frame] - # Following frames, for fragmented messages + # Fetch subsequent frames for fragmented messages. while not frame.fin: try: frame = await self.frames.get(not self.closed) @@ -230,7 +229,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. - # First frame + # Yield the first frame. try: frame = await self.frames.get(not self.closed) except asyncio.CancelledError: @@ -247,7 +246,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # Convert to bytes when frame.data is a bytearray. yield bytes(frame.data) - # Following frames, for fragmented messages + # Yield subsequent frames for fragmented messages. while not frame.fin: # We cannot handle asyncio.CancelledError because we don't buffer # previous fragments — we're streaming them. Canceling get_iter() @@ -280,22 +279,22 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" - # Skip if flow control is disabled + # Skip if flow control is disabled. if self.high is None: return - # Check for "> high" to support high = 0 + # Check for "> high" to support high = 0. if len(self.frames) > self.high and not self.paused: self.paused = True self.pause() def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" - # Skip if flow control is disabled + # Skip if flow control is disabled. if self.low is None: return - # Check for "<= low" to support low = 0 + # Check for "<= low" to support low = 0. if len(self.frames) <= self.low and self.paused: self.paused = False self.resume() diff --git a/src/websockets/proxy.py b/src/websockets/proxy.py new file mode 100644 index 000000000..a343b37bc --- /dev/null +++ b/src/websockets/proxy.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import dataclasses +import urllib.parse +import urllib.request + +from .datastructures import Headers +from .exceptions import InvalidProxy +from .headers import build_authorization_basic, build_host +from .http11 import USER_AGENT +from .uri import DELIMS, WebSocketURI + + +__all__ = ["get_proxy", "parse_proxy", "Proxy"] + + +@dataclasses.dataclass +class Proxy: + """ + Proxy address. + + Attributes: + scheme: ``"socks5h"``, ``"socks5"``, ``"socks4a"``, ``"socks4"``, + ``"https"``, or ``"http"``. + host: Normalized to lower case. + port: Always set even if it's the default. + username: Available when the proxy address contains `User Information`_. + password: Available when the proxy address contains `User Information`_. + + .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 + + """ + + scheme: str + host: str + port: int + username: str | None = None + password: str | None = None + + @property + def user_info(self) -> tuple[str, str] | None: + if self.username is None: + return None + assert self.password is not None + return (self.username, self.password) + + +def parse_proxy(proxy: str) -> Proxy: + """ + Parse and validate a proxy. + + Args: + proxy: proxy. + + Returns: + Parsed proxy. + + Raises: + InvalidProxy: If ``proxy`` isn't a valid proxy. + + """ + parsed = urllib.parse.urlparse(proxy) + if parsed.scheme not in ["socks5h", "socks5", "socks4a", "socks4", "https", "http"]: + raise InvalidProxy(proxy, f"scheme {parsed.scheme} isn't supported") + if parsed.hostname is None: + raise InvalidProxy(proxy, "hostname isn't provided") + if parsed.path not in ["", "/"]: + raise InvalidProxy(proxy, "path is meaningless") + if parsed.query != "": + raise InvalidProxy(proxy, "query is meaningless") + if parsed.fragment != "": + raise InvalidProxy(proxy, "fragment is meaningless") + + scheme = parsed.scheme + host = parsed.hostname + port = parsed.port or (443 if parsed.scheme == "https" else 80) + username = parsed.username + password = parsed.password + # urllib.parse.urlparse accepts URLs with a username but without a + # password. This doesn't make sense for HTTP Basic Auth credentials. + if username is not None and password is None: + raise InvalidProxy(proxy, "username provided without password") + + try: + proxy.encode("ascii") + except UnicodeEncodeError: + # Input contains non-ASCII characters. + # It must be an IRI. Convert it to a URI. + host = host.encode("idna").decode() + if username is not None: + assert password is not None + username = urllib.parse.quote(username, safe=DELIMS) + password = urllib.parse.quote(password, safe=DELIMS) + + return Proxy(scheme, host, port, username, password) + + +def get_proxy(uri: WebSocketURI) -> str | None: + """ + Return the proxy to use for connecting to the given WebSocket URI, if any. + + """ + if urllib.request.proxy_bypass(f"{uri.host}:{uri.port}"): + return None + + # According to the _Proxy Usage_ section of RFC 6455, use a SOCKS5 proxy if + # available, else favor the proxy for HTTPS connections over the proxy for + # HTTP connections. + + # The priority of a proxy for WebSocket connections is unspecified. We give + # it the highest priority. This makes it easy to configure a specific proxy + # for websockets. + + # getproxies() may return SOCKS proxies as {"socks": "http://host:port"} or + # as {"https": "socks5h://host:port"} depending on whether they're declared + # in the operating system or in environment variables. + + proxies = urllib.request.getproxies() + if uri.secure: + schemes = ["wss", "socks", "https"] + else: + schemes = ["ws", "socks", "https", "http"] + + for scheme in schemes: + proxy = proxies.get(scheme) + if proxy is not None: + if scheme == "socks" and proxy.startswith("http://"): + proxy = "socks5h://" + proxy[7:] + return proxy + else: + return None + + +def prepare_connect_request( + proxy: Proxy, + ws_uri: WebSocketURI, + user_agent_header: str | None = USER_AGENT, +) -> bytes: + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) + headers = Headers() + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if user_agent_header is not None: + headers["User-Agent"] = user_agent_header + if proxy.username is not None: + assert proxy.password is not None # enforced by parse_proxy() + headers["Proxy-Authorization"] = build_authorization_basic( + proxy.username, proxy.password + ) + # We cannot use the Request class because it supports only GET requests. + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 8042a3744..b3fff44ee 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -8,16 +8,17 @@ from typing import Any, Callable, Literal, TypeVar, cast from ..client import ClientProtocol -from ..datastructures import Headers, HeadersLike +from ..datastructures import HeadersLike from ..exceptions import InvalidProxyMessage, InvalidProxyStatus, ProxyError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import build_authorization_basic, build_host, validate_subprotocols +from ..headers import validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request from ..streams import StreamReader from ..typing import BytesLike, LoggerLike, Origin, Subprotocol -from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri +from ..uri import WebSocketURI, parse_uri from .connection import Connection from .utils import Deadline @@ -476,25 +477,6 @@ def connect_socks_proxy( raise ProxyError("failed to connect to SOCKS proxy") from exc -def prepare_connect_request( - proxy: Proxy, - ws_uri: WebSocketURI, - user_agent_header: str | None = None, -) -> bytes: - host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) - headers = Headers() - headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) - if user_agent_header is not None: - headers["User-Agent"] = user_agent_header - if proxy.username is not None: - assert proxy.password is not None # enforced by parse_proxy() - headers["Proxy-Authorization"] = build_authorization_basic( - proxy.username, proxy.password - ) - # We cannot use the Request class because it supports only GET requests. - return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() - - def read_connect_response(sock: socket.socket, deadline: Deadline) -> Response: reader = StreamReader() parser = Response.parse( diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index c4d04bc83..d95519f63 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -165,7 +165,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: try: deadline = Deadline(timeout) - # First frame + # Fetch the first frame. frame = self.get_next_frame(deadline.timeout(raise_if_elapsed=False)) with self.mutex: self.maybe_resume() @@ -174,7 +174,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: decode = frame.opcode is OP_TEXT frames = [frame] - # Following frames, for fragmented messages + # Fetch subsequent frames for fragmented messages. while not frame.fin: try: frame = self.get_next_frame( @@ -245,7 +245,7 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: # If get_iter() raises an exception e.g. in decoder.decode(), # get_in_progress remains set and the connection becomes unusable. - # First frame + # Yield the first frame. frame = self.get_next_frame() with self.mutex: self.maybe_resume() @@ -259,7 +259,7 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: # Convert to bytes when frame.data is a bytearray. yield bytes(frame.data) - # Following frames, for fragmented messages + # Yield subsequent frames for fragmented messages. while not frame.fin: frame = self.get_next_frame() with self.mutex: @@ -300,26 +300,26 @@ def put(self, frame: Frame) -> None: def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" - # Skip if flow control is disabled + # Skip if flow control is disabled. if self.high is None: return assert self.mutex.locked() - # Check for "> high" to support high = 0 + # Check for "> high" to support high = 0. if self.frames.qsize() > self.high and not self.paused: self.paused = True self.pause() def maybe_resume(self) -> None: """Resume the writer if queue is below the low water mark.""" - # Skip if flow control is disabled + # Skip if flow control is disabled. if self.low is None: return assert self.mutex.locked() - # Check for "<= low" to support low = 0 + # Check for "<= low" to support low = 0. if self.frames.qsize() <= self.low and self.paused: self.paused = False self.resume() diff --git a/src/websockets/trio/__init__.py b/src/websockets/trio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/websockets/trio/messages.py b/src/websockets/trio/messages.py new file mode 100644 index 000000000..42423a856 --- /dev/null +++ b/src/websockets/trio/messages.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +import codecs +import math +from collections.abc import AsyncIterator +from typing import Any, Callable, Literal, overload + +import trio + +from ..exceptions import ConcurrencyError +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + + +class Assembler: + """ + Assemble messages from frames. + + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + + """ + + def __init__( + self, + high: int | None = None, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: + # Queue of incoming frames. + self.send_frames: trio.MemorySendChannel[Frame] + self.recv_frames: trio.MemoryReceiveChannel[Frame] + self.send_frames, self.recv_frames = trio.open_memory_channel(math.inf) + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if high is not None and low is None: + low = high // 4 + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # This flag marks the end of the connection. + self.closed = False + + @overload + async def get(self, decode: Literal[True]) -> str: ... + + @overload + async def get(self, decode: Literal[False]) -> bytes: ... + + @overload + async def get(self, decode: bool | None = None) -> Data: ... + + async def get(self, decode: bool | None = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + + """ + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or is canceled. + + try: + # Fetch the first frame. + try: + frame = await self.recv_frames.receive() + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Fetch subsequent frames for fragmented messages. + while not frame.fin: + try: + frame = await self.recv_frames.receive() + except trio.Cancelled: + # Put frames already received back into the queue + # so that future calls to get() can return them. + # Bypass the statistics() method for performance. + state = self.send_frames._state + assert not state.receive_tasks, "no task should receive" + assert not state.data, "queue should be empty" + for frame in frames: + self.send_frames.send_nowait(frame) + raise + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False + + # This converts frame.data to bytes when it's a bytearray. + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data + + @overload + def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` asynchronously yields a + :class:`str` or :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + + """ + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or is canceled. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + + # Yield the first frame. + try: + frame = await self.recv_frames.receive() + except trio.Cancelled: + self.get_in_progress = False + raise + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) + + # Yield subsequent frames for fragmented messages. + while not frame.fin: + # We cannot handle trio.Cancelled because we don't buffer + # previous fragments — we're streaming them. Canceling get_iter() + # here will leave the assembler in a stuck state. Future calls to + # get() or get_iter() will raise ConcurrencyError. + try: + frame = await self.recv_frames.receive() + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) + + self.get_in_progress = False + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + Raises: + EOFError: If the stream of frames has ended. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + self.send_frames.send_nowait(frame) + self.maybe_pause() + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled. + if self.high is None: + return + + # Bypass the statistics() method for performance. + # Check for "> high" to support high = 0. + if len(self.send_frames._state.data) > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled. + if self.low is None: + return + + # Bypass the statistics() method for performance. + # Check for "<= low" to support low = 0. + if len(self.send_frames._state.data) <= self.low and self.paused: + self.paused = False + self.resume() + + def close(self) -> None: + """ + End the stream of frames. + + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`EOFError`. + + """ + if self.closed: + return + + self.closed = True + + # Unblock get() or get_iter(). + self.send_frames.close() diff --git a/src/websockets/trio/utils.py b/src/websockets/trio/utils.py new file mode 100644 index 000000000..8f3bdd822 --- /dev/null +++ b/src/websockets/trio/utils.py @@ -0,0 +1,42 @@ +import sys + +import trio + + +if sys.version_info[:2] < (3, 11): # pragma: no cover + from exceptiongroup import BaseExceptionGroup + + +__all__ = ["race_events"] + + +# Based on https://trio.readthedocs.io/en/stable/reference-core.html#custom-supervisors + + +async def jockey(event: trio.Event, cancel_scope: trio.CancelScope) -> None: + await event.wait() + cancel_scope.cancel() + + +async def race_events(*events: trio.Event) -> None: + """ + Wait for any of the given events to be set. + + Args: + *events: The events to wait for. + + """ + if not events: + raise ValueError("no events provided") + + try: + async with trio.open_nursery() as nursery: + for event in events: + nursery.start_soon(jockey, event, nursery.cancel_scope) + except BaseExceptionGroup as exc: + try: + trio._util.raise_single_exception_from_group(exc) + except trio._util.MultipleExceptionError: # pragma: no cover + raise AssertionError( + "race_events should be canceled; please file a bug report" + ) from exc diff --git a/src/websockets/uri.py b/src/websockets/uri.py index b925b99b5..f85e16810 100644 --- a/src/websockets/uri.py +++ b/src/websockets/uri.py @@ -2,9 +2,8 @@ import dataclasses import urllib.parse -import urllib.request -from .exceptions import InvalidProxy, InvalidURI +from .exceptions import InvalidURI __all__ = ["parse_uri", "WebSocketURI"] @@ -106,120 +105,3 @@ def parse_uri(uri: str) -> WebSocketURI: password = urllib.parse.quote(password, safe=DELIMS) return WebSocketURI(secure, host, port, path, query, username, password) - - -@dataclasses.dataclass -class Proxy: - """ - Proxy. - - Attributes: - scheme: ``"socks5h"``, ``"socks5"``, ``"socks4a"``, ``"socks4"``, - ``"https"``, or ``"http"``. - host: Normalized to lower case. - port: Always set even if it's the default. - username: Available when the proxy address contains `User Information`_. - password: Available when the proxy address contains `User Information`_. - - .. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1 - - """ - - scheme: str - host: str - port: int - username: str | None = None - password: str | None = None - - @property - def user_info(self) -> tuple[str, str] | None: - if self.username is None: - return None - assert self.password is not None - return (self.username, self.password) - - -def parse_proxy(proxy: str) -> Proxy: - """ - Parse and validate a proxy. - - Args: - proxy: proxy. - - Returns: - Parsed proxy. - - Raises: - InvalidProxy: If ``proxy`` isn't a valid proxy. - - """ - parsed = urllib.parse.urlparse(proxy) - if parsed.scheme not in ["socks5h", "socks5", "socks4a", "socks4", "https", "http"]: - raise InvalidProxy(proxy, f"scheme {parsed.scheme} isn't supported") - if parsed.hostname is None: - raise InvalidProxy(proxy, "hostname isn't provided") - if parsed.path not in ["", "/"]: - raise InvalidProxy(proxy, "path is meaningless") - if parsed.query != "": - raise InvalidProxy(proxy, "query is meaningless") - if parsed.fragment != "": - raise InvalidProxy(proxy, "fragment is meaningless") - - scheme = parsed.scheme - host = parsed.hostname - port = parsed.port or (443 if parsed.scheme == "https" else 80) - username = parsed.username - password = parsed.password - # urllib.parse.urlparse accepts URLs with a username but without a - # password. This doesn't make sense for HTTP Basic Auth credentials. - if username is not None and password is None: - raise InvalidProxy(proxy, "username provided without password") - - try: - proxy.encode("ascii") - except UnicodeEncodeError: - # Input contains non-ASCII characters. - # It must be an IRI. Convert it to a URI. - host = host.encode("idna").decode() - if username is not None: - assert password is not None - username = urllib.parse.quote(username, safe=DELIMS) - password = urllib.parse.quote(password, safe=DELIMS) - - return Proxy(scheme, host, port, username, password) - - -def get_proxy(uri: WebSocketURI) -> str | None: - """ - Return the proxy to use for connecting to the given WebSocket URI, if any. - - """ - if urllib.request.proxy_bypass(f"{uri.host}:{uri.port}"): - return None - - # According to the _Proxy Usage_ section of RFC 6455, use a SOCKS5 proxy if - # available, else favor the proxy for HTTPS connections over the proxy for - # HTTP connections. - - # The priority of a proxy for WebSocket connections is unspecified. We give - # it the highest priority. This makes it easy to configure a specific proxy - # for websockets. - - # getproxies() may return SOCKS proxies as {"socks": "http://host:port"} or - # as {"https": "socks5h://host:port"} depending on whether they're declared - # in the operating system or in environment variables. - - proxies = urllib.request.getproxies() - if uri.secure: - schemes = ["wss", "socks", "https"] - else: - schemes = ["ws", "socks", "https", "http"] - - for scheme in schemes: - proxy = proxies.get(scheme) - if proxy is not None: - if scheme == "socks" and proxy.startswith("http://"): - proxy = "socks5h://" + proxy[7:] - return proxy - else: - return None diff --git a/tests/__init__.py b/tests/__init__.py index bb1866f2d..83b10efb2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,6 @@ import logging import os +import tracemalloc format = "%(asctime)s %(levelname)s %(name)s %(message)s" @@ -12,3 +13,7 @@ level = logging.CRITICAL logging.basicConfig(format=format, level=level) + +if bool(os.environ.get("WEBSOCKETS_TRACE")): # pragma: no cover + # Trace allocations to debug resource warnings. + tracemalloc.start() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 6cad971c7..39fc953dc 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -19,9 +19,8 @@ from websockets.protocol import CLIENT, SERVER, Protocol, State from ..protocol import RecordingProtocol -from ..utils import MS +from ..utils import MS, alist from .connection import InterceptingConnection -from .utils import alist # Connection implements symmetrical behavior between clients and servers. diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index a90788d02..c862090a3 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import unittest import unittest.mock @@ -8,7 +9,7 @@ from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame -from .utils import alist +from ..utils import alist class SimpleQueueTests(unittest.IsolatedAsyncioTestCase): @@ -32,7 +33,7 @@ async def test_put_then_get(self): async def test_get_then_put(self): """get returns an item when it is put.""" getter_task = asyncio.create_task(self.queue.get()) - await asyncio.sleep(0) # let the task start + await asyncio.sleep(0) # let the event loop start getter_task self.queue.put(42) item = await getter_task self.assertEqual(item, 42) @@ -46,7 +47,7 @@ async def test_reset(self): async def test_abort(self): """abort throws an exception in get.""" getter_task = asyncio.create_task(self.queue.get()) - await asyncio.sleep(0) # let the task start + await asyncio.sleep(0) # let the event loop start getter_task self.queue.abort() with self.assertRaises(EOFError): await getter_task @@ -58,7 +59,7 @@ async def asyncSetUp(self): self.resume = unittest.mock.Mock() self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) - # Test get + # Test get. async def test_get_text_message_already_received(self): """get returns a text message that is already received.""" @@ -107,6 +108,7 @@ async def test_get_fragmented_binary_message_already_received(self): async def test_get_fragmented_text_message_not_received_yet(self): """get reassembles a fragmented text message when it is received.""" getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) @@ -116,6 +118,7 @@ async def test_get_fragmented_text_message_not_received_yet(self): async def test_get_fragmented_binary_message_not_received_yet(self): """get reassembles a fragmented binary message when it is received.""" getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) @@ -126,6 +129,7 @@ async def test_get_fragmented_text_message_being_received(self): """get reassembles a fragmented text message that is partially received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) message = await getter_task @@ -135,6 +139,7 @@ async def test_get_fragmented_binary_message_being_received(self): """get reassembles a fragmented binary message that is partially received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start getter_task self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) message = await getter_task @@ -161,11 +166,9 @@ async def test_get_resumes_reading(self): # queue is above the low-water mark await self.assembler.get() self.resume.assert_not_called() - # queue is at the low-water mark await self.assembler.get() self.resume.assert_called_once_with() - # queue is below the low-water mark await self.assembler.get() self.resume.assert_called_once_with() @@ -180,7 +183,6 @@ async def test_get_does_not_resume_reading(self): await self.assembler.get() await self.assembler.get() await self.assembler.get() - self.resume.assert_not_called() async def test_cancel_get_before_first_frame(self): @@ -192,7 +194,6 @@ async def test_cancel_get_before_first_frame(self): await getter_task self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - message = await self.assembler.get() self.assertEqual(message, "café") @@ -208,11 +209,10 @@ async def test_cancel_get_after_first_frame(self): self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - message = await self.assembler.get() self.assertEqual(message, "café") - # Test get_iter + # Test get_iter. async def test_get_iter_text_message_already_received(self): """get_iter yields a text message that is already received.""" @@ -261,42 +261,46 @@ async def test_get_iter_fragmented_binary_message_already_received(self): async def test_get_iter_fragmented_text_message_not_received_yet(self): """get_iter yields a fragmented text message when it is received.""" iterator = aiter(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assertEqual(await anext(iterator), "ca") - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assertEqual(await anext(iterator), "f") - self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(await anext(iterator), "é") + async with contextlib.aclosing(iterator): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") async def test_get_iter_fragmented_binary_message_not_received_yet(self): """get_iter yields a fragmented binary message when it is received.""" iterator = aiter(self.assembler.get_iter()) - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assertEqual(await anext(iterator), b"t") - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assertEqual(await anext(iterator), b"e") - self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(await anext(iterator), b"a") + async with contextlib.aclosing(iterator): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") async def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) iterator = aiter(self.assembler.get_iter()) - self.assertEqual(await anext(iterator), "ca") - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assertEqual(await anext(iterator), "f") - self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(await anext(iterator), "é") + async with contextlib.aclosing(iterator): + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") async def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) iterator = aiter(self.assembler.get_iter()) - self.assertEqual(await anext(iterator), b"t") - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assertEqual(await anext(iterator), b"e") - self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(await anext(iterator), b"a") + async with contextlib.aclosing(iterator): + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") async def test_get_iter_encoded_text_message(self): """get_iter yields a text message without UTF-8 decoding.""" @@ -321,18 +325,16 @@ async def test_get_iter_resumes_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) iterator = aiter(self.assembler.get_iter()) - - # queue is above the low-water mark - await anext(iterator) - self.resume.assert_not_called() - - # queue is at the low-water mark - await anext(iterator) - self.resume.assert_called_once_with() - - # queue is below the low-water mark - await anext(iterator) - self.resume.assert_called_once_with() + async with contextlib.aclosing(iterator): + # queue is above the low-water mark + await anext(iterator) + self.resume.assert_not_called() + # queue is at the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + # queue is below the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() async def test_get_iter_does_not_resume_reading(self): """get_iter does not resume reading when the low-water mark is unset.""" @@ -342,11 +344,11 @@ async def test_get_iter_does_not_resume_reading(self): self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) iterator = aiter(self.assembler.get_iter()) - await anext(iterator) - await anext(iterator) - await anext(iterator) - - self.resume.assert_not_called() + async with contextlib.aclosing(iterator): + await anext(iterator) + await anext(iterator) + await anext(iterator) + self.resume.assert_not_called() async def test_cancel_get_iter_before_first_frame(self): """get_iter can be canceled safely before reading the first frame.""" @@ -357,7 +359,6 @@ async def test_cancel_get_iter_before_first_frame(self): await getter_task self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - fragments = await alist(self.assembler.get_iter()) self.assertEqual(fragments, ["café"]) @@ -373,11 +374,10 @@ async def test_cancel_get_iter_after_first_frame(self): self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - # Test put + # Test put. async def test_put_pauses_reading(self): """put pauses reading when queue goes above the high-water mark.""" @@ -385,11 +385,9 @@ async def test_put_pauses_reading(self): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.pause.assert_not_called() - # queue is at the high-water mark self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.pause.assert_called_once_with() - # queue is above the high-water mark self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() @@ -402,10 +400,9 @@ async def test_put_does_not_pause_reading(self): self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) - self.pause.assert_not_called() - # Test termination + # Test termination. async def test_get_fails_when_interrupted_by_close(self): """get raises EOFError when close is called.""" @@ -467,7 +464,7 @@ async def test_get_iter_queued_fragmented_message_after_close(self): self.assertEqual(fragments, [b"t", b"e", b"a"]) async def test_get_partially_queued_fragmented_message_after_close(self): - """get raises EOF on a partial fragmented message after close is called.""" + """get raises EOFError on a partial fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.close() @@ -496,41 +493,41 @@ async def test_close_is_idempotent(self): self.assembler.close() self.assembler.close() - # Test (non-)concurrency + # Test (non-)concurrency. async def test_get_fails_when_get_is_running(self): """get cannot be called concurrently.""" - asyncio.create_task(self.assembler.get()) - await asyncio.sleep(0) + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start the task with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate + getter_task.cancel() async def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" - asyncio.create_task(alist(self.assembler.get_iter())) - await asyncio.sleep(0) + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start the task with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate + getter_task.cancel() async def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" - asyncio.create_task(self.assembler.get()) - await asyncio.sleep(0) + getter_task = asyncio.create_task(self.assembler.get()) + await asyncio.sleep(0) # let the event loop start the task with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate + getter_task.cancel() async def test_get_iter_fails_when_get_iter_is_running(self): """get_iter cannot be called concurrently.""" - asyncio.create_task(alist(self.assembler.get_iter())) - await asyncio.sleep(0) + getter_task = asyncio.create_task(alist(self.assembler.get_iter())) + await asyncio.sleep(0) # let the event loop start the task with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate + getter_task.cancel() - # Test setting limits + # Test setting limits. async def test_set_high_water_mark(self): """high sets the high-water and low-water marks.""" diff --git a/tests/asyncio/test_router.py b/tests/asyncio/test_router.py index 3dd766c96..b746052c1 100644 --- a/tests/asyncio/test_router.py +++ b/tests/asyncio/test_router.py @@ -8,9 +8,8 @@ from websockets.asyncio.router import * from websockets.exceptions import InvalidStatus -from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, temp_unix_socket_path +from ..utils import CLIENT_CONTEXT, SERVER_CONTEXT, alist, temp_unix_socket_path from .server import EvalShellMixin, get_uri, handler -from .utils import alist try: diff --git a/tests/asyncio/utils.py b/tests/asyncio/utils.py deleted file mode 100644 index a611bfc4b..000000000 --- a/tests/asyncio/utils.py +++ /dev/null @@ -1,5 +0,0 @@ -async def alist(async_iterable): - items = [] - async for item in async_iterable: - items.append(item) - return items diff --git a/tests/requirements.txt b/tests/requirements.txt index f375e6f69..77de5350b 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,2 +1,3 @@ -python-socks[asyncio] mitmproxy +python-socks[asyncio] +trio diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index e42784094..9bd5119fb 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,3 +1,4 @@ +import contextlib import time import unittest import unittest.mock @@ -16,7 +17,7 @@ def setUp(self): self.resume = unittest.mock.Mock() self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) - # Test get + # Test get. def test_get_text_message_already_received(self): """get returns a text message that is already received.""" @@ -40,7 +41,6 @@ def getter(): with self.run_in_thread(getter): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - self.assertEqual(message, "café") def test_get_binary_message_not_received_yet(self): @@ -53,7 +53,6 @@ def getter(): with self.run_in_thread(getter): self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assertEqual(message, b"tea") def test_get_fragmented_text_message_already_received(self): @@ -84,7 +83,6 @@ def getter(): self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(message, "café") def test_get_fragmented_binary_message_not_received_yet(self): @@ -99,7 +97,6 @@ def getter(): self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(message, b"tea") def test_get_fragmented_text_message_being_received(self): @@ -114,7 +111,6 @@ def getter(): with self.run_in_thread(getter): self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(message, "café") def test_get_fragmented_binary_message_being_received(self): @@ -129,7 +125,6 @@ def getter(): with self.run_in_thread(getter): self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(message, b"tea") def test_get_encoded_text_message(self): @@ -153,11 +148,9 @@ def test_get_resumes_reading(self): # queue is above the low-water mark self.assembler.get() self.resume.assert_not_called() - # queue is at the low-water mark self.assembler.get() self.resume.assert_called_once_with() - # queue is below the low-water mark self.assembler.get() self.resume.assert_called_once_with() @@ -172,7 +165,6 @@ def test_get_does_not_resume_reading(self): self.assembler.get() self.assembler.get() self.assembler.get() - self.resume.assert_not_called() def test_get_timeout_before_first_frame(self): @@ -181,7 +173,6 @@ def test_get_timeout_before_first_frame(self): self.assembler.get(timeout=MS) self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - message = self.assembler.get() self.assertEqual(message, "café") @@ -194,7 +185,6 @@ def test_get_timeout_after_first_frame(self): self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) - message = self.assembler.get() self.assertEqual(message, "café") @@ -224,7 +214,7 @@ def test_get_timeout_0_fragmented_message_partially_received(self): with self.assertRaises(TimeoutError): self.assembler.get(timeout=0) - # Test get_iter + # Test get_iter. def test_get_iter_text_message_already_received(self): """get_iter yields a text message that is already received.""" @@ -240,30 +230,26 @@ def test_get_iter_binary_message_already_received(self): def test_get_iter_text_message_not_received_yet(self): """get_iter yields a text message when it is received.""" - fragments = [] + fragments = None def getter(): nonlocal fragments - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + fragments = list(self.assembler.get_iter()) with self.run_in_thread(getter): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - self.assertEqual(fragments, ["café"]) def test_get_iter_binary_message_not_received_yet(self): """get_iter yields a binary message when it is received.""" - fragments = [] + fragments = None def getter(): nonlocal fragments - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + fragments = list(self.assembler.get_iter()) with self.run_in_thread(getter): self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assertEqual(fragments, [b"tea"]) def test_get_iter_fragmented_text_message_already_received(self): @@ -285,42 +271,46 @@ def test_get_iter_fragmented_binary_message_already_received(self): def test_get_iter_fragmented_text_message_not_received_yet(self): """get_iter yields a fragmented text message when it is received.""" iterator = self.assembler.get_iter() - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assertEqual(next(iterator), "ca") - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assertEqual(next(iterator), "f") - self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(next(iterator), "é") + with contextlib.closing(iterator): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_not_received_yet(self): """get_iter yields a fragmented binary message when it is received.""" iterator = self.assembler.get_iter() - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assertEqual(next(iterator), b"t") - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assertEqual(next(iterator), b"e") - self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(next(iterator), b"a") + with contextlib.closing(iterator): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) iterator = self.assembler.get_iter() - self.assertEqual(next(iterator), "ca") - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assertEqual(next(iterator), "f") - self.assembler.put(Frame(OP_CONT, b"\xa9")) - self.assertEqual(next(iterator), "é") + with contextlib.closing(iterator): + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) iterator = self.assembler.get_iter() - self.assertEqual(next(iterator), b"t") - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assertEqual(next(iterator), b"e") - self.assembler.put(Frame(OP_CONT, b"a")) - self.assertEqual(next(iterator), b"a") + with contextlib.closing(iterator): + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") def test_get_iter_encoded_text_message(self): """get_iter yields a text message without UTF-8 decoding.""" @@ -345,18 +335,16 @@ def test_get_iter_resumes_reading(self): self.assembler.put(Frame(OP_CONT, b"a")) iterator = self.assembler.get_iter() - - # queue is above the low-water mark - next(iterator) - self.resume.assert_not_called() - - # queue is at the low-water mark - next(iterator) - self.resume.assert_called_once_with() - - # queue is below the low-water mark - next(iterator) - self.resume.assert_called_once_with() + with contextlib.closing(iterator): + # queue is above the low-water mark + next(iterator) + self.resume.assert_not_called() + # queue is at the low-water mark + next(iterator) + self.resume.assert_called_once_with() + # queue is below the low-water mark + next(iterator) + self.resume.assert_called_once_with() def test_get_iter_does_not_resume_reading(self): """get_iter does not resume reading when the low-water mark is unset.""" @@ -366,13 +354,13 @@ def test_get_iter_does_not_resume_reading(self): self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) iterator = self.assembler.get_iter() - next(iterator) - next(iterator) - next(iterator) + with contextlib.closing(iterator): + next(iterator) + next(iterator) + next(iterator) + self.resume.assert_not_called() - self.resume.assert_not_called() - - # Test put + # Test put. def test_put_pauses_reading(self): """put pauses reading when queue goes above the high-water mark.""" @@ -380,11 +368,9 @@ def test_put_pauses_reading(self): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.pause.assert_not_called() - # queue is at the high-water mark self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.pause.assert_called_once_with() - # queue is above the high-water mark self.assembler.put(Frame(OP_CONT, b"a")) self.pause.assert_called_once_with() @@ -397,10 +383,9 @@ def test_put_does_not_pause_reading(self): self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) - self.pause.assert_not_called() - # Test termination + # Test termination. def test_get_fails_when_interrupted_by_close(self): """get raises EOFError when close is called.""" @@ -472,7 +457,7 @@ def test_get_iter_queued_fragmented_message_after_close(self): self.assertEqual(fragments, [b"t", b"e", b"a"]) def test_get_partially_queued_fragmented_message_after_close(self): - """get raises EOF on a partial fragmented message after close is called.""" + """get raises EOFError on a partial fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.close() @@ -501,7 +486,6 @@ def test_close_resumes_reading(self): self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) self.assembler.put(Frame(OP_TEXT, b"water")) - # queue is at the high-water mark assert self.assembler.paused @@ -513,7 +497,7 @@ def test_close_is_idempotent(self): self.assembler.close() self.assembler.close() - # Test (non-)concurrency + # Test (non-)concurrency. def test_get_fails_when_get_is_running(self): """get cannot be called concurrently.""" @@ -543,7 +527,7 @@ def test_get_iter_fails_when_get_iter_is_running(self): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - # Test setting limits + # Test setting limits. def test_set_high_water_mark(self): """high sets the high-water and low-water marks.""" diff --git a/tests/test_proxy.py b/tests/test_proxy.py new file mode 100644 index 000000000..e0d12898e --- /dev/null +++ b/tests/test_proxy.py @@ -0,0 +1,233 @@ +import os +import unittest +from unittest.mock import patch + +from websockets.exceptions import InvalidProxy +from websockets.http11 import USER_AGENT +from websockets.proxy import * +from websockets.proxy import prepare_connect_request +from websockets.uri import parse_uri + + +VALID_PROXIES = [ + ( + "http://proxy:8080", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "https://proxy:8080", + Proxy("https", "proxy", 8080, None, None), + ), + ( + "http://proxy", + Proxy("http", "proxy", 80, None, None), + ), + ( + "http://proxy:8080/", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "http://PROXY:8080", + Proxy("http", "proxy", 8080, None, None), + ), + ( + "http://user:pass@proxy:8080", + Proxy("http", "proxy", 8080, "user", "pass"), + ), + ( + "http://høst:8080/", + Proxy("http", "xn--hst-0na", 8080, None, None), + ), + ( + "http://üser:påss@høst:8080", + Proxy("http", "xn--hst-0na", 8080, "%C3%BCser", "p%C3%A5ss"), + ), +] + +INVALID_PROXIES = [ + "ws://proxy:8080", + "wss://proxy:8080", + "http://proxy:8080/path", + "http://proxy:8080/?query", + "http://proxy:8080/#fragment", + "http://user@proxy", + "http:///", +] + +PROXIES_WITH_USER_INFO = [ + ("http://proxy", None), + ("http://user:pass@proxy", ("user", "pass")), + ("http://üser:påss@høst", ("%C3%BCser", "p%C3%A5ss")), +] + +PROXY_ENVS = [ + ( + {"ws_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"ws_proxy": "http://proxy:8080"}, + "wss://example.com/", + None, + ), + ( + {"wss_proxy": "http://proxy:8080"}, + "ws://example.com/", + None, + ), + ( + {"wss_proxy": "http://proxy:8080"}, + "wss://example.com/", + "http://proxy:8080", + ), + ( + {"http_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"http_proxy": "http://proxy:8080"}, + "wss://example.com/", + None, + ), + ( + {"https_proxy": "http://proxy:8080"}, + "ws://example.com/", + "http://proxy:8080", + ), + ( + {"https_proxy": "http://proxy:8080"}, + "wss://example.com/", + "http://proxy:8080", + ), + ( + {"socks_proxy": "http://proxy:1080"}, + "ws://example.com/", + "socks5h://proxy:1080", + ), + ( + {"socks_proxy": "http://proxy:1080"}, + "wss://example.com/", + "socks5h://proxy:1080", + ), + ( + {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, + "ws://example.com/", + "http://proxy1:8080", + ), + ( + {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, + "wss://example.com/", + "http://proxy2:8080", + ), + ( + {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, + "ws://example.com/", + "http://proxy2:8080", + ), + ( + {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, + "wss://example.com/", + "http://proxy2:8080", + ), + ( + {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, + "ws://example.com/", + "socks5h://proxy:1080", + ), + ( + {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, + "wss://example.com/", + "socks5h://proxy:1080", + ), + ( + {"socks_proxy": "http://proxy:1080", "no_proxy": ".local"}, + "ws://example.local/", + None, + ), +] + +CONNECT_REQUESTS = [ + ( + {"https_proxy": "http://proxy:8080"}, + "ws://example.com/", + ( + b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com\r\n" + b"User-Agent: " + USER_AGENT.encode() + b"\r\n\r\n" + ), + ), + ( + {"https_proxy": "http://proxy:8080"}, + "wss://example.com/", + ( + b"CONNECT example.com:443 HTTP/1.1\r\nHost: example.com\r\n" + b"User-Agent: " + USER_AGENT.encode() + b"\r\n\r\n" + ), + ), + ( + {"https_proxy": "http://hello:iloveyou@proxy:8080"}, + "ws://example.com/", + ( + b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com\r\n" + b"User-Agent: " + USER_AGENT.encode() + b"\r\n" + b"Proxy-Authorization: Basic aGVsbG86aWxvdmV5b3U=\r\n\r\n" + ), + ), +] + +CONNECT_REQUESTS_WITH_USER_AGENT = [ + ( + "Smith", + ( + b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com\r\n" + b"User-Agent: Smith\r\n\r\n" + ), + ), + ( + None, + b"CONNECT example.com:80 HTTP/1.1\r\nHost: example.com\r\n\r\n", + ), +] + + +class ProxyTests(unittest.TestCase): + def test_parse_valid_proxies(self): + for proxy, parsed in VALID_PROXIES: + with self.subTest(proxy=proxy): + self.assertEqual(parse_proxy(proxy), parsed) + + def test_parse_invalid_proxies(self): + for proxy in INVALID_PROXIES: + with self.subTest(proxy=proxy): + with self.assertRaises(InvalidProxy): + parse_proxy(proxy) + + def test_parse_proxy_user_info(self): + for proxy, user_info in PROXIES_WITH_USER_INFO: + with self.subTest(proxy=proxy): + self.assertEqual(parse_proxy(proxy).user_info, user_info) + + def test_get_proxy(self): + for environ, uri, proxy in PROXY_ENVS: + with patch.dict(os.environ, environ): + with self.subTest(environ=environ, uri=uri): + self.assertEqual(get_proxy(parse_uri(uri)), proxy) + + def test_prepare_connect_request(self): + for environ, uri, request in CONNECT_REQUESTS: + with patch.dict(os.environ, environ): + with self.subTest(environ=environ, uri=uri): + uri = parse_uri(uri) + proxy = parse_proxy(get_proxy(uri)) + self.assertEqual(prepare_connect_request(proxy, uri), request) + + def test_prepare_connect_request_with_user_agent(self): + for user_agent_header, request in CONNECT_REQUESTS_WITH_USER_AGENT: + with self.subTest(user_agent_header=user_agent_header): + uri = parse_uri("ws://example.com") + proxy = parse_proxy("http://proxy:8080") + self.assertEqual( + prepare_connect_request(proxy, uri, user_agent_header), + request, + ) diff --git a/tests/test_uri.py b/tests/test_uri.py index 3ccf21158..057a17291 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -1,10 +1,7 @@ -import os import unittest -from unittest.mock import patch -from websockets.exceptions import InvalidProxy, InvalidURI +from websockets.exceptions import InvalidURI from websockets.uri import * -from websockets.uri import Proxy, get_proxy, parse_proxy VALID_URIS = [ @@ -75,145 +72,6 @@ ("ws://üser:påss@høst/", ("%C3%BCser", "p%C3%A5ss")), ] -VALID_PROXIES = [ - ( - "http://proxy:8080", - Proxy("http", "proxy", 8080, None, None), - ), - ( - "https://proxy:8080", - Proxy("https", "proxy", 8080, None, None), - ), - ( - "http://proxy", - Proxy("http", "proxy", 80, None, None), - ), - ( - "http://proxy:8080/", - Proxy("http", "proxy", 8080, None, None), - ), - ( - "http://PROXY:8080", - Proxy("http", "proxy", 8080, None, None), - ), - ( - "http://user:pass@proxy:8080", - Proxy("http", "proxy", 8080, "user", "pass"), - ), - ( - "http://høst:8080/", - Proxy("http", "xn--hst-0na", 8080, None, None), - ), - ( - "http://üser:påss@høst:8080", - Proxy("http", "xn--hst-0na", 8080, "%C3%BCser", "p%C3%A5ss"), - ), -] - -INVALID_PROXIES = [ - "ws://proxy:8080", - "wss://proxy:8080", - "http://proxy:8080/path", - "http://proxy:8080/?query", - "http://proxy:8080/#fragment", - "http://user@proxy", - "http:///", -] - -PROXIES_WITH_USER_INFO = [ - ("http://proxy", None), - ("http://user:pass@proxy", ("user", "pass")), - ("http://üser:påss@høst", ("%C3%BCser", "p%C3%A5ss")), -] - -PROXY_ENVS = [ - ( - {"ws_proxy": "http://proxy:8080"}, - "ws://example.com/", - "http://proxy:8080", - ), - ( - {"ws_proxy": "http://proxy:8080"}, - "wss://example.com/", - None, - ), - ( - {"wss_proxy": "http://proxy:8080"}, - "ws://example.com/", - None, - ), - ( - {"wss_proxy": "http://proxy:8080"}, - "wss://example.com/", - "http://proxy:8080", - ), - ( - {"http_proxy": "http://proxy:8080"}, - "ws://example.com/", - "http://proxy:8080", - ), - ( - {"http_proxy": "http://proxy:8080"}, - "wss://example.com/", - None, - ), - ( - {"https_proxy": "http://proxy:8080"}, - "ws://example.com/", - "http://proxy:8080", - ), - ( - {"https_proxy": "http://proxy:8080"}, - "wss://example.com/", - "http://proxy:8080", - ), - ( - {"socks_proxy": "http://proxy:1080"}, - "ws://example.com/", - "socks5h://proxy:1080", - ), - ( - {"socks_proxy": "http://proxy:1080"}, - "wss://example.com/", - "socks5h://proxy:1080", - ), - ( - {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, - "ws://example.com/", - "http://proxy1:8080", - ), - ( - {"ws_proxy": "http://proxy1:8080", "wss_proxy": "http://proxy2:8080"}, - "wss://example.com/", - "http://proxy2:8080", - ), - ( - {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, - "ws://example.com/", - "http://proxy2:8080", - ), - ( - {"http_proxy": "http://proxy1:8080", "https_proxy": "http://proxy2:8080"}, - "wss://example.com/", - "http://proxy2:8080", - ), - ( - {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, - "ws://example.com/", - "socks5h://proxy:1080", - ), - ( - {"https_proxy": "http://proxy:8080", "socks_proxy": "http://proxy:1080"}, - "wss://example.com/", - "socks5h://proxy:1080", - ), - ( - {"socks_proxy": "http://proxy:1080", "no_proxy": ".local"}, - "ws://example.local/", - None, - ), -] - class URITests(unittest.TestCase): def test_parse_valid_uris(self): @@ -236,25 +94,3 @@ def test_parse_user_info(self): for uri, user_info in URIS_WITH_USER_INFO: with self.subTest(uri=uri): self.assertEqual(parse_uri(uri).user_info, user_info) - - def test_parse_valid_proxies(self): - for proxy, parsed in VALID_PROXIES: - with self.subTest(proxy=proxy): - self.assertEqual(parse_proxy(proxy), parsed) - - def test_parse_invalid_proxies(self): - for proxy in INVALID_PROXIES: - with self.subTest(proxy=proxy): - with self.assertRaises(InvalidProxy): - parse_proxy(proxy) - - def test_parse_proxy_user_info(self): - for proxy, user_info in PROXIES_WITH_USER_INFO: - with self.subTest(proxy=proxy): - self.assertEqual(parse_proxy(proxy).user_info, user_info) - - def test_get_proxy(self): - for environ, uri, proxy in PROXY_ENVS: - with patch.dict(os.environ, environ): - with self.subTest(environ=environ, uri=uri): - self.assertEqual(get_proxy(parse_uri(uri)), proxy) diff --git a/tests/trio/__init__.py b/tests/trio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/trio/test_messages.py b/tests/trio/test_messages.py new file mode 100644 index 000000000..a5f6923b4 --- /dev/null +++ b/tests/trio/test_messages.py @@ -0,0 +1,571 @@ +import contextlib +import unittest +import unittest.mock + +import trio.testing + +from websockets.asyncio.compatibility import aiter, anext +from websockets.exceptions import ConcurrencyError +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from websockets.trio.messages import * + +from ..utils import alist +from .utils import IsolatedTrioTestCase + + +class AssemblerTests(IsolatedTrioTestCase): + def setUp(self): + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) + + # Test get. + + async def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, "café") + + async def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"tea")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + message = None + + async def getter(): + nonlocal message + message = await self.assembler.get() + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(message, b"tea") + + async def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + async def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + async def test_get_resumes_reading(self): + """get resumes reading when queue goes below the low-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + await self.assembler.get() + self.resume.assert_not_called() + # queue is at the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + # queue is below the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + async def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + await self.assembler.get() + await self.assembler.get() + await self.assembler.get() + self.resume.assert_not_called() + + async def test_cancel_get_before_first_frame(self): + """get can be canceled safely before reading the first frame.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(self.assembler.get) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_cancel_get_after_first_frame(self): + """get can be canceled safely after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + async with trio.open_nursery() as nursery: + nursery.start_soon(self.assembler.get) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + # Test get_iter. + + async def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + fragments = None + + async def getter(): + nonlocal fragments + fragments = await alist(self.assembler.get_iter()) + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + fragments = None + + async def getter(): + nonlocal fragments + fragments = await alist(self.assembler.get_iter()) + + self.nursery.start_soon(getter) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"tea")) + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + + async def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + + async def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) + + async def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) + + async def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the low-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + # queue is above the low-water mark + await anext(iterator) + self.resume.assert_not_called() + # queue is at the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + # queue is below the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + async def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = aiter(self.assembler.get_iter()) + async with contextlib.aclosing(iterator): + await anext(iterator) + await anext(iterator) + await anext(iterator) + self.resume.assert_not_called() + + async def test_cancel_get_iter_before_first_frame(self): + """get_iter can be canceled safely before reading the first frame.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(lambda: alist(self.assembler.get_iter())) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_cancel_get_iter_after_first_frame(self): + """get_iter cannot be canceled after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + async with trio.open_nursery() as nursery: + nursery.start_soon(lambda: alist(self.assembler.get_iter())) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + with self.assertRaises(ConcurrencyError): + await alist(self.assembler.get_iter()) + + # Test put. + + async def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() + + async def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_not_called() + + # Test termination. + + async def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + + async def closer(): + self.assembler.close() + + self.nursery.start_soon(closer) + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + + async def closer(): + self.assembler.close() + + self.nursery.start_soon(closer) + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOFError on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + async for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + + async def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + async def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close() + + # Test (non-)concurrency. + + async def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently.""" + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(self.assembler.get) + nursery.start_soon(self.assembler.get) + + async def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(lambda: alist(self.assembler.get_iter())) + nursery.start_soon(self.assembler.get) + + async def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(self.assembler.get) + nursery.start_soon(lambda: alist(self.assembler.get_iter())) + + async def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently.""" + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(lambda: alist(self.assembler.get_iter())) + nursery.start_soon(lambda: alist(self.assembler.get_iter())) + + # Test setting limits. + + async def test_set_high_water_mark(self): + """high sets the high-water and low-water marks.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + async def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) + + async def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + async def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + + async def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + async def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5) diff --git a/tests/trio/test_utils.py b/tests/trio/test_utils.py new file mode 100644 index 000000000..1ecdd80f1 --- /dev/null +++ b/tests/trio/test_utils.py @@ -0,0 +1,40 @@ +import trio.testing + +from websockets.trio.utils import * + +from .utils import IsolatedTrioTestCase + + +class UtilsTests(IsolatedTrioTestCase): + async def test_race_events(self): + event1 = trio.Event() + event2 = trio.Event() + done = trio.Event() + + async def waiter(): + await race_events(event1, event2) + done.set() + + async with trio.open_nursery() as nursery: + nursery.start_soon(waiter) + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(done.is_set()) + + event1.set() + await trio.testing.wait_all_tasks_blocked() + self.assertTrue(done.is_set()) + + async def test_race_events_cancelled(self): + event1 = trio.Event() + event2 = trio.Event() + + async def waiter(): + with trio.move_on_after(0): + await race_events(event1, event2) + + async with trio.open_nursery() as nursery: + nursery.start_soon(waiter) + + async def test_race_events_no_events(self): + with self.assertRaises(ValueError): + await race_events() diff --git a/tests/trio/utils.py b/tests/trio/utils.py new file mode 100644 index 000000000..1e0676d56 --- /dev/null +++ b/tests/trio/utils.py @@ -0,0 +1,62 @@ +import functools +import inspect +import sys +import unittest + +import trio.testing + + +if sys.version_info[:2] < (3, 11): # pragma: no cover + from exceptiongroup import BaseExceptionGroup + + +class IsolatedTrioTestCase(unittest.TestCase): + """ + Wrap test coroutines with :func:`trio.testing.trio_test` automatically. + + Create a nursery for each test, available in the :attr:`nursery` attribute. + + :meth:`asyncSetUp` and :meth:`asyncTearDown` are supported, similar to + :class:`unittest.IsolatedAsyncioTestCase`, but ``addAsyncCleanup`` isn't. + + """ + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + for name in unittest.defaultTestLoader.getTestCaseNames(cls): + test = getattr(cls, name) + if getattr(test, "converted_to_trio", False): # pragma: no cover + return + assert inspect.iscoroutinefunction(test) + setattr(cls, name, cls.convert_to_trio(test)) + + @staticmethod + def convert_to_trio(test): + @trio.testing.trio_test + @functools.wraps(test) + async def new_test(self, *args, **kwargs): + try: + # Provide a nursery so it's easy to start tasks. + async with trio.open_nursery() as self.nursery: + await self.asyncSetUp() + try: + return await test(self, *args, **kwargs) + finally: + await self.asyncTearDown() + except BaseExceptionGroup as exc: # pragma: no cover + # Unwrap exceptions like unittest.SkipTest. Multiple exceptions + # could occur is a test fails with multiple errors; this is OK; + # raise the original exception group in that case. + try: + trio._util.raise_single_exception_from_group(exc) + except trio._util.MultipleExceptionError: # pragma: no cover + raise exc + + new_test.converted_to_trio = True + return new_test + + async def asyncSetUp(self): + pass + + async def asyncTearDown(self): + pass diff --git a/tests/utils.py b/tests/utils.py index bd3bb0ed9..4db014337 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -112,6 +112,13 @@ def assertDeprecationWarning(self, message): self.assertEqual(str(warning.message), message) +async def alist(async_iterable): + items = [] + async for item in async_iterable: + items.append(item) + return items + + @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tox.ini b/tox.ini index ce4572e59..dce6698c3 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ pass_env = deps = py311,py312,py313,py314,coverage,maxi_cov: mitmproxy py311,py312,py313,py314,coverage,maxi_cov: python-socks[asyncio] + trio werkzeug [testenv:coverage] @@ -48,4 +49,5 @@ commands = deps = mypy python-socks + trio werkzeug