Skip to content

Commit 0fec860

Browse files
feat: add asyncio server state machine
1 parent 92e5f82 commit 0fec860

File tree

4 files changed

+86
-13
lines changed

4 files changed

+86
-13
lines changed

Lib/asyncio/base_events.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import collections
1717
import collections.abc
1818
import concurrent.futures
19+
import enum
1920
import errno
2021
import heapq
2122
import itertools
@@ -272,6 +273,23 @@ async def restore(self):
272273
self._proto.resume_writing()
273274

274275

276+
class _ServerState(enum.Enum):
277+
"""This tracks the state of Server.
278+
279+
-[in]->INITIALIZED -[ss]-> SERVING -[cl]-> CLOSED -[wk]*-> SHUTDOWN
280+
281+
- in: Server.__init__()
282+
- ss: Server._start_serving()
283+
- cl: Server.close()
284+
- wk: Server._wakeup() *only called if number of clients == 0
285+
"""
286+
287+
INITIALIZED = "initialized"
288+
SERVING = "serving"
289+
CLOSED = "closed"
290+
SHUTDOWN = "shutdown"
291+
292+
275293
class Server(events.AbstractServer):
276294

277295
def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
@@ -287,32 +305,49 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
287305
self._ssl_context = ssl_context
288306
self._ssl_handshake_timeout = ssl_handshake_timeout
289307
self._ssl_shutdown_timeout = ssl_shutdown_timeout
290-
self._serving = False
308+
self._state = _ServerState.INITIALIZED
291309
self._serving_forever_fut = None
292310

293311
def __repr__(self):
294312
return f'<{self.__class__.__name__} sockets={self.sockets!r}>'
295313

296314
def _attach(self, transport):
297-
assert self._sockets is not None
315+
if self._state != _ServerState.SERVING:
316+
raise RuntimeError("server is not serving, cannot attach transport")
298317
self._clients.add(transport)
299318

300319
def _detach(self, transport):
301320
self._clients.discard(transport)
302-
if len(self._clients) == 0 and self._sockets is None:
321+
if self._state == _ServerState.CLOSED and len(self._clients) == 0:
303322
self._wakeup()
304323

305324
def _wakeup(self):
325+
match self._state:
326+
case _ServerState.SHUTDOWN:
327+
# gh109564: the wakeup method has two possible call-sites,
328+
# through an explicit call Server.close(), or indirectly through
329+
# Server._detach() by the last connected client.
330+
return
331+
case _ServerState.INITIALIZED | _ServerState.SERVING:
332+
raise RuntimeError("cannot wakeup server before closing")
333+
case _ServerState.CLOSED:
334+
self._state = _ServerState.SHUTDOWN
335+
306336
waiters = self._waiters
307337
self._waiters = None
308338
for waiter in waiters:
309339
if not waiter.done():
310340
waiter.set_result(None)
311341

312342
def _start_serving(self):
313-
if self._serving:
314-
return
315-
self._serving = True
343+
match self._state:
344+
case _ServerState.SERVING:
345+
return
346+
case _ServerState.CLOSED | _ServerState.SHUTDOWN:
347+
raise RuntimeError(f'server {self!r} is closed')
348+
case _ServerState.INITIALIZED:
349+
self._state = _ServerState.SERVING
350+
316351
for sock in self._sockets:
317352
sock.listen(self._backlog)
318353
self._loop._start_serving(
@@ -324,7 +359,7 @@ def get_loop(self):
324359
return self._loop
325360

326361
def is_serving(self):
327-
return self._serving
362+
return self._state == _ServerState.SERVING
328363

329364
@property
330365
def sockets(self):
@@ -333,6 +368,13 @@ def sockets(self):
333368
return tuple(trsock.TransportSocket(s) for s in self._sockets)
334369

335370
def close(self):
371+
match self._state:
372+
case _ServerState.CLOSED | _ServerState.SHUTDOWN:
373+
# Shutdown state can only be reached after closing.
374+
return
375+
case _:
376+
self._state = _ServerState.CLOSED
377+
336378
sockets = self._sockets
337379
if sockets is None:
338380
return
@@ -341,8 +383,6 @@ def close(self):
341383
for sock in sockets:
342384
self._loop._stop_serving(sock)
343385

344-
self._serving = False
345-
346386
if (self._serving_forever_fut is not None and
347387
not self._serving_forever_fut.done()):
348388
self._serving_forever_fut.cancel()
@@ -369,8 +409,6 @@ async def serve_forever(self):
369409
if self._serving_forever_fut is not None:
370410
raise RuntimeError(
371411
f'server {self!r} is already being awaited on serve_forever()')
372-
if self._sockets is None:
373-
raise RuntimeError(f'server {self!r} is closed')
374412

375413
self._start_serving()
376414
self._serving_forever_fut = self._loop.create_future()
@@ -407,7 +445,7 @@ async def wait_closed(self):
407445
# from two places: self.close() and self._detach(), but only
408446
# when both conditions have become true. To signal that this
409447
# has happened, self._wakeup() sets self._waiters to None.
410-
if self._waiters is None:
448+
if self._state == _ServerState.SHUTDOWN:
411449
return
412450
waiter = self._loop.create_future()
413451
self._waiters.append(waiter)

Lib/asyncio/selector_events.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,12 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
795795
self._paused = False # Set when pause_reading() called
796796

797797
if self._server is not None:
798-
self._server._attach(self)
798+
if self._server.is_serving():
799+
self._server._attach(self)
800+
else:
801+
self.abort()
802+
return
803+
799804
loop._transports[self._sock_fd] = self
800805

801806
def __repr__(self):

Lib/test/test_asyncio/test_server.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
import threading
66
import unittest
7+
from unittest.mock import Mock
78

89
from test.support import socket_helper
910
from test.test_asyncio import utils as test_utils
@@ -186,6 +187,8 @@ async def serve(rd, wr):
186187
loop.call_soon(srv.close)
187188
loop.call_soon(wr.close)
188189
await srv.wait_closed()
190+
self.assertTrue(task.done())
191+
self.assertFalse(srv.is_serving())
189192

190193
async def test_close_clients(self):
191194
async def serve(rd, wr):
@@ -212,6 +215,9 @@ async def serve(rd, wr):
212215
await asyncio.sleep(0)
213216
self.assertTrue(task.done())
214217

218+
with self.assertRaisesRegex(RuntimeError, r'is closed'):
219+
await srv.start_serving()
220+
215221
async def test_abort_clients(self):
216222
async def serve(rd, wr):
217223
fut.set_result((rd, wr))
@@ -266,6 +272,29 @@ async def serve(rd, wr):
266272
await asyncio.sleep(0)
267273
self.assertTrue(task.done())
268274

275+
async def test_close_before_transport_attach(self):
276+
proto = Mock()
277+
loop = asyncio.get_running_loop()
278+
srv = await loop.create_server(lambda *_: proto, socket_helper.HOSTv4, 0)
279+
280+
await srv.start_serving()
281+
addr = srv.sockets[0].getsockname()
282+
283+
# Create a connection to the server but close the server before the
284+
# socket transport for the connection is created and attached
285+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
286+
s.connect(addr)
287+
await asyncio.sleep(0) # loop select reader
288+
await asyncio.sleep(0) # accept conn 1
289+
srv.close()
290+
291+
# Ensure the protocol is given an opportunity to handle this event
292+
# gh109564: the transport would be unclosed and will cause a loop
293+
# exception due to a double-call to Server._wakeup
294+
await asyncio.sleep(0)
295+
await asyncio.sleep(0)
296+
proto.connection_lost.assert_called()
297+
269298

270299
# Test the various corner cases of Unix server socket removal
271300
class UnixServerCleanupTests(unittest.IsolatedAsyncioTestCase):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix race condition in :meth:`asyncio.Server.close`. Patch by Jamie Phan.

0 commit comments

Comments
 (0)