Skip to content

Commit b0d9376

Browse files
committed
Add trio connection.
Uniformize other implementations.
1 parent 2e87c47 commit b0d9376

File tree

9 files changed

+3015
-459
lines changed

9 files changed

+3015
-459
lines changed

docs/topics/keepalive.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ measured during the last exchange of Ping and Pong frames::
136136
Alternatively, you can measure the latency at any time by calling
137137
:attr:`~asyncio.connection.Connection.ping` and awaiting its result::
138138

139-
pong_waiter = await websocket.ping()
140-
latency = await pong_waiter
139+
pong_received = await websocket.ping()
140+
latency = await pong_received
141141

142142
Latency between a client and a server may increase for two reasons:
143143

src/websockets/asyncio/connection.py

Lines changed: 59 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,10 @@ def __init__(
6363
self.ping_interval = ping_interval
6464
self.ping_timeout = ping_timeout
6565
self.close_timeout = close_timeout
66-
self.max_queue: tuple[int | None, int | None]
6766
if isinstance(max_queue, int) or max_queue is None:
68-
self.max_queue = (max_queue, None)
67+
self.max_queue_high, self.max_queue_low = max_queue, None
6968
else:
70-
self.max_queue = max_queue
69+
self.max_queue_high, self.max_queue_low = max_queue
7170
if isinstance(write_limit, int):
7271
write_limit = (write_limit, None)
7372
self.write_limit = write_limit
@@ -101,12 +100,12 @@ def __init__(
101100
self.close_deadline: float | None = None
102101

103102
# Protect sending fragmented messages.
104-
self.fragmented_send_waiter: asyncio.Future[None] | None = None
103+
self.send_in_progress: asyncio.Future[None] | None = None
105104

106105
# Mapping of ping IDs to pong waiters, in chronological order.
107-
self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {}
106+
self.pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {}
108107

109-
self.latency: float = 0
108+
self.latency: float = 0.0
110109
"""
111110
Latency of the connection, in seconds.
112111
@@ -417,16 +416,16 @@ async def send(
417416
418417
You may override this behavior with the ``text`` argument:
419418
420-
* Set ``text=True`` to send a bytestring or bytes-like object
421-
(:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a
419+
* Set ``text=True`` to send an UTF-8 bytestring or bytes-like object
420+
(:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) in a
422421
Text_ frame. This improves performance when the message is already
423422
UTF-8 encoded, for example if the message contains JSON and you're
424423
using a JSON library that produces a bytestring.
425424
* Set ``text=False`` to send a string (:class:`str`) in a Binary_
426425
frame. This may be useful for servers that expect binary frames
427426
instead of text frames.
428427
429-
:meth:`send` also accepts an iterable or an asynchronous iterable of
428+
:meth:`send` also accepts an iterable or asynchronous iterable of
430429
strings, bytestrings, or bytes-like objects to enable fragmentation_.
431430
Each item is treated as a message fragment and sent in its own frame.
432431
All items must be of the same type, or else :meth:`send` will raise a
@@ -468,8 +467,8 @@ async def send(
468467
"""
469468
# While sending a fragmented message, prevent sending other messages
470469
# until all fragments are sent.
471-
while self.fragmented_send_waiter is not None:
472-
await asyncio.shield(self.fragmented_send_waiter)
470+
while self.send_in_progress is not None:
471+
await asyncio.shield(self.send_in_progress)
473472

474473
# Unfragmented message -- this case must be handled first because
475474
# strings and bytes-like objects are iterable.
@@ -502,8 +501,8 @@ async def send(
502501
except StopIteration:
503502
return
504503

505-
assert self.fragmented_send_waiter is None
506-
self.fragmented_send_waiter = self.loop.create_future()
504+
assert self.send_in_progress is None
505+
self.send_in_progress = self.loop.create_future()
507506
try:
508507
# First fragment.
509508
if isinstance(chunk, str):
@@ -549,8 +548,8 @@ async def send(
549548
raise
550549

551550
finally:
552-
self.fragmented_send_waiter.set_result(None)
553-
self.fragmented_send_waiter = None
551+
self.send_in_progress.set_result(None)
552+
self.send_in_progress = None
554553

555554
# Fragmented message -- async iterator.
556555

@@ -561,8 +560,8 @@ async def send(
561560
except StopAsyncIteration:
562561
return
563562

564-
assert self.fragmented_send_waiter is None
565-
self.fragmented_send_waiter = self.loop.create_future()
563+
assert self.send_in_progress is None
564+
self.send_in_progress = self.loop.create_future()
566565
try:
567566
# First fragment.
568567
if isinstance(chunk, str):
@@ -610,8 +609,8 @@ async def send(
610609
raise
611610

612611
finally:
613-
self.fragmented_send_waiter.set_result(None)
614-
self.fragmented_send_waiter = None
612+
self.send_in_progress.set_result(None)
613+
self.send_in_progress = None
615614

616615
else:
617616
raise TypeError("data must be str, bytes, iterable, or async iterable")
@@ -639,7 +638,7 @@ async def close(
639638
# The context manager takes care of waiting for the TCP connection
640639
# to terminate after calling a method that sends a close frame.
641640
async with self.send_context():
642-
if self.fragmented_send_waiter is not None:
641+
if self.send_in_progress is not None:
643642
self.protocol.fail(
644643
CloseCode.INTERNAL_ERROR,
645644
"close during fragmented message",
@@ -681,9 +680,9 @@ async def ping(self, data: DataLike | None = None) -> Awaitable[float]:
681680
682681
::
683682
684-
pong_waiter = await ws.ping()
683+
pong_received = await ws.ping()
685684
# only if you want to wait for the corresponding pong
686-
latency = await pong_waiter
685+
latency = await pong_received
687686
688687
Raises:
689688
ConnectionClosed: When the connection is closed.
@@ -700,19 +699,20 @@ async def ping(self, data: DataLike | None = None) -> Awaitable[float]:
700699

701700
async with self.send_context():
702701
# Protect against duplicates if a payload is explicitly set.
703-
if data in self.pong_waiters:
702+
if data in self.pending_pings:
704703
raise ConcurrencyError("already waiting for a pong with the same data")
705704

706705
# Generate a unique random payload otherwise.
707-
while data is None or data in self.pong_waiters:
706+
while data is None or data in self.pending_pings:
708707
data = struct.pack("!I", random.getrandbits(32))
709708

710-
pong_waiter = self.loop.create_future()
709+
pong_received = self.loop.create_future()
710+
ping_timestamp = self.loop.time()
711711
# The event loop's default clock is time.monotonic(). Its resolution
712712
# is a bit low on Windows (~16ms). This is improved in Python 3.13.
713-
self.pong_waiters[data] = (pong_waiter, self.loop.time())
713+
self.pending_pings[data] = (pong_received, ping_timestamp)
714714
self.protocol.send_ping(data)
715-
return pong_waiter
715+
return pong_received
716716

717717
async def pong(self, data: DataLike = b"") -> None:
718718
"""
@@ -761,7 +761,7 @@ def acknowledge_pings(self, data: bytes) -> None:
761761
762762
"""
763763
# Ignore unsolicited pong.
764-
if data not in self.pong_waiters:
764+
if data not in self.pending_pings:
765765
return
766766

767767
pong_timestamp = self.loop.time()
@@ -770,41 +770,39 @@ def acknowledge_pings(self, data: bytes) -> None:
770770
# Acknowledge all previous pings too in that case.
771771
ping_id = None
772772
ping_ids = []
773-
for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items():
773+
for ping_id, (pong_received, ping_timestamp) in self.pending_pings.items():
774774
ping_ids.append(ping_id)
775775
latency = pong_timestamp - ping_timestamp
776-
if not pong_waiter.done():
777-
pong_waiter.set_result(latency)
776+
if not pong_received.done():
777+
pong_received.set_result(latency)
778778
if ping_id == data:
779779
self.latency = latency
780780
break
781781
else:
782782
raise AssertionError("solicited pong not found in pings")
783783

784-
# Remove acknowledged pings from self.pong_waiters.
784+
# Remove acknowledged pings from self.pending_pings.
785785
for ping_id in ping_ids:
786-
del self.pong_waiters[ping_id]
786+
del self.pending_pings[ping_id]
787787

788-
def abort_pings(self) -> None:
788+
def terminate_pending_pings(self) -> None:
789789
"""
790-
Raise ConnectionClosed in pending pings.
791-
792-
They'll never receive a pong once the connection is closed.
790+
Raise ConnectionClosed in pending pings when the connection is closed.
793791
794792
"""
795793
assert self.protocol.state is CLOSED
796794
exc = self.protocol.close_exc
797795

798-
for pong_waiter, _ping_timestamp in self.pong_waiters.values():
799-
if not pong_waiter.done():
800-
pong_waiter.set_exception(exc)
796+
for pong_received, _ping_timestamp in self.pending_pings.values():
797+
if not pong_received.done():
798+
pong_received.set_exception(exc)
801799
# If the exception is never retrieved, it will be logged when ping
802800
# is garbage-collected. This is confusing for users.
803801
# Given that ping is done (with an exception), canceling it does
804802
# nothing, but it prevents logging the exception.
805-
pong_waiter.cancel()
803+
pong_received.cancel()
806804

807-
self.pong_waiters.clear()
805+
self.pending_pings.clear()
808806

809807
async def keepalive(self) -> None:
810808
"""
@@ -825,7 +823,7 @@ async def keepalive(self) -> None:
825823
# connection to be closed before raising ConnectionClosed.
826824
# However, connection_lost() cancels keepalive_task before
827825
# it gets a chance to resume excuting.
828-
pong_waiter = await self.ping()
826+
pong_received = await self.ping()
829827
if self.debug:
830828
self.logger.debug("% sent keepalive ping")
831829

@@ -834,10 +832,11 @@ async def keepalive(self) -> None:
834832
async with asyncio_timeout(self.ping_timeout):
835833
# connection_lost cancels keepalive immediately
836834
# after setting a ConnectionClosed exception on
837-
# pong_waiter. A CancelledError is raised here,
835+
# pong_received. A CancelledError is raised here,
838836
# not a ConnectionClosed exception.
839-
latency = await pong_waiter
840-
self.logger.debug("% received keepalive pong")
837+
latency = await pong_received
838+
if self.debug:
839+
self.logger.debug("% received keepalive pong")
841840
except asyncio.TimeoutError:
842841
if self.debug:
843842
self.logger.debug("- timed out waiting for keepalive pong")
@@ -908,14 +907,13 @@ async def send_context(
908907
# Check if the connection is expected to close soon.
909908
if self.protocol.close_expected():
910909
wait_for_close = True
911-
# If the connection is expected to close soon, set the
912-
# close deadline based on the close timeout.
913-
# Since we tested earlier that protocol.state was OPEN
910+
# Set the close deadline based on the close timeout.
911+
# Since we tested earlier that protocol.state is OPEN
914912
# (or CONNECTING), self.close_deadline is still None.
913+
assert self.close_deadline is None
915914
if self.close_timeout is not None:
916-
assert self.close_deadline is None
917915
self.close_deadline = self.loop.time() + self.close_timeout
918-
# Write outgoing data to the socket and enforce flow control.
916+
# Write outgoing data to the socket with flow control.
919917
try:
920918
self.send_data()
921919
await self.drain()
@@ -933,9 +931,8 @@ async def send_context(
933931
# will be closing soon if it isn't in the expected state.
934932
wait_for_close = True
935933
# Calculate close_deadline if it wasn't set yet.
936-
if self.close_timeout is not None:
937-
if self.close_deadline is None:
938-
self.close_deadline = self.loop.time() + self.close_timeout
934+
if self.close_deadline is None and self.close_timeout is not None:
935+
self.close_deadline = self.loop.time() + self.close_timeout
939936
raise_close_exc = True
940937

941938
# If the connection is expected to close soon and the close timeout
@@ -966,9 +963,6 @@ def send_data(self) -> None:
966963
"""
967964
Send outgoing data.
968965
969-
Raises:
970-
OSError: When a socket operations fails.
971-
972966
"""
973967
for data in self.protocol.data_to_send():
974968
if data:
@@ -982,7 +976,7 @@ def send_data(self) -> None:
982976
# OSError is plausible. uvloop can raise RuntimeError here.
983977
try:
984978
self.transport.write_eof()
985-
except (OSError, RuntimeError): # pragma: no cover
979+
except Exception: # pragma: no cover
986980
pass
987981
# Else, close the TCP connection.
988982
else: # pragma: no cover
@@ -1005,7 +999,8 @@ def set_recv_exc(self, exc: BaseException | None) -> None:
1005999
def connection_made(self, transport: asyncio.BaseTransport) -> None:
10061000
transport = cast(asyncio.Transport, transport)
10071001
self.recv_messages = Assembler(
1008-
*self.max_queue,
1002+
self.max_queue_high,
1003+
self.max_queue_low,
10091004
pause=transport.pause_reading,
10101005
resume=transport.resume_reading,
10111006
)
@@ -1022,7 +1017,7 @@ def connection_lost(self, exc: Exception | None) -> None:
10221017

10231018
# Abort recv() and pending pings with a ConnectionClosed exception.
10241019
self.recv_messages.close()
1025-
self.abort_pings()
1020+
self.terminate_pending_pings()
10261021

10271022
if self.keepalive_task is not None:
10281023
self.keepalive_task.cancel()
@@ -1092,12 +1087,10 @@ def data_received(self, data: bytes) -> None:
10921087
self.logger.debug("! error while sending data", exc_info=True)
10931088
self.set_recv_exc(exc)
10941089

1090+
# If needed, set the close deadline based on the close timeout.
10951091
if self.protocol.close_expected():
1096-
# If the connection is expected to close soon, set the
1097-
# close deadline based on the close timeout.
1098-
if self.close_timeout is not None:
1099-
if self.close_deadline is None:
1100-
self.close_deadline = self.loop.time() + self.close_timeout
1092+
if self.close_deadline is None and self.close_timeout is not None:
1093+
self.close_deadline = self.loop.time() + self.close_timeout
11011094

11021095
for event in events:
11031096
# This isn't expected to raise an exception.
@@ -1205,7 +1198,7 @@ def broadcast(
12051198
if connection.protocol.state is not OPEN:
12061199
continue
12071200

1208-
if connection.fragmented_send_waiter is not None:
1201+
if connection.send_in_progress is not None:
12091202
if raise_exceptions:
12101203
exception = ConcurrencyError("sending a fragmented message")
12111204
exceptions.append(exception)

0 commit comments

Comments
 (0)