Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 4 additions & 22 deletions src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any, Callable, Literal, cast

from ..client import ClientProtocol, backoff
from ..datastructures import Headers, HeadersLike
from ..datastructures import HeadersLike
from ..exceptions import (
InvalidMessage,
InvalidProxyMessage,
Expand All @@ -23,12 +23,13 @@
)
from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import build_authorization_basic, build_host, validate_subprotocols
from ..headers import validate_subprotocols
from ..http11 import USER_AGENT, Response
from ..protocol import CONNECTING, Event
from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request
from ..streams import StreamReader
from ..typing import LoggerLike, Origin, Subprotocol
from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri
from ..uri import WebSocketURI, parse_uri
from .compatibility import TimeoutError, asyncio_timeout
from .connection import Connection

Expand Down Expand Up @@ -721,25 +722,6 @@ async def connect_socks_proxy(
raise ProxyError("failed to connect to SOCKS proxy") from exc


def prepare_connect_request(
proxy: Proxy,
ws_uri: WebSocketURI,
user_agent_header: str | None = None,
) -> bytes:
host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True)
headers = Headers()
headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure)
if user_agent_header is not None:
headers["User-Agent"] = user_agent_header
if proxy.username is not None:
assert proxy.password is not None # enforced by parse_proxy()
headers["Proxy-Authorization"] = build_authorization_basic(
proxy.username, proxy.password
)
# We cannot use the Request class because it supports only GET requests.
return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize()


class HTTPProxyConnection(asyncio.Protocol):
def __init__(
self,
Expand Down
21 changes: 10 additions & 11 deletions src/websockets/asyncio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __len__(self) -> int:
return len(self.queue)

def put(self, item: T) -> None:
"""Put an item into the queue without waiting."""
"""Put an item into the queue."""
self.queue.append(item)
if self.get_waiter is not None and not self.get_waiter.done():
self.get_waiter.set_result(None)
Expand Down Expand Up @@ -81,8 +81,7 @@ class Assembler:

"""

# coverage reports incorrectly: "line NN didn't jump to the function exit"
def __init__( # pragma: no cover
def __init__(
self,
high: int | None = None,
low: int | None = None,
Expand Down Expand Up @@ -155,15 +154,15 @@ async def get(self, decode: bool | None = None) -> Data:
# until get() fetches a complete message or is canceled.

try:
# First frame
# Fetch the first frame.
frame = await self.frames.get(not self.closed)
self.maybe_resume()
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
if decode is None:
decode = frame.opcode is OP_TEXT
frames = [frame]

# Following frames, for fragmented messages
# Fetch subsequent frames for fragmented messages.
while not frame.fin:
try:
frame = await self.frames.get(not self.closed)
Expand Down Expand Up @@ -230,7 +229,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
# If get_iter() raises an exception e.g. in decoder.decode(),
# get_in_progress remains set and the connection becomes unusable.

# First frame
# Yield the first frame.
try:
frame = await self.frames.get(not self.closed)
except asyncio.CancelledError:
Expand All @@ -247,7 +246,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
# Convert to bytes when frame.data is a bytearray.
yield bytes(frame.data)

# Following frames, for fragmented messages
# Yield subsequent frames for fragmented messages.
while not frame.fin:
# We cannot handle asyncio.CancelledError because we don't buffer
# previous fragments — we're streaming them. Canceling get_iter()
Expand Down Expand Up @@ -280,22 +279,22 @@ def put(self, frame: Frame) -> None:

def maybe_pause(self) -> None:
"""Pause the writer if queue is above the high water mark."""
# Skip if flow control is disabled
# Skip if flow control is disabled.
if self.high is None:
return

# Check for "> high" to support high = 0
# Check for "> high" to support high = 0.
if len(self.frames) > self.high and not self.paused:
self.paused = True
self.pause()

def maybe_resume(self) -> None:
"""Resume the writer if queue is below the low water mark."""
# Skip if flow control is disabled
# Skip if flow control is disabled.
if self.low is None:
return

# Check for "<= low" to support low = 0
# Check for "<= low" to support low = 0.
if len(self.frames) <= self.low and self.paused:
self.paused = False
self.resume()
Expand Down
150 changes: 150 additions & 0 deletions src/websockets/proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from __future__ import annotations

import dataclasses
import urllib.parse
import urllib.request

from .datastructures import Headers
from .exceptions import InvalidProxy
from .headers import build_authorization_basic, build_host
from .http11 import USER_AGENT
from .uri import DELIMS, WebSocketURI


__all__ = ["get_proxy", "parse_proxy", "Proxy"]


@dataclasses.dataclass
class Proxy:
"""
Proxy address.

Attributes:
scheme: ``"socks5h"``, ``"socks5"``, ``"socks4a"``, ``"socks4"``,
``"https"``, or ``"http"``.
host: Normalized to lower case.
port: Always set even if it's the default.
username: Available when the proxy address contains `User Information`_.
password: Available when the proxy address contains `User Information`_.

.. _User Information: https://datatracker.ietf.org/doc/html/rfc3986#section-3.2.1

"""

scheme: str
host: str
port: int
username: str | None = None
password: str | None = None

@property
def user_info(self) -> tuple[str, str] | None:
if self.username is None:
return None
assert self.password is not None
return (self.username, self.password)


def parse_proxy(proxy: str) -> Proxy:
"""
Parse and validate a proxy.

Args:
proxy: proxy.

Returns:
Parsed proxy.

Raises:
InvalidProxy: If ``proxy`` isn't a valid proxy.

"""
parsed = urllib.parse.urlparse(proxy)
if parsed.scheme not in ["socks5h", "socks5", "socks4a", "socks4", "https", "http"]:
raise InvalidProxy(proxy, f"scheme {parsed.scheme} isn't supported")
if parsed.hostname is None:
raise InvalidProxy(proxy, "hostname isn't provided")
if parsed.path not in ["", "/"]:
raise InvalidProxy(proxy, "path is meaningless")
if parsed.query != "":
raise InvalidProxy(proxy, "query is meaningless")
if parsed.fragment != "":
raise InvalidProxy(proxy, "fragment is meaningless")

scheme = parsed.scheme
host = parsed.hostname
port = parsed.port or (443 if parsed.scheme == "https" else 80)
username = parsed.username
password = parsed.password
# urllib.parse.urlparse accepts URLs with a username but without a
# password. This doesn't make sense for HTTP Basic Auth credentials.
if username is not None and password is None:
raise InvalidProxy(proxy, "username provided without password")

try:
proxy.encode("ascii")
except UnicodeEncodeError:
# Input contains non-ASCII characters.
# It must be an IRI. Convert it to a URI.
host = host.encode("idna").decode()
if username is not None:
assert password is not None
username = urllib.parse.quote(username, safe=DELIMS)
password = urllib.parse.quote(password, safe=DELIMS)

return Proxy(scheme, host, port, username, password)


def get_proxy(uri: WebSocketURI) -> str | None:
"""
Return the proxy to use for connecting to the given WebSocket URI, if any.

"""
if urllib.request.proxy_bypass(f"{uri.host}:{uri.port}"):
return None

# According to the _Proxy Usage_ section of RFC 6455, use a SOCKS5 proxy if
# available, else favor the proxy for HTTPS connections over the proxy for
# HTTP connections.

# The priority of a proxy for WebSocket connections is unspecified. We give
# it the highest priority. This makes it easy to configure a specific proxy
# for websockets.

# getproxies() may return SOCKS proxies as {"socks": "http://host:port"} or
# as {"https": "socks5h://host:port"} depending on whether they're declared
# in the operating system or in environment variables.

proxies = urllib.request.getproxies()
if uri.secure:
schemes = ["wss", "socks", "https"]
else:
schemes = ["ws", "socks", "https", "http"]

for scheme in schemes:
proxy = proxies.get(scheme)
if proxy is not None:
if scheme == "socks" and proxy.startswith("http://"):
proxy = "socks5h://" + proxy[7:]
return proxy
else:
return None


def prepare_connect_request(
proxy: Proxy,
ws_uri: WebSocketURI,
user_agent_header: str | None = USER_AGENT,
) -> bytes:
host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True)
headers = Headers()
headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure)
if user_agent_header is not None:
headers["User-Agent"] = user_agent_header
if proxy.username is not None:
assert proxy.password is not None # enforced by parse_proxy()
headers["Proxy-Authorization"] = build_authorization_basic(
proxy.username, proxy.password
)
# We cannot use the Request class because it supports only GET requests.
return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize()
26 changes: 4 additions & 22 deletions src/websockets/sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from typing import Any, Callable, Literal, TypeVar, cast

from ..client import ClientProtocol
from ..datastructures import Headers, HeadersLike
from ..datastructures import HeadersLike
from ..exceptions import InvalidProxyMessage, InvalidProxyStatus, ProxyError
from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import build_authorization_basic, build_host, validate_subprotocols
from ..headers import validate_subprotocols
from ..http11 import USER_AGENT, Response
from ..protocol import CONNECTING, Event
from ..proxy import Proxy, get_proxy, parse_proxy, prepare_connect_request
from ..streams import StreamReader
from ..typing import BytesLike, LoggerLike, Origin, Subprotocol
from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri
from ..uri import WebSocketURI, parse_uri
from .connection import Connection
from .utils import Deadline

Expand Down Expand Up @@ -476,25 +477,6 @@ def connect_socks_proxy(
raise ProxyError("failed to connect to SOCKS proxy") from exc


def prepare_connect_request(
proxy: Proxy,
ws_uri: WebSocketURI,
user_agent_header: str | None = None,
) -> bytes:
host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True)
headers = Headers()
headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure)
if user_agent_header is not None:
headers["User-Agent"] = user_agent_header
if proxy.username is not None:
assert proxy.password is not None # enforced by parse_proxy()
headers["Proxy-Authorization"] = build_authorization_basic(
proxy.username, proxy.password
)
# We cannot use the Request class because it supports only GET requests.
return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize()


def read_connect_response(sock: socket.socket, deadline: Deadline) -> Response:
reader = StreamReader()
parser = Response.parse(
Expand Down
16 changes: 8 additions & 8 deletions src/websockets/sync/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
try:
deadline = Deadline(timeout)

# First frame
# Fetch the first frame.
frame = self.get_next_frame(deadline.timeout(raise_if_elapsed=False))
with self.mutex:
self.maybe_resume()
Expand All @@ -174,7 +174,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
decode = frame.opcode is OP_TEXT
frames = [frame]

# Following frames, for fragmented messages
# Fetch subsequent frames for fragmented messages.
while not frame.fin:
try:
frame = self.get_next_frame(
Expand Down Expand Up @@ -245,7 +245,7 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]:
# If get_iter() raises an exception e.g. in decoder.decode(),
# get_in_progress remains set and the connection becomes unusable.

# First frame
# Yield the first frame.
frame = self.get_next_frame()
with self.mutex:
self.maybe_resume()
Expand All @@ -259,7 +259,7 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]:
# Convert to bytes when frame.data is a bytearray.
yield bytes(frame.data)

# Following frames, for fragmented messages
# Yield subsequent frames for fragmented messages.
while not frame.fin:
frame = self.get_next_frame()
with self.mutex:
Expand Down Expand Up @@ -300,26 +300,26 @@ def put(self, frame: Frame) -> None:

def maybe_pause(self) -> None:
"""Pause the writer if queue is above the high water mark."""
# Skip if flow control is disabled
# Skip if flow control is disabled.
if self.high is None:
return

assert self.mutex.locked()

# Check for "> high" to support high = 0
# Check for "> high" to support high = 0.
if self.frames.qsize() > self.high and not self.paused:
self.paused = True
self.pause()

def maybe_resume(self) -> None:
"""Resume the writer if queue is below the low water mark."""
# Skip if flow control is disabled
# Skip if flow control is disabled.
if self.low is None:
return

assert self.mutex.locked()

# Check for "<= low" to support low = 0
# Check for "<= low" to support low = 0.
if self.frames.qsize() <= self.low and self.paused:
self.paused = False
self.resume()
Expand Down
Empty file added src/websockets/trio/__init__.py
Empty file.
Loading