Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 87 additions & 53 deletions libp2p/stream_muxer/mplex/mplex_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,49 @@
)


class ReadWriteLock:
def __init__(self):
self._readers = 0
self._lock = trio.Lock() # Protects _readers
self._write_lock = trio.Lock()

async def acquire_read(self):
async with self._lock:
self._readers += 1
if self._readers == 1:
await self._write_lock.acquire()

async def release_read(self):
async with self._lock:
self._readers -= 1
if self._readers == 0:
self._write_lock.release()

async def acquire_write(self):
await self._write_lock.acquire()

def release_write(self):
self._write_lock.release()


class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
"""

name: str
stream_id: StreamID
# NOTE: All methods used here are part of `Mplex` which is a derived
# NOTE: All methods used here are part of Mplex which is a derived
# class of IMuxedConn. Ignoring this type assignment should not pose
# any risk.
muxed_conn: "Mplex" # type: ignore[assignment]
read_deadline: int | None
write_deadline: int | None

# TODO: Add lock for read/write to avoid interleaving receiving messages?
rw_lock: ReadWriteLock
close_lock: trio.Lock

# NOTE: `dataIn` is size of 8 in Go implementation.
# NOTE: dataIn is size of 8 in Go implementation.
incoming_data_channel: "trio.MemoryReceiveChannel[bytes]"

event_local_closed: trio.Event
Expand Down Expand Up @@ -80,6 +105,7 @@ def __init__(
self.event_remote_closed = trio.Event()
self.event_reset = trio.Event()
self.close_lock = trio.Lock()
self.rw_lock = ReadWriteLock()
self.incoming_data_channel = incoming_data_channel
self._buf = bytearray()

Expand All @@ -106,70 +132,78 @@ def _read_return_when_blocked(self) -> bytearray:

async def read(self, n: int | None = None) -> bytes:
"""
Read up to n bytes. Read possibly returns fewer than `n` bytes, if
there are not enough bytes in the Mplex buffer. If `n is None`, read
Read up to n bytes. Read possibly returns fewer than n bytes, if
there are not enough bytes in the Mplex buffer. If n is None, read
until EOF.

:param n: number of bytes to read
:return: bytes actually read
"""
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
f"`None` to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until there is
# no data, then return.
try:
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
await self.rw_lock.acquire_read()
try:
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read n must be non-negative or "
f"None to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until there is
# no data, then return.
try:
data = await self.incoming_data_channel.receive()
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when we are
# waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
raise MplexStreamEOF
except trio.WouldBlock:
# We know receive will be blocked here. Wait for data here with
# receive and catch all kinds of errors here.
try:
data = await self.incoming_data_channel.receive()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably incoming_data_channel is closed in reset when we are
# waiting for receive.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"incoming_data_channel is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
finally:
await self.rw_lock.release_read()

async def write(self, data: bytes) -> None:
"""
Write to stream.

:return: number of bytes written
"""
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
await self.rw_lock.acquire_write()
try:
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
finally:
self.rw_lock.release_write()

async def close(self) -> None:
"""
Expand All @@ -185,7 +219,7 @@ async def close(self) -> None:
flag = (
HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver
)
# TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown.
# TODO: Raise when muxed_conn.send_message fails and Mplex isn't shutdown.
await self.muxed_conn.send_message(flag, None, self.stream_id)

_is_remote_closed: bool
Expand Down