Skip to content
1 change: 1 addition & 0 deletions newsfragments/279.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``trio.open_unix_listener``, ``trio.serve_unix``, and ``trio.UnixSocketListener`` to support ``SOCK_STREAM`` `Unix domain sockets <https://en.wikipedia.org/wiki/Unix_domain_socket>`__
4 changes: 4 additions & 0 deletions src/trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@
serve_tcp as serve_tcp,
)
from ._highlevel_open_tcp_stream import open_tcp_stream as open_tcp_stream
from ._highlevel_open_unix_listeners import (
open_unix_listener as open_unix_listener,
serve_unix as serve_unix,
)
from ._highlevel_open_unix_stream import open_unix_socket as open_unix_socket
from ._highlevel_serve_listeners import serve_listeners as serve_listeners
from ._highlevel_socket import (
Expand Down
133 changes: 133 additions & 0 deletions src/trio/_highlevel_open_unix_listeners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING

import trio
import trio.socket as tsocket
from trio import TaskStatus

from ._highlevel_open_tcp_listeners import _compute_backlog

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable


try:
from trio.socket import AF_UNIX

HAS_UNIX = True
except ImportError:
HAS_UNIX = False


async def open_unix_listener(
path: str | bytes | os.PathLike[str] | os.PathLike[bytes],
*,
mode: int | None = None,
backlog: int | None = None,
) -> trio.SocketListener:
"""Create :class:`SocketListener` objects to listen for connections.
Opens a connection to the specified
`Unix domain socket <https://en.wikipedia.org/wiki/Unix_domain_socket>`__.

You must have read/write permission on the specified file to connect.

Args:

path (str): Filename of UNIX socket to create and listen on.
Absolute or relative paths may be used.

mode (int or None): The socket file permissions.
UNIX permissions are usually specified in octal numbers. If
you leave this as ``None``, Trio will not change the mode from
the operating system's default.

backlog (int or None): The listen backlog to use. If you leave this as
``None`` then Trio will pick a good default. (Currently:
whatever your system has configured as the maximum backlog.)

Returns:
:class:`UnixSocketListener`

Raises:
:class:`ValueError` If invalid arguments.
:class:`RuntimeError`: If AF_UNIX sockets are not supported.
:class:`FileNotFoundError`: If folder socket file is to be created in does not exist.
"""
if not HAS_UNIX:
raise RuntimeError("Unix sockets are not supported on this platform")

computed_backlog = _compute_backlog(backlog)

fspath = await trio.Path(os.fsdecode(path)).absolute()

folder = fspath.parent
if not await folder.exists():
raise FileNotFoundError(f"Socket folder does not exist: {folder!r}")

str_path = str(fspath)

# much more simplified logic vs tcp sockets - one socket family and only one
# possible location to connect to
sock = tsocket.socket(AF_UNIX, tsocket.SOCK_STREAM)
try:
await sock.bind(str_path)

if mode is not None:
await fspath.chmod(mode)

sock.listen(computed_backlog)

return trio.SocketListener(sock)
except BaseException:
sock.close()
if os.path.exists(str_path):
os.unlink(str_path)
raise


async def serve_unix(
handler: Callable[[trio.SocketStream], Awaitable[object]],
path: str | bytes | os.PathLike[str] | os.PathLike[bytes],
*,
backlog: int | None = None,
handler_nursery: trio.Nursery | None = None,
task_status: TaskStatus[list[trio.SocketListener]] = trio.TASK_STATUS_IGNORED,
) -> None:
"""Listen for incoming UNIX connections, and for each one start a task
running ``handler(stream)``.
This is a thin convenience wrapper around :func:`open_unix_listener` and
:func:`serve_listeners` – see them for full details.
.. warning::
If ``handler`` raises an exception, then this function doesn't do
anything special to catch it – so by default the exception will
propagate out and crash your server. If you don't want this, then catch
exceptions inside your ``handler``, or use a ``handler_nursery`` object
that responds to exceptions in some other way.
When used with ``nursery.start`` you get back the newly opened listeners.
Args:
handler: The handler to start for each incoming connection. Passed to
:func:`serve_listeners`.
path: The socket file name.
Passed to :func:`open_unix_listener`.
backlog: The listen backlog, or None to have a good default picked.
Passed to :func:`open_tcp_listener`.
handler_nursery: The nursery to start handlers in, or None to use an
internal nursery. Passed to :func:`serve_listeners`.
task_status: This function can be used with ``nursery.start``.
Returns:
This function only returns when cancelled.
Raises:
RuntimeError: If AF_UNIX sockets are not supported.
"""
if not HAS_UNIX:
raise RuntimeError("Unix sockets are not supported on this platform")

listener = await open_unix_listener(path, backlog=backlog)
await trio.serve_listeners(
handler,
[listener],
handler_nursery=handler_nursery,
task_status=task_status,
)
43 changes: 36 additions & 7 deletions src/trio/_highlevel_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import errno
from contextlib import contextmanager, suppress
from typing import TYPE_CHECKING, overload
from os import stat, unlink
from os.path import exists
from stat import S_ISSOCK
from typing import TYPE_CHECKING, Final, overload

import trio

Expand All @@ -16,7 +19,7 @@

from typing_extensions import Buffer

from ._socket import SocketType
from ._socket import AddressFormat, SocketType

# XX TODO: this number was picked arbitrarily. We should do experiments to
# tune it. (Or make it dynamic -- one idea is to start small and increase it
Expand All @@ -31,6 +34,8 @@
errno.ENOTSOCK,
}

HAS_UNIX: Final = hasattr(tsocket, "AF_UNIX")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is cleaner than the definition in _highlevel_open_unix_listeners.py



@contextmanager
def _translate_socket_errors_to_stream_errors() -> Generator[None, None, None]:
Expand Down Expand Up @@ -68,13 +73,15 @@ class SocketStream(HalfCloseableStream):

"""

__slots__ = ("_send_conflict_detector", "socket")

def __init__(self, socket: SocketType) -> None:
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketStream requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
raise ValueError("SocketStream requires a SOCK_STREAM socket")

self.socket = socket
self.socket: SocketType = socket
self._send_conflict_detector = ConflictDetector(
"another task is currently sending data on this SocketStream",
)
Expand Down Expand Up @@ -356,15 +363,22 @@ class SocketListener(Listener[SocketStream]):
and be listening.

Note that the :class:`SocketListener` "takes ownership" of the given
socket; closing the :class:`SocketListener` will also close the socket.
socket; closing the :class:`SocketListener` will also close the
socket, and if it's a Unix socket, it will also unlink the leftover
socket file that the Unix socket is bound to.

.. attribute:: socket

The Trio socket object that this stream wraps.

"""

def __init__(self, socket: SocketType) -> None:
__slots__ = ("socket",)

def __init__(
self,
socket: SocketType,
) -> None:
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketListener requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
Expand All @@ -378,7 +392,7 @@ def __init__(self, socket: SocketType) -> None:
if not listening:
raise ValueError("SocketListener requires a listening socket")

self.socket = socket
self.socket: SocketType = socket

async def accept(self) -> SocketStream:
"""Accept an incoming connection.
Expand Down Expand Up @@ -409,6 +423,21 @@ async def accept(self) -> SocketStream:
return SocketStream(sock)

async def aclose(self) -> None:
"""Close this listener and its underlying socket."""
"""Close this listener, its underlying socket, and for Unix sockets unlink the socket file."""
is_unix_socket = self.socket.family == getattr(tsocket, "AF_UNIX", None)

path: AddressFormat | None = None
if is_unix_socket:
# If unix socket, need to get path before we close socket
# or OS errors
path = self.socket.getsockname()
self.socket.close()
# If unix socket, clean up socket file that gets left behind.
if (
is_unix_socket
and path is not None
and exists(path)
and S_ISSOCK(stat(path).st_mode)
):
unlink(path)
await trio.lowlevel.checkpoint()
8 changes: 6 additions & 2 deletions src/trio/_tests/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,6 @@ def lookup_symbol(symbol: str) -> dict[str, str]:
trio.Process: {"args", "pid", "stderr", "stdin", "stdio", "stdout"},
trio.SSLListener: {"transport_listener"},
trio.SSLStream: {"transport_stream"},
trio.SocketListener: {"socket"},
trio.SocketStream: {"socket"},
trio.testing.MemoryReceiveStream: {"close_hook", "receive_some_hook"},
trio.testing.MemorySendStream: {
"close_hook",
Expand Down Expand Up @@ -525,6 +523,12 @@ def lookup_symbol(symbol: str) -> dict[str, str]:

print(f"\n{tool} can't see the following symbols in {module_name}:")
pprint(errors)
print(
f"""
If there are extra attributes listed, try checking to make sure this test
isn't ignoring them. If there are missing attributes, try looking for why
{tool} isn't seeing them compared to `inspect.getmembers`."""
)
assert not errors


Expand Down
Loading
Loading