Skip to content

Commit 132e5ef

Browse files
authored
fix: Prevent async deadlock on circular dependency
1 parent 364f583 commit 132e5ef

File tree

7 files changed

+145
-92
lines changed

7 files changed

+145
-92
lines changed

injection/_core/common/threading.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from threading import RLock
44
from typing import Any, ContextManager, Final
55

6-
_PYTHON_INJECTION_THREADSAFE: Final[bool] = bool(getenv("PYTHON_INJECTION_THREADSAFE"))
6+
_PYTHON_INJECTION_THREADSAFE: Final[bool] = bool(
7+
int(getenv("PYTHON_INJECTION_THREADSAFE", 0))
8+
)
79

810

911
def get_lock(threadsafe: bool | None = None) -> ContextManager[Any]:
10-
cond = _PYTHON_INJECTION_THREADSAFE if threadsafe is None else threadsafe
11-
return RLock() if cond else nullcontext()
12+
threadsafe = _PYTHON_INJECTION_THREADSAFE if threadsafe is None else threadsafe
13+
return RLock() if threadsafe else nullcontext()

injection/_core/injectables.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
2-
from collections.abc import Awaitable, Callable, MutableMapping
3-
from contextlib import suppress
2+
from collections.abc import Awaitable, Callable, Iterator, MutableMapping
3+
from contextlib import contextmanager, suppress
44
from dataclasses import dataclass, field
55
from functools import partial
66
from typing import (
@@ -53,11 +53,13 @@ def get_instance(self) -> T:
5353

5454

5555
class CacheLogic[T]:
56-
__slots__ = ("__semaphore",)
56+
__slots__ = ("__is_instantiating", "__semaphore")
5757

58+
__is_instantiating: bool
5859
__semaphore: AsyncContextManager[Any]
5960

6061
def __init__(self) -> None:
62+
self.__is_instantiating = False
6163
self.__semaphore = AsyncSemaphore(1)
6264

6365
async def aget_or_create[K](
@@ -66,11 +68,14 @@ async def aget_or_create[K](
6668
key: K,
6769
factory: Callable[..., Awaitable[T]],
6870
) -> T:
71+
self.__fail_if_instantiating()
6972
async with self.__semaphore:
7073
with suppress(KeyError):
7174
return cache[key]
7275

73-
instance = await factory()
76+
with self.__instantiating():
77+
instance = await factory()
78+
7479
cache[key] = instance
7580

7681
return instance
@@ -81,13 +86,29 @@ def get_or_create[K](
8186
key: K,
8287
factory: Callable[..., T],
8388
) -> T:
89+
self.__fail_if_instantiating()
8490
with suppress(KeyError):
8591
return cache[key]
8692

87-
instance = factory()
93+
with self.__instantiating():
94+
instance = factory()
95+
8896
cache[key] = instance
8997
return instance
9098

99+
def __fail_if_instantiating(self) -> None:
100+
if self.__is_instantiating:
101+
raise RecursionError("Recursive call detected during instantiation.")
102+
103+
@contextmanager
104+
def __instantiating(self) -> Iterator[None]:
105+
self.__is_instantiating = True
106+
107+
try:
108+
yield
109+
finally:
110+
self.__is_instantiating = False
111+
91112

92113
@dataclass(repr=False, eq=False, frozen=True, slots=True)
93114
class SingletonInjectable[T](Injectable[T]):

injection/_core/locator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def make_injected_function[**P, T](
6161
self,
6262
wrapped: Callable[P, T],
6363
/,
64+
threadsafe: bool | None = ...,
6465
) -> Callable[P, T]:
6566
raise NotImplementedError
6667

@@ -108,24 +109,24 @@ def request(self, provider: InjectionProvider) -> Injectable[T]:
108109

109110
injectable = _make_injectable(
110111
self.factory,
111-
provider.make_injected_function(self.recipe), # type: ignore[misc]
112+
provider.make_injected_function(self.recipe, threadsafe=False), # type: ignore[misc]
112113
)
113114
self.injectables[provider] = injectable
114115
return injectable
115116

116117

117118
@dataclass(repr=False, eq=False, frozen=True, slots=True)
118119
class StaticInjectableBroker[T](InjectableBroker[T]):
119-
value: Injectable[T]
120+
injectable: Injectable[T]
120121

121122
def get(self, provider: InjectionProvider) -> Injectable[T] | None:
122-
return self.value
123+
return self.injectable
123124

124125
def is_locked(self, provider: InjectionProvider) -> bool:
125126
return False
126127

127128
def request(self, provider: InjectionProvider) -> Injectable[T]:
128-
return self.value
129+
return self.injectable
129130

130131
@classmethod
131132
def from_factory(

injection/_core/scope.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def _bind_scope(
191191
kind: ScopeKind | ScopeKindStr,
192192
threadsafe: bool | None,
193193
) -> Iterator[ScopeFacade]:
194-
kind = ScopeKind(kind)
195194
lock = get_lock(threadsafe)
196195

197196
with lock:

tests/test_injectable.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,18 @@ def my_enum_recipe() -> MyEnum:
229229

230230
value = get_instance(MyEnum)
231231
assert isinstance(value, MyEnum)
232+
233+
async def test_injectable_with_circular_dependency_raise_recursion_error(self):
234+
class A: ...
235+
236+
@injectable
237+
@dataclass
238+
class B:
239+
a: A
240+
241+
@injectable
242+
async def a_factory(_b: B) -> A:
243+
return A()
244+
245+
with pytest.raises(RecursionError):
246+
await aget_instance(A)

tests/test_singleton.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,18 @@ class C(B): ...
183183

184184
a = get_instance(A)
185185
assert isinstance(a, C)
186+
187+
async def test_singleton_with_circular_dependency_raise_recursion_error(self):
188+
class A: ...
189+
190+
@singleton
191+
@dataclass
192+
class B:
193+
a: A
194+
195+
@singleton
196+
async def a_factory(_b: B) -> A:
197+
return A()
198+
199+
with pytest.raises(RecursionError):
200+
await aget_instance(A)

0 commit comments

Comments
 (0)