Skip to content

Commit e011ef6

Browse files
committed
Add trio connection.
1 parent 899df77 commit e011ef6

File tree

11 files changed

+2767
-222
lines changed

11 files changed

+2767
-222
lines changed

docs/topics/keepalive.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ measured during the last exchange of Ping and Pong frames::
136136
Alternatively, you can measure the latency at any time by calling
137137
:attr:`~asyncio.connection.Connection.ping` and awaiting its result::
138138

139-
pong_waiter = await websocket.ping()
140-
latency = await pong_waiter
139+
pong_received = await websocket.ping()
140+
latency = await pong_received
141141

142142
Latency between a client and a server may increase for two reasons:
143143

src/websockets/asyncio/connection.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,10 @@ def __init__(
101101
self.close_deadline: float | None = None
102102

103103
# Protect sending fragmented messages.
104-
self.fragmented_send_waiter: asyncio.Future[None] | None = None
104+
self.send_in_progress: asyncio.Future[None] | None = None
105105

106106
# Mapping of ping IDs to pong waiters, in chronological order.
107-
self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {}
107+
self.pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {}
108108

109109
self.latency: float = 0
110110
"""
@@ -468,8 +468,8 @@ async def send(
468468
"""
469469
# While sending a fragmented message, prevent sending other messages
470470
# until all fragments are sent.
471-
while self.fragmented_send_waiter is not None:
472-
await asyncio.shield(self.fragmented_send_waiter)
471+
while self.send_in_progress is not None:
472+
await asyncio.shield(self.send_in_progress)
473473

474474
# Unfragmented message -- this case must be handled first because
475475
# strings and bytes-like objects are iterable.
@@ -502,8 +502,8 @@ async def send(
502502
except StopIteration:
503503
return
504504

505-
assert self.fragmented_send_waiter is None
506-
self.fragmented_send_waiter = self.loop.create_future()
505+
assert self.send_in_progress is None
506+
self.send_in_progress = self.loop.create_future()
507507
try:
508508
# First fragment.
509509
if isinstance(chunk, str):
@@ -549,8 +549,8 @@ async def send(
549549
raise
550550

551551
finally:
552-
self.fragmented_send_waiter.set_result(None)
553-
self.fragmented_send_waiter = None
552+
self.send_in_progress.set_result(None)
553+
self.send_in_progress = None
554554

555555
# Fragmented message -- async iterator.
556556

@@ -561,8 +561,8 @@ async def send(
561561
except StopAsyncIteration:
562562
return
563563

564-
assert self.fragmented_send_waiter is None
565-
self.fragmented_send_waiter = self.loop.create_future()
564+
assert self.send_in_progress is None
565+
self.send_in_progress = self.loop.create_future()
566566
try:
567567
# First fragment.
568568
if isinstance(chunk, str):
@@ -610,8 +610,8 @@ async def send(
610610
raise
611611

612612
finally:
613-
self.fragmented_send_waiter.set_result(None)
614-
self.fragmented_send_waiter = None
613+
self.send_in_progress.set_result(None)
614+
self.send_in_progress = None
615615

616616
else:
617617
raise TypeError("data must be str, bytes, iterable, or async iterable")
@@ -639,7 +639,7 @@ async def close(
639639
# The context manager takes care of waiting for the TCP connection
640640
# to terminate after calling a method that sends a close frame.
641641
async with self.send_context():
642-
if self.fragmented_send_waiter is not None:
642+
if self.send_in_progress is not None:
643643
self.protocol.fail(
644644
CloseCode.INTERNAL_ERROR,
645645
"close during fragmented message",
@@ -681,9 +681,9 @@ async def ping(self, data: DataLike | None = None) -> Awaitable[float]:
681681
682682
::
683683
684-
pong_waiter = await ws.ping()
684+
pong_received = await ws.ping()
685685
# only if you want to wait for the corresponding pong
686-
latency = await pong_waiter
686+
latency = await pong_received
687687
688688
Raises:
689689
ConnectionClosed: When the connection is closed.
@@ -700,19 +700,19 @@ async def ping(self, data: DataLike | None = None) -> Awaitable[float]:
700700

701701
async with self.send_context():
702702
# Protect against duplicates if a payload is explicitly set.
703-
if data in self.pong_waiters:
703+
if data in self.pending_pings:
704704
raise ConcurrencyError("already waiting for a pong with the same data")
705705

706706
# Generate a unique random payload otherwise.
707-
while data is None or data in self.pong_waiters:
707+
while data is None or data in self.pending_pings:
708708
data = struct.pack("!I", random.getrandbits(32))
709709

710-
pong_waiter = self.loop.create_future()
710+
pong_received = self.loop.create_future()
711711
# The event loop's default clock is time.monotonic(). Its resolution
712712
# is a bit low on Windows (~16ms). This is improved in Python 3.13.
713-
self.pong_waiters[data] = (pong_waiter, self.loop.time())
713+
self.pending_pings[data] = (pong_received, self.loop.time())
714714
self.protocol.send_ping(data)
715-
return pong_waiter
715+
return pong_received
716716

717717
async def pong(self, data: DataLike = b"") -> None:
718718
"""
@@ -761,7 +761,7 @@ def acknowledge_pings(self, data: bytes) -> None:
761761
762762
"""
763763
# Ignore unsolicited pong.
764-
if data not in self.pong_waiters:
764+
if data not in self.pending_pings:
765765
return
766766

767767
pong_timestamp = self.loop.time()
@@ -770,20 +770,20 @@ def acknowledge_pings(self, data: bytes) -> None:
770770
# Acknowledge all previous pings too in that case.
771771
ping_id = None
772772
ping_ids = []
773-
for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items():
773+
for ping_id, (pong_received, ping_timestamp) in self.pending_pings.items():
774774
ping_ids.append(ping_id)
775775
latency = pong_timestamp - ping_timestamp
776-
if not pong_waiter.done():
777-
pong_waiter.set_result(latency)
776+
if not pong_received.done():
777+
pong_received.set_result(latency)
778778
if ping_id == data:
779779
self.latency = latency
780780
break
781781
else:
782782
raise AssertionError("solicited pong not found in pings")
783783

784-
# Remove acknowledged pings from self.pong_waiters.
784+
# Remove acknowledged pings from self.pending_pings.
785785
for ping_id in ping_ids:
786-
del self.pong_waiters[ping_id]
786+
del self.pending_pings[ping_id]
787787

788788
def abort_pings(self) -> None:
789789
"""
@@ -795,16 +795,16 @@ def abort_pings(self) -> None:
795795
assert self.protocol.state is CLOSED
796796
exc = self.protocol.close_exc
797797

798-
for pong_waiter, _ping_timestamp in self.pong_waiters.values():
799-
if not pong_waiter.done():
800-
pong_waiter.set_exception(exc)
798+
for pong_received, _ping_timestamp in self.pending_pings.values():
799+
if not pong_received.done():
800+
pong_received.set_exception(exc)
801801
# If the exception is never retrieved, it will be logged when ping
802802
# is garbage-collected. This is confusing for users.
803803
# Given that ping is done (with an exception), canceling it does
804804
# nothing, but it prevents logging the exception.
805-
pong_waiter.cancel()
805+
pong_received.cancel()
806806

807-
self.pong_waiters.clear()
807+
self.pending_pings.clear()
808808

809809
async def keepalive(self) -> None:
810810
"""
@@ -825,7 +825,7 @@ async def keepalive(self) -> None:
825825
# connection to be closed before raising ConnectionClosed.
826826
# However, connection_lost() cancels keepalive_task before
827827
# it gets a chance to resume excuting.
828-
pong_waiter = await self.ping()
828+
pong_received = await self.ping()
829829
if self.debug:
830830
self.logger.debug("% sent keepalive ping")
831831

@@ -834,9 +834,9 @@ async def keepalive(self) -> None:
834834
async with asyncio_timeout(self.ping_timeout):
835835
# connection_lost cancels keepalive immediately
836836
# after setting a ConnectionClosed exception on
837-
# pong_waiter. A CancelledError is raised here,
837+
# pong_received. A CancelledError is raised here,
838838
# not a ConnectionClosed exception.
839-
latency = await pong_waiter
839+
latency = await pong_received
840840
self.logger.debug("% received keepalive pong")
841841
except asyncio.TimeoutError:
842842
if self.debug:
@@ -1205,7 +1205,7 @@ def broadcast(
12051205
if connection.protocol.state is not OPEN:
12061206
continue
12071207

1208-
if connection.fragmented_send_waiter is not None:
1208+
if connection.send_in_progress is not None:
12091209
if raise_exceptions:
12101210
exception = ConcurrencyError("sending a fragmented message")
12111211
exceptions.append(exception)

src/websockets/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def close_exc(self) -> ConnectionClosed:
254254

255255
# Public methods for receiving data.
256256

257-
def receive_data(self, data: bytes) -> None:
257+
def receive_data(self, data: bytes | bytearray) -> None:
258258
"""
259259
Receive data from the network.
260260

src/websockets/streams.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def at_eof(self) -> Generator[None, None, bool]:
112112
# tell if until either feed_data() or feed_eof() is called.
113113
yield
114114

115-
def feed_data(self, data: bytes) -> None:
115+
def feed_data(self, data: bytes | bytearray) -> None:
116116
"""
117117
Write data to the stream.
118118

src/websockets/sync/connection.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
self.send_in_progress = False
105105

106106
# Mapping of ping IDs to pong waiters, in chronological order.
107-
self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {}
107+
self.pending_pings: dict[bytes, tuple[threading.Event, float, bool]] = {}
108108

109109
self.latency: float = 0
110110
"""
@@ -633,8 +633,9 @@ def ping(
633633
634634
::
635635
636-
pong_event = ws.ping()
637-
pong_event.wait() # only if you want to wait for the pong
636+
pong_received = ws.ping()
637+
# only if you want to wait for the corresponding pong
638+
pong_received.wait()
638639
639640
Raises:
640641
ConnectionClosed: When the connection is closed.
@@ -651,17 +652,17 @@ def ping(
651652

652653
with self.send_context():
653654
# Protect against duplicates if a payload is explicitly set.
654-
if data in self.pong_waiters:
655+
if data in self.pending_pings:
655656
raise ConcurrencyError("already waiting for a pong with the same data")
656657

657658
# Generate a unique random payload otherwise.
658-
while data is None or data in self.pong_waiters:
659+
while data is None or data in self.pending_pings:
659660
data = struct.pack("!I", random.getrandbits(32))
660661

661-
pong_waiter = threading.Event()
662-
self.pong_waiters[data] = (pong_waiter, time.monotonic(), ack_on_close)
662+
pong_received = threading.Event()
663+
self.pending_pings[data] = (pong_received, time.monotonic(), ack_on_close)
663664
self.protocol.send_ping(data)
664-
return pong_waiter
665+
return pong_received
665666

666667
def pong(self, data: DataLike = b"") -> None:
667668
"""
@@ -711,7 +712,7 @@ def acknowledge_pings(self, data: bytes) -> None:
711712
"""
712713
with self.protocol_mutex:
713714
# Ignore unsolicited pong.
714-
if data not in self.pong_waiters:
715+
if data not in self.pending_pings:
715716
return
716717

717718
pong_timestamp = time.monotonic()
@@ -721,21 +722,21 @@ def acknowledge_pings(self, data: bytes) -> None:
721722
ping_id = None
722723
ping_ids = []
723724
for ping_id, (
724-
pong_waiter,
725+
pong_received,
725726
ping_timestamp,
726727
_ack_on_close,
727-
) in self.pong_waiters.items():
728+
) in self.pending_pings.items():
728729
ping_ids.append(ping_id)
729-
pong_waiter.set()
730+
pong_received.set()
730731
if ping_id == data:
731732
self.latency = pong_timestamp - ping_timestamp
732733
break
733734
else:
734735
raise AssertionError("solicited pong not found in pings")
735736

736-
# Remove acknowledged pings from self.pong_waiters.
737+
# Remove acknowledged pings from self.pending_pings.
737738
for ping_id in ping_ids:
738-
del self.pong_waiters[ping_id]
739+
del self.pending_pings[ping_id]
739740

740741
def acknowledge_pending_pings(self) -> None:
741742
"""
@@ -744,11 +745,11 @@ def acknowledge_pending_pings(self) -> None:
744745
"""
745746
assert self.protocol.state is CLOSED
746747

747-
for pong_waiter, _ping_timestamp, ack_on_close in self.pong_waiters.values():
748+
for pong_received, _ping_timestamp, ack_on_close in self.pending_pings.values():
748749
if ack_on_close:
749-
pong_waiter.set()
750+
pong_received.set()
750751

751-
self.pong_waiters.clear()
752+
self.pending_pings.clear()
752753

753754
def keepalive(self) -> None:
754755
"""
@@ -766,15 +767,14 @@ def keepalive(self) -> None:
766767
break
767768

768769
try:
769-
pong_waiter = self.ping(ack_on_close=True)
770+
pong_received = self.ping(ack_on_close=True)
770771
except ConnectionClosed:
771772
break
772773
if self.debug:
773774
self.logger.debug("% sent keepalive ping")
774775

775776
if self.ping_timeout is not None:
776-
#
777-
if pong_waiter.wait(self.ping_timeout):
777+
if pong_received.wait(self.ping_timeout):
778778
if self.debug:
779779
self.logger.debug("% received keepalive pong")
780780
else:
@@ -808,7 +808,7 @@ def recv_events(self) -> None:
808808
809809
Run this method in a thread as long as the connection is alive.
810810
811-
``recv_events()`` exits immediately when the ``self.socket`` is closed.
811+
``recv_events()`` exits immediately when ``self.socket`` is closed.
812812
813813
"""
814814
try:
@@ -983,6 +983,7 @@ def send_context(
983983
# Minor layering violation: we assume that the connection
984984
# will be closing soon if it isn't in the expected state.
985985
wait_for_close = True
986+
# TODO: calculate close deadline if not set?
986987
raise_close_exc = True
987988

988989
# To avoid a deadlock, release the connection lock by exiting the

0 commit comments

Comments
 (0)