Skip to content

Commit 57c78cd

Browse files
authored
Add names to all asyncio tasks (#152)
* Attach names to all tasks * Name task in wait_for * Cover the kick event too * Fix copy paste error * Only run on python 312 for task naming check
1 parent 40e5233 commit 57c78cd

File tree

10 files changed

+81
-6
lines changed

10 files changed

+81
-6
lines changed

dispatcherd/brokers/socket.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, AsyncGenerator, Callable, Coroutine, Iterator, Optional, Union
77

88
from ..protocols import Broker as BrokerProtocol
9+
from ..service.asyncio_tasks import named_wait
910

1011
logger = logging.getLogger(__name__)
1112

@@ -77,6 +78,9 @@ async def _add_client(self, reader: asyncio.StreamReader, writer: asyncio.Stream
7778
client = Client(self.client_ct, reader, writer)
7879
self.clients[self.client_ct] = client
7980
self.client_ct += 1
81+
current_task = asyncio.current_task()
82+
if current_task is not None:
83+
current_task.set_name(f'socket_client_task_{client.client_id}')
8084
logger.info(f'Socket client_id={client.client_id} is connected')
8185

8286
try:
@@ -89,7 +93,7 @@ async def _add_client(self, reader: asyncio.StreamReader, writer: asyncio.Stream
8993
await self.incoming_queue.put((client.client_id, message))
9094
# Wait for caller to potentially fill a reply queue
9195
# this should realistically never take more than a trivial amount of time
92-
await asyncio.wait_for(client.yield_clear.wait(), timeout=2)
96+
await asyncio.wait_for(named_wait(client.yield_clear, f'internal_wait_for_client_{client.client_id}'), timeout=2)
9397
client.yield_clear.clear()
9498
await client.send_replies()
9599
except asyncio.TimeoutError:

dispatcherd/producers/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def __init__(self) -> None:
1010

1111

1212
class BaseProducer(ProducerProtocol):
13+
can_recycle: bool = False
1314

1415
def __init__(self) -> None:
1516
self.events = ProducerEvents()

dispatcherd/producers/brokered.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010

1111
class BrokeredProducer(BaseProducer):
12+
can_recycle = True
13+
1214
def __init__(self, broker: Broker) -> None:
1315
self.production_task: Optional[asyncio.Task] = None
1416
self.broker = broker

dispatcherd/protocols.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class Producer(Protocol):
7575
"""
7676

7777
events: ProducerEvents
78+
can_recycle: bool
7879

7980
async def start_producing(self, dispatcher: 'DispatcherMain') -> None:
8081
"""Starts tasks which will eventually call DispatcherMain.process_message - how tasks originate in the service"""

dispatcherd/service/asyncio_tasks.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,15 @@ def ensure_fatal(task: asyncio.Task, exit_event: Optional[asyncio.Event] = None)
3636
return task # nicety so this can be used as a wrapper
3737

3838

39-
async def wait_for_any(events: Iterable[asyncio.Event]) -> int:
39+
async def wait_for_any(events: Iterable[asyncio.Event], names: Optional[Iterable[str]] = None) -> int:
4040
"""
4141
Wait for a list of events. If any of the events gets set, this function
4242
will return
4343
"""
44-
tasks = [asyncio.create_task(event.wait()) for event in events]
44+
if names:
45+
tasks = [asyncio.create_task(event.wait(), name=task_name) for (event, task_name) in zip(events, names)]
46+
else:
47+
tasks = [asyncio.create_task(event.wait()) for event in events]
4548
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
4649
for task in pending:
4750
task.cancel()
@@ -53,3 +56,12 @@ async def wait_for_any(events: Iterable[asyncio.Event]) -> int:
5356
return i
5457

5558
raise RuntimeError('Internal error - could done find any tasks that are done')
59+
60+
61+
async def named_wait(event: asyncio.Event, name: str) -> None:
62+
"""Add a name to waiting task so it is visible via debugging commands"""
63+
current_task = asyncio.current_task()
64+
if current_task:
65+
current_task.set_name(name)
66+
67+
await event.wait()

dispatcherd/service/main.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ async def cancel_tasks(self) -> None:
242242
async def recycle_broker_producers(self) -> None:
243243
"""For any producer in a broken state (likely due to external factors beyond our control) recycle it"""
244244
for producer in self.producers:
245+
if not producer.can_recycle:
246+
continue
245247
if producer.events.recycle_event.is_set():
246248
await producer.recycle()
247249
for task in producer.all_tasks():
@@ -251,12 +253,20 @@ async def recycle_broker_producers(self) -> None:
251253
async def main_loop_wait(self) -> None:
252254
"""Wait for an event that requires some kind of action by the main loop"""
253255
events = [self.events.exit_event]
256+
names = ['exit_event_wait']
254257
for producer in self.producers:
258+
if not producer.can_recycle:
259+
continue
255260
events.append(producer.events.recycle_event)
261+
names.append(f'{str(producer)}_recycle_event_wait')
256262

257-
await wait_for_any(events)
263+
await wait_for_any(events, names=names)
258264

259265
async def main(self) -> None:
266+
current_task = asyncio.current_task()
267+
if current_task is not None:
268+
current_task.set_name('dispatcherd_service_main')
269+
260270
await self.connect_signals()
261271

262272
try:

dispatcherd/service/next_wakeup_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from abc import abstractmethod
55
from typing import Any, Callable, Coroutine, Iterable, Optional
66

7-
from .asyncio_tasks import ensure_fatal
7+
from .asyncio_tasks import ensure_fatal, named_wait
88

99
logger = logging.getLogger(__name__)
1010

@@ -92,7 +92,7 @@ async def background_task(self) -> None:
9292
delta = 0.1
9393

9494
try:
95-
await asyncio.wait_for(self.kick_event.wait(), timeout=delta)
95+
await asyncio.wait_for(named_wait(self.kick_event, f'{self.name}_kick_event_wait'), timeout=delta)
9696
except asyncio.TimeoutError:
9797
pass # intended mechanism to hit the next schedule
9898
except asyncio.CancelledError:

tests/integration/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import sys
2+
3+
import pytest
4+
5+
6+
@pytest.fixture
7+
def python312():
8+
if sys.version_info < (3, 12):
9+
pytest.skip("test requires python 3.12")

tests/integration/test_main.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
from dispatcherd.config import temporary_settings
9+
from dispatcherd.service.control_tasks import _stack_from_task
910
from tests.data import methods as test_methods
1011

1112
SLEEP_METHOD = 'lambda: __import__("time").sleep(0.1)'
@@ -293,3 +294,18 @@ async def test_scale_up(apg_dispatcher, test_settings):
293294
break
294295
else:
295296
assert f'Never scaled up to expected 6 workers, have: {apg_dispatcher.pool.workers}'
297+
298+
299+
@pytest.mark.asyncio
300+
async def test_tasks_are_named(apg_dispatcher, python312):
301+
wait_task = asyncio.create_task(apg_dispatcher.main_loop_wait(), name='this_is_for_test')
302+
303+
current_task = asyncio.current_task()
304+
for task in asyncio.all_tasks():
305+
if task is current_task:
306+
continue
307+
task_name = task.get_name()
308+
assert not task_name.startswith('Task-'), _stack_from_task(task)
309+
310+
apg_dispatcher.events.exit_event.set()
311+
await wait_task

tests/integration/test_socket_use.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from dispatcherd.control import Control
1111
from dispatcherd.factories import from_settings, get_control_from_settings, get_publisher_from_settings
1212
from dispatcherd.protocols import DispatcherMain
13+
from dispatcherd.service.control_tasks import _stack_from_task
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -84,3 +85,22 @@ def alive_cmd():
8485
data = alive[0]
8586

8687
assert data['node_id'] == 'socket-test-server'
88+
89+
90+
@pytest.mark.asyncio
91+
async def test_socket_tasks_are_named(asock_dispatcher, sock_control, python312):
92+
loop = asyncio.get_event_loop()
93+
94+
def aio_tasks_cmd():
95+
return sock_control.control_with_reply('aio_tasks')
96+
97+
aio_tasks = await loop.run_in_executor(None, aio_tasks_cmd)
98+
99+
current_task_name = asyncio.current_task().get_name()
100+
101+
assert len(aio_tasks) == 1
102+
data = aio_tasks[0]
103+
for task_name, task_stuff in data.items():
104+
if task_name == current_task_name:
105+
continue
106+
assert not task_name.startswith('Task-'), task_stuff['stack']

0 commit comments

Comments
 (0)