Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/reference-core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,12 @@ inside a single process, and for that you can use

.. autofunction:: open_memory_channel(max_buffer_size)

Assigning the send and receive channels to separate variables usually
produces the most readable code. However, in situations where the pair
is preserved-- such as a collection of memory channels-- prefer named tuple
access (``pair.send_channel``, ``pair.receive_channel``) over indexed access
(``pair[0]``, ``pair[1]``).

Comment on lines +1223 to +1228
Copy link
Member Author

@belm0 belm0 Oct 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please note this point

It's why I didn't go crazy converting all the docs and tests to named tuple access-- tuple destructuring is more readable for local code dealing with a single memory channel.

However the named tuple still wins when you're dealing with many channels and keep the pairs intact.

.. note:: If you've used the :mod:`threading` or :mod:`asyncio`
modules, you may be familiar with :class:`queue.Queue` or
:class:`asyncio.Queue`. In Trio, :func:`open_memory_channel` is
Expand Down
5 changes: 5 additions & 0 deletions newsfragments/1771.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
open_memory_channel() now returns a named tuple with attributes ``send_channel``
and ``receive_channel``. This can be used to avoid indexed access of the
channel halves in some scenarios such as a collection of channels. (Note: when
dealing with a single memory channel, assigning the send and receive halves
to separate variables via destructuring is still considered more readable.)
150 changes: 120 additions & 30 deletions src/trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,36 @@

from collections import OrderedDict, deque
from math import inf
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Generic,
Tuple, # only needed for typechecking on <3.9
TypeVar,
)

import attrs
from outcome import Error, Value

import trio

from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T
from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType
from ._core import Abort, RaiseCancelT, Task, enable_ki_protection
from ._util import NoPublicConstructor, final, generic_function

if TYPE_CHECKING:
from collections.abc import Iterable
from types import TracebackType

from typing_extensions import Self


T = TypeVar("T")


def _open_memory_channel(
max_buffer_size: int | float, # noqa: PYI041
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
) -> MemoryChannelPair[T]:
"""Open a channel for passing objects between tasks within a process.

Memory channels are lightweight, cheap to allocate, and entirely
Expand Down Expand Up @@ -53,9 +59,8 @@ def _open_memory_channel(
see :ref:`channel-buffering` for more details. If in doubt, use 0.

Returns:
A pair ``(send_channel, receive_channel)``. If you have
trouble remembering which order these go in, remember: data
flows from left → right.
A named tuple ``(send_channel, receive_channel)``. The tuple ordering is
intended to match the image of data flowing from left → right.

In addition to the standard channel methods, all memory channel objects
provide a ``statistics()`` method, which returns an object with the
Expand All @@ -82,33 +87,12 @@ def _open_memory_channel(
if max_buffer_size < 0:
raise ValueError("max_buffer_size must be >= 0")
state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size)
return (
return MemoryChannelPair(
MemorySendChannel[T]._create(state),
MemoryReceiveChannel[T]._create(state),
)


# This workaround requires python3.9+, once older python versions are not supported
# or there's a better way of achieving type-checking on a generic factory function,
# it could replace the normal function header
if TYPE_CHECKING:
# written as a class so you can say open_memory_channel[int](5)
# Need to use Tuple instead of tuple due to CI check running on 3.8
class open_memory_channel(Tuple["MemorySendChannel[T]", "MemoryReceiveChannel[T]"]):
def __new__( # type: ignore[misc] # "must return a subtype"
cls, max_buffer_size: int | float # noqa: PYI041
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
return _open_memory_channel(max_buffer_size)

def __init__(self, max_buffer_size: int | float): # noqa: PYI041
...

else:
# apply the generic_function decorator to make open_memory_channel indexable
# so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime
open_memory_channel = generic_function(_open_memory_channel)


@attrs.frozen
class MemoryChannelStats:
current_buffer_used: int
Expand Down Expand Up @@ -144,9 +128,12 @@ def statistics(self) -> MemoryChannelStats:

@final
@attrs.define(eq=False, repr=False, slots=False)
class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor):
class MemorySendChannel(
SendChannel[SendType],
Generic[SendType],
metaclass=NoPublicConstructor,
):
_state: MemoryChannelState[SendType]
_closed: bool = False
# This is just the tasks waiting on *this* object. As compared to
# self._state.send_tasks, which includes tasks from this object and
# all clones.
Expand Down Expand Up @@ -287,7 +274,11 @@ async def aclose(self) -> None:

@final
@attrs.define(eq=False, repr=False, slots=False)
class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor):
class MemoryReceiveChannel(
ReceiveChannel[ReceiveType],
Generic[ReceiveType],
metaclass=NoPublicConstructor,
):
_state: MemoryChannelState[ReceiveType]
_closed: bool = False
_tasks: set[trio._core._run.Task] = attrs.Factory(set)
Expand Down Expand Up @@ -431,3 +422,102 @@ def close(self) -> None:
async def aclose(self) -> None:
self.close()
await trio.lowlevel.checkpoint()


# We cannot use generic named tuples before Py 3.11, manually define it.
class MemoryChannelPair(
Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]],
Generic[T],
):
"""Named tuple of send/receive memory channels."""

__slots__ = ()
_fields = ("send_channel", "receive_channel")

if TYPE_CHECKING:

@property
def send_channel(self) -> MemorySendChannel[T]:
"""Returns the sending channel half."""
return self[0]

@property
def receive_channel(self) -> MemoryReceiveChannel[T]:
"""Returns the receiving channel half."""
return self[1]

else: # More efficient
send_channel = property(itemgetter(0), doc="Returns the sending channel half.")
receive_channel = property(
itemgetter(1), doc="Returns the receiving channel half."
)

def __new__(
cls,
send_channel: MemorySendChannel[T],
receive_channel: MemoryReceiveChannel[T],
) -> Self:
"""Create new instance of MemoryChannelPair(send_channel, receive_channel)"""
return tuple.__new__(cls, (send_channel, receive_channel)) # type: ignore[type-var]

@classmethod
def _make(
cls,
iterable: Iterable[MemorySendChannel[T] | MemoryReceiveChannel[T]],
) -> Self:
"""Make a new MemoryChannelPair object from a sequence or iterable"""
send, rec = iterable
if isinstance(send, MemoryReceiveChannel) or isinstance(rec, MemorySendChannel):
raise TypeError("Channel order passed incorrectly.")
return tuple.__new__(cls, (send, rec)) # type: ignore[type-var]

def _replace(
self,
*,
send_channel: MemorySendChannel[T] | None = None,
receive_channel: MemoryReceiveChannel[T] | None = None,
) -> MemoryChannelPair[T]:
"""Return a new MemoryChannelPair object replacing specified fields with new values"""
if send_channel is None:
send_channel = self.send_channel
if receive_channel is None:
receive_channel = self.receive_channel
return tuple.__new__(
MemoryChannelPair,
(send_channel, receive_channel),
) # type: ignore[type-var]

def __repr__(self) -> str:
"""Return a nicely formatted representation string"""
return f"{self.__class__.__name__}(send_channel={self[0]!r}, receive_channel={self[1]!r})"

def _asdict(
self,
) -> OrderedDict[str, MemorySendChannel[T] | MemoryReceiveChannel[T]]:
"""Return a new OrderedDict which maps field names to their values."""
return OrderedDict(zip(self._fields, self))

def __getnewargs__(self) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
"""Return self as a plain tuple. Used by copy and pickle."""
return (self[0], self[1])


# This workaround requires python3.9+, once older python versions are not supported
# or there's a better way of achieving type-checking on a generic factory function,
# it could replace the normal function header
if TYPE_CHECKING:
# written as a class so that you can say open_memory_channel[int](5)
# Need to use Tuple instead of tuple due to CI check running on 3.8
class open_memory_channel(MemoryChannelPair[T]):
def __new__( # type: ignore[misc] # "must return a subtype"
cls, max_buffer_size: int | float # noqa: PYI041
) -> MemoryChannelPair[T]:
return _open_memory_channel(max_buffer_size)

def __init__(self, max_buffer_size: int | float): # noqa: PYI041
...

else:
# apply the generic_function decorator to make open_memory_channel indexable
# so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime
open_memory_channel = generic_function(_open_memory_channel)
5 changes: 5 additions & 0 deletions src/trio/_tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,8 @@ async def do_send(s: trio.MemorySendChannel[int], v: int) -> None:
assert await r.receive() == 1
with pytest.raises(trio.WouldBlock):
r.receive_nowait()


def test_named_tuple():
pair = open_memory_channel(0)
assert pair.send_channel, pair.receive_channel == pair
4 changes: 2 additions & 2 deletions src/trio/_tests/test_highlevel_serve_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ class MemoryListener(trio.abc.Listener[StapledMemoryStream]):
async def connect(self) -> StapledMemoryStream:
assert not self.closed
client, server = memory_stream_pair()
await self.queued_streams[0].send(server)
await self.queued_streams.send_channel.send(server)
return client

async def accept(self) -> StapledMemoryStream:
await trio.lowlevel.checkpoint()
assert not self.closed
if self.accept_hook is not None:
await self.accept_hook()
stream = await self.queued_streams[1].receive()
stream = await self.queued_streams.receive_channel.receive()
self.accepted_streams.append(stream)
return stream

Expand Down