Skip to content

Commit c215bf9

Browse files
GitHKAndrei Neagu
andauthored
🐛♻️ Better context clenup (ITISFoundation#2586)
* refactor & add cleanup * attached events to stop * docs and pylint * adding new test and refactored exiting * remove dunsused * renamed function to avoid confusion Co-authored-by: Andrei Neagu <[email protected]>
1 parent aff0b8e commit c215bf9

File tree

5 files changed

+140
-158
lines changed

5 files changed

+140
-158
lines changed

packages/service-library/src/servicelib/async_utils.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,33 @@
11
import asyncio
2+
import logging
23
from collections import deque
34
from functools import wraps
4-
from typing import Dict, List
5+
from typing import Dict, List, Optional
56

67
import attr
78

9+
logger = logging.getLogger(__name__)
10+
811

912
@attr.s(auto_attribs=True)
1013
class Context:
1114
in_queue: asyncio.Queue
1215
out_queue: asyncio.Queue
1316
initialized: bool
17+
task: Optional[asyncio.Task] = None
18+
1419

20+
_sequential_jobs_contexts: Dict[str, Context] = {}
1521

16-
sequential_jobs_contexts = {}
22+
23+
async def stop_sequential_workers() -> None:
24+
"""Singlas all workers to close thus avoiding errors on shutdown"""
25+
for context in _sequential_jobs_contexts.values():
26+
await context.in_queue.put(None)
27+
if context.task is not None:
28+
await context.task
29+
_sequential_jobs_contexts.clear()
30+
logger.info("All run_sequentially_in_context pending workers stopped")
1731

1832

1933
def run_sequentially_in_context(target_args: List[str] = None):
@@ -35,11 +49,13 @@ async def func(param1, param2, param3):
3549
3650
functions = [
3751
func(1, "something", 3),
38-
func(1, "else", 3),
52+
func(1, "argument.attribute", 3),
3953
func(1, "here", 3),
4054
]
4155
await asyncio.gather(*functions)
4256
57+
note the special "argument.attribute", which will use the attribute of argument to create the context.
58+
4359
The following calls will run in parallel, because they have different contexts:
4460
4561
functions = [
@@ -62,24 +78,34 @@ def get_context(args, kwargs: Dict) -> Context:
6278

6379
key_parts = deque()
6480
for arg in target_args:
65-
if arg not in search_args:
81+
sub_args = arg.split(".")
82+
main_arg = sub_args[0]
83+
if main_arg not in search_args:
6684
message = (
67-
f"Expected '{arg}' in '{decorated_function.__name__}'"
85+
f"Expected '{main_arg}' in '{decorated_function.__name__}'"
6886
f" arguments. Got '{search_args}'"
6987
)
7088
raise ValueError(message)
71-
key_parts.append(search_args[arg])
89+
context_key = search_args[main_arg]
90+
for attribute in sub_args[1:]:
91+
potential_key = getattr(context_key, attribute)
92+
if not potential_key:
93+
message = f"Expected '{attribute}' attribute in '{context_key.__name__}' arguments."
94+
raise ValueError(message)
95+
context_key = potential_key
96+
97+
key_parts.append(f"{decorated_function.__name__}_{context_key}")
7298

7399
key = ":".join(map(str, key_parts))
74100

75-
if key not in sequential_jobs_contexts:
76-
sequential_jobs_contexts[key] = Context(
101+
if key not in _sequential_jobs_contexts:
102+
_sequential_jobs_contexts[key] = Context(
77103
in_queue=asyncio.Queue(),
78104
out_queue=asyncio.Queue(),
79105
initialized=False,
80106
)
81107

82-
return sequential_jobs_contexts[key]
108+
return _sequential_jobs_contexts[key]
83109

84110
@wraps(decorated_function)
85111
async def wrapper(*args, **kwargs):
@@ -92,13 +118,22 @@ async def worker(in_q: asyncio.Queue, out_q: asyncio.Queue):
92118
while True:
93119
awaitable = await in_q.get()
94120
in_q.task_done()
121+
# check if requested to shutdown
122+
if awaitable is None:
123+
break
95124
try:
96125
result = await awaitable
97126
except Exception as e: # pylint: disable=broad-except
98127
result = e
99128
await out_q.put(result)
100129

101-
asyncio.get_event_loop().create_task(
130+
logging.info(
131+
"Closed worker for @run_sequentially_in_context applied to '%s' with target_args=%s",
132+
decorated_function.__name__,
133+
target_args,
134+
)
135+
136+
context.task = asyncio.create_task(
102137
worker(context.in_queue, context.out_queue)
103138
)
104139

packages/service-library/tests/test_async_utils.py

Lines changed: 89 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,47 @@
11
# pylint: disable=redefined-outer-name
2+
# pylint: disable=unused-argument
23

34
import asyncio
45
import copy
56
import random
67
from collections import deque
8+
from dataclasses import dataclass
79
from time import time
8-
from typing import Any, Dict, List
10+
from typing import Any, AsyncIterable, Dict, List, Optional
911

1012
import pytest
11-
from servicelib.async_utils import run_sequentially_in_context, sequential_jobs_contexts
13+
from servicelib.async_utils import (
14+
_sequential_jobs_contexts,
15+
run_sequentially_in_context,
16+
stop_sequential_workers,
17+
)
1218

19+
RETRIES = 10
20+
DIFFERENT_CONTEXTS_COUNT = 10
1321

14-
@pytest.fixture(autouse=True)
15-
def ensure_run_in_sequence_context_is_empty():
16-
# NOTE: since the contexts variable is initialized at import time, when several test run
17-
# the import happens only once and is rendered invalid, therefore explicit clearance is necessary
18-
sequential_jobs_contexts.clear()
22+
23+
@pytest.fixture
24+
async def ensure_run_in_sequence_context_is_empty(loop) -> AsyncIterable[None]:
25+
yield
26+
# NOTE
27+
# required when shutting down the application or ending tests
28+
# otherwise errors will occur when closing the loop
29+
await stop_sequential_workers()
30+
31+
32+
@pytest.fixture
33+
def payload() -> str:
34+
return "some string payload"
35+
36+
37+
@pytest.fixture
38+
def expected_param_name() -> str:
39+
return "expected_param_name"
40+
41+
42+
@pytest.fixture
43+
def sleep_duration() -> float:
44+
return 0.01
1945

2046

2147
class LockedStore:
@@ -34,12 +60,14 @@ async def get_all(self) -> List[Any]:
3460
return list(self._queue)
3561

3662

37-
async def test_context_aware_dispatch() -> None:
63+
async def test_context_aware_dispatch(
64+
sleep_duration: float,
65+
ensure_run_in_sequence_context_is_empty: None,
66+
) -> None:
3867
@run_sequentially_in_context(target_args=["c1", "c2", "c3"])
3968
async def orderly(c1: Any, c2: Any, c3: Any, control: Any) -> None:
4069
_ = (c1, c2, c3)
41-
sleep_interval = random.uniform(0, 0.01)
42-
await asyncio.sleep(sleep_interval)
70+
await asyncio.sleep(sleep_duration)
4371

4472
context = dict(c1=c1, c2=c2, c3=c3)
4573
await locked_stores[make_key_from_context(context)].push(control)
@@ -81,12 +109,14 @@ def make_context():
81109
assert list(expected_outcomes[key]) == await locked_stores[key].get_all()
82110

83111

84-
async def test_context_aware_function_sometimes_fails() -> None:
112+
async def test_context_aware_function_sometimes_fails(
113+
ensure_run_in_sequence_context_is_empty: None,
114+
) -> None:
85115
class DidFailException(Exception):
86116
pass
87117

88118
@run_sequentially_in_context(target_args=["will_fail"])
89-
async def sometimes_failing(will_fail: bool) -> None:
119+
async def sometimes_failing(will_fail: bool) -> bool:
90120
if will_fail:
91121
raise DidFailException("I was instructed to fail")
92122
return True
@@ -101,8 +131,10 @@ async def sometimes_failing(will_fail: bool) -> None:
101131
assert await sometimes_failing(raise_error) is True
102132

103133

104-
async def test_context_aware_wrong_target_args_name() -> None:
105-
expected_param_name = "wrong_parameter"
134+
async def test_context_aware_wrong_target_args_name(
135+
expected_param_name: str,
136+
ensure_run_in_sequence_context_is_empty: None, # pylint: disable=unused-argument
137+
) -> None:
106138

107139
# pylint: disable=unused-argument
108140
@run_sequentially_in_context(target_args=[expected_param_name])
@@ -119,15 +151,17 @@ async def target_function(the_param: Any) -> None:
119151
assert str(excinfo.value).startswith(message) is True
120152

121153

122-
async def test_context_aware_measure_parallelism() -> None:
154+
async def test_context_aware_measure_parallelism(
155+
sleep_duration: float,
156+
ensure_run_in_sequence_context_is_empty: None,
157+
) -> None:
123158
# expected duration 1 second
124159
@run_sequentially_in_context(target_args=["control"])
125160
async def sleep_for(sleep_interval: float, control: Any) -> Any:
126161
await asyncio.sleep(sleep_interval)
127162
return control
128163

129-
control_sequence = list(range(1000))
130-
sleep_duration = 0.5
164+
control_sequence = list(range(RETRIES))
131165
functions = [sleep_for(sleep_duration, x) for x in control_sequence]
132166

133167
start = time()
@@ -138,15 +172,17 @@ async def sleep_for(sleep_interval: float, control: Any) -> Any:
138172
assert control_sequence == result
139173

140174

141-
async def test_context_aware_measure_serialization() -> None:
175+
async def test_context_aware_measure_serialization(
176+
sleep_duration: float,
177+
ensure_run_in_sequence_context_is_empty: None,
178+
) -> None:
142179
# expected duration 1 second
143180
@run_sequentially_in_context(target_args=["control"])
144181
async def sleep_for(sleep_interval: float, control: Any) -> Any:
145182
await asyncio.sleep(sleep_interval)
146183
return control
147184

148-
control_sequence = [1 for _ in range(10)]
149-
sleep_duration = 0.1
185+
control_sequence = [1 for _ in range(RETRIES)]
150186
functions = [sleep_for(sleep_duration, x) for x in control_sequence]
151187

152188
start = time()
@@ -156,3 +192,36 @@ async def sleep_for(sleep_interval: float, control: Any) -> Any:
156192
minimum_timelapse = (sleep_duration) * len(control_sequence)
157193
assert elapsed > minimum_timelapse
158194
assert control_sequence == result
195+
196+
197+
async def test_nested_object_attribute(
198+
payload: str,
199+
ensure_run_in_sequence_context_is_empty: None,
200+
) -> None:
201+
@dataclass
202+
class ObjectWithPropos:
203+
attr1: str = payload
204+
205+
@run_sequentially_in_context(target_args=["object_with_props.attr1"])
206+
async def test_attribute(
207+
object_with_props: ObjectWithPropos, other_attr: Optional[int] = None
208+
) -> str:
209+
return object_with_props.attr1
210+
211+
for _ in range(RETRIES):
212+
assert payload == await test_attribute(ObjectWithPropos())
213+
214+
215+
async def test_different_contexts(
216+
payload: str,
217+
ensure_run_in_sequence_context_is_empty: None,
218+
) -> None:
219+
@run_sequentially_in_context(target_args=["context_param"])
220+
async def test_multiple_context_calls(context_param: int) -> int:
221+
return context_param
222+
223+
for _ in range(RETRIES):
224+
for i in range(DIFFERENT_CONTEXTS_COUNT):
225+
assert i == await test_multiple_context_calls(i)
226+
227+
assert len(_sequential_jobs_contexts) == RETRIES

services/director-v2/src/simcore_service_director_v2/api/routes/computations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from fastapi import APIRouter, Depends, HTTPException
88
from models_library.projects import ProjectAtDB, ProjectID
99
from models_library.projects_state import RunningState
10+
from servicelib.async_utils import run_sequentially_in_context
1011
from starlette import status
1112
from starlette.requests import Request
1213
from tenacity import (
@@ -32,7 +33,6 @@
3233
from ...modules.db.repositories.comp_tasks import CompTasksRepository
3334
from ...modules.db.repositories.projects import ProjectsRepository
3435
from ...modules.director_v0 import DirectorV0Client
35-
from ...utils.async_utils import run_sequentially_in_context
3636
from ...utils.computations import (
3737
get_pipeline_state_from_task_states,
3838
is_pipeline_running,

services/director-v2/src/simcore_service_director_v2/core/events.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from servicelib.async_utils import stop_sequential_workers
2+
13
from ..meta import PROJECT_NAME, __version__
24

35
#
@@ -16,10 +18,11 @@
1618
)
1719

1820

19-
def on_startup() -> None:
21+
async def on_startup() -> None:
2022
print(WELCOME_MSG, flush=True)
2123

2224

23-
def on_shutdown() -> None:
25+
async def on_shutdown() -> None:
26+
await stop_sequential_workers()
2427
msg = PROJECT_NAME + f" v{__version__} SHUT DOWN"
2528
print(f"{msg:=^100}", flush=True)

0 commit comments

Comments
 (0)