Skip to content

Commit 5fa147d

Browse files
Don't nest contexts in modules (#137)
1 parent 0dd8f86 commit 5fa147d

File tree

4 files changed

+121
-59
lines changed

4 files changed

+121
-59
lines changed

src/fps/_context.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __init__(
101101
self._exit_stack: ExitStack | None = None
102102
self._async_exit_stack: AsyncExitStack | None = None
103103
self._opened = False
104+
self._closing = False
104105

105106
def _drop(self, borrower: Value) -> None:
106107
if borrower in self._borrowers:
@@ -191,6 +192,11 @@ async def aclose(
191192
Raises:
192193
TimeoutError: If the shared value could not be closed in time.
193194
"""
195+
if self._closing:
196+
return
197+
198+
self._closing = True
199+
194200
if timeout is None:
195201
timeout = self._close_timeout
196202
if timeout is None:
@@ -228,7 +234,7 @@ class Context:
228234
def __init__(self) -> None:
229235
self._context: dict[int, SharedValue] = {}
230236
self._value_added = Event()
231-
self._closed = Event()
237+
self._closed = False
232238
self._teardown_callbacks: list[
233239
Callable[..., Any] | Callable[..., Awaitable[Any]]
234240
] = []
@@ -246,21 +252,12 @@ async def __aenter__(self) -> Context:
246252
return self
247253

248254
async def __aexit__(self, exc_type, exc_value, exc_tb):
249-
for context in set(self._children):
250-
await context._closed.wait()
251-
async with create_task_group() as tg:
252-
for shared_value in self._context.values():
253-
tg.start_soon(
254-
partial(
255-
shared_value.aclose,
256-
_exc_type=exc_type,
257-
_exc_value=exc_value,
258-
_exc_tb=exc_tb,
259-
)
260-
)
261-
for callback in self._teardown_callbacks[::-1]:
262-
await call(callback, exc_value)
263-
self._closed.set()
255+
await self.aclose(
256+
timeout=None,
257+
_exc_type=exc_type,
258+
_exc_value=exc_value,
259+
_exc_tb=exc_tb,
260+
)
264261
_current_context.reset(self._token)
265262
if self._parent is not None:
266263
self._parent._children.remove(self)
@@ -277,7 +274,7 @@ def _get_value_types(
277274
return types
278275

279276
def _check_closed(self):
280-
if self._closed.is_set():
277+
if self._closed:
281278
raise RuntimeError("Context is closed")
282279

283280
def add_teardown_callback(
@@ -302,6 +299,7 @@ def put(
302299
teardown_callback: Callable[..., Any]
303300
| Callable[..., Awaitable[Any]]
304301
| None = None,
302+
shared_value: SharedValue[T] | None = None,
305303
) -> SharedValue[T]:
306304
"""
307305
Put a value in the context so that it can be shared.
@@ -319,12 +317,16 @@ def put(
319317
The shared value.
320318
"""
321319
self._check_closed()
322-
_shared_value = SharedValue(
323-
value,
324-
max_borrowers=max_borrowers,
325-
manage=manage,
326-
teardown_callback=teardown_callback,
327-
)
320+
if shared_value is not None:
321+
_shared_value = shared_value
322+
value = _shared_value._value
323+
else:
324+
_shared_value = SharedValue(
325+
value,
326+
max_borrowers=max_borrowers,
327+
manage=manage,
328+
teardown_callback=teardown_callback,
329+
)
328330
_types = self._get_value_types(value, types)
329331
for value_type in _types:
330332
value_type_id = id(value_type)
@@ -379,6 +381,41 @@ async def _get(self, value_type: type[T]) -> Value[T]:
379381
return await shared_value.get()
380382
await self._value_added.wait()
381383

384+
async def aclose(
385+
self,
386+
*,
387+
timeout: float | None = None,
388+
_exc_type=None,
389+
_exc_value: BaseException | None = None,
390+
_exc_tb=None,
391+
) -> None:
392+
"""
393+
Close the context, after all shared values that were borrowed have been dropped.
394+
The shared values will be torn down, if applicable.
395+
396+
Args:
397+
timeout: The time to wait for all shared values to be freed.
398+
399+
Raises:
400+
TimeoutError: If the context could not be closed in time.
401+
"""
402+
if timeout is None:
403+
timeout = float("inf")
404+
with fail_after(timeout):
405+
async with create_task_group() as tg:
406+
for shared_value in self._context.values():
407+
tg.start_soon(
408+
partial(
409+
shared_value.aclose,
410+
_exc_type=_exc_type,
411+
_exc_value=_exc_value,
412+
_exc_tb=_exc_tb,
413+
)
414+
)
415+
for callback in self._teardown_callbacks[::-1]:
416+
await call(callback, _exc_value)
417+
self._closed = True
418+
382419

383420
@lru_cache(maxsize=1024)
384421
def count_parameters(func: Callable) -> int:

src/fps/_module.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from collections.abc import Callable, Awaitable
77
from contextlib import AsyncExitStack
88
from inspect import isawaitable, signature, _empty
9-
from typing import TypeVar, Any, Iterable
9+
from typing import TypeVar, Any, Iterable, cast
1010

1111
import anyio
1212
import structlog
13-
from anyio import Event, create_task_group, move_on_after
13+
from anyio import Event, create_task_group, fail_after, move_on_after
1414
from anyioutils import create_task, wait, FIRST_COMPLETED
1515

1616
from ._context import Context, SharedValue, Value
@@ -69,6 +69,7 @@ def __init__(
6969
self._start_timeout = start_timeout
7070
self._stop_timeout = stop_timeout
7171
self._parent: Module | None = None
72+
self._context = Context()
7273
self._prepared = Event()
7374
self._started = Event()
7475
self._stopped = Event()
@@ -256,24 +257,24 @@ def put(
256257
teardown_callback: A callback to call when the value is torn down.
257258
manage: Whether to use the (async) context manager of the value for its setup/teardown.
258259
"""
259-
if self.parent is None:
260-
shared_value = self._context.put(
261-
value,
262-
types,
263-
max_borrowers=max_borrowers,
264-
manage=manage,
265-
teardown_callback=teardown_callback,
266-
)
267-
else:
268-
shared_value = self.parent._context.put(
260+
value_id = id(value)
261+
shared_value = self._context.put(
262+
value,
263+
types,
264+
max_borrowers=max_borrowers,
265+
manage=manage,
266+
teardown_callback=teardown_callback,
267+
)
268+
self._published_values[value_id] = shared_value
269+
if self.parent is not None:
270+
self.parent._context.put(
269271
value,
270272
types,
271273
max_borrowers=max_borrowers,
272274
manage=manage,
273275
teardown_callback=teardown_callback,
276+
shared_value=shared_value,
274277
)
275-
value_id = id(value)
276-
self._published_values[value_id] = shared_value
277278
log.debug("Module added value", path=self.path, types=types)
278279

279280
async def get(
@@ -290,10 +291,24 @@ async def get(
290291
The borrowed value.
291292
"""
292293
log.debug("Module getting value", path=self.path, value_type=value_type)
294+
tasks = [create_task(self._context.get(value_type), self._task_group)]
295+
if self.parent is not None:
296+
tasks.append(
297+
create_task(self.parent._context.get(value_type), self._task_group)
298+
)
293299
value_acquired = False
294300
try:
295-
value = await self._context.get(value_type, timeout=timeout)
296-
value_acquired = True
301+
with fail_after(timeout):
302+
done, pending = await wait(
303+
tasks, self._task_group, return_when=FIRST_COMPLETED
304+
)
305+
for task in pending:
306+
task.cancel()
307+
for task in done:
308+
break
309+
value = await task.wait()
310+
value = cast(Value, value)
311+
value_acquired = True
297312
finally:
298313
if not value_acquired:
299314
log.critical(
@@ -441,27 +456,26 @@ async def _finish(self):
441456

442457
async def _drop_and_wait_values(self):
443458
self.drop_all()
459+
await self._context.aclose()
444460
self._stopped.set()
445461
log.debug("Module stopped", path=self.path)
446462

447463
async def _prepare(self) -> None:
448464
log.debug("Preparing module", path=self.path)
449-
async with Context() as self._context:
450-
try:
451-
async with create_task_group() as tg:
452-
for module in self._modules.values():
453-
module._task_group = tg
454-
module._phase = self._phase
455-
module._exceptions = self._exceptions
456-
tg.start_soon(module._prepare, name=f"{module.path} _prepare")
457-
tg.start_soon(
458-
self._prepare_and_done, name=f"{self.path} _prepare_and_done"
459-
)
460-
except ExceptionGroup as exc:
461-
self._exceptions.append(*exc.exceptions)
462-
self._exit.set()
463-
log.critical("Module failed while preparing", path=self.path)
464-
await self._stopped.wait()
465+
try:
466+
async with create_task_group() as tg:
467+
for module in self._modules.values():
468+
module._task_group = tg
469+
module._phase = self._phase
470+
module._exceptions = self._exceptions
471+
tg.start_soon(module._prepare, name=f"{module.path} _prepare")
472+
tg.start_soon(
473+
self._prepare_and_done, name=f"{self.path} _prepare_and_done"
474+
)
475+
except ExceptionGroup as exc:
476+
self._exceptions.append(*exc.exceptions)
477+
self._exit.set()
478+
log.critical("Module failed while preparing", path=self.path)
465479

466480
async def _prepare_and_done(self) -> None:
467481
await self.prepare()

tests/test_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ async def test_nested_contexts():
5454
acquired_value_0 = await get(str)
5555
assert published_value_0 is acquired_value_0.unwrap()
5656
acquired_value_0.drop()
57+
async with Context():
58+
acquired_value_1 = await get(str)
59+
assert published_value_0 is acquired_value_1.unwrap()
60+
acquired_value_1.drop()
5761

5862

5963
async def test_context_cm():

tests/test_value.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ class Module2(Module):
9292
async def start(self):
9393
self.value2 = Value2()
9494
self.put(self.value2)
95-
self.value0 = await self.get(Value0, timeout=0.1)
95+
try:
96+
self.value0 = await self.get(Value0, timeout=0.1)
97+
except TimeoutError:
98+
self.value0 = None
9699
self.value1 = await self.get(Value1, timeout=0.1)
97100

98101
async with Module0("module0") as module0:
@@ -104,7 +107,7 @@ async def start(self):
104107
assert module0.value2 is None
105108
assert module1.value0 == module0.value0
106109
assert module1.value2 == module2.value2
107-
assert module2.value0 == module0.value0
110+
assert module2.value0 is None
108111
assert module2.value1 == module1.value1
109112

110113

@@ -251,8 +254,12 @@ async def stop(self):
251254
async with Module0("module0", stop_timeout=0.1) as module0:
252255
pass
253256

254-
assert len(module0.exceptions) == 1
255-
assert str(module0.exceptions[0]) == "Module timed out while stopping: module0"
257+
assert len(module0.exceptions) == 2
258+
assert (
259+
str(module0.exceptions[0])
260+
== "Module timed out while stopping: module0.submodule0"
261+
)
262+
assert str(module0.exceptions[1]) == "Module timed out while stopping: module0"
256263

257264

258265
async def test_all_freed():

0 commit comments

Comments
 (0)