Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions aiomisc/thread_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from types import MappingProxyType
from typing import (
Any, Awaitable, Callable, Coroutine, Dict, FrozenSet, Generator, Generic,
Optional, Set, Tuple, TypeVar, Union, overload,
Optional, Set, Tuple, TypeVar, Union, overload, MutableMapping,
)
from weakref import WeakKeyDictionary

from ._context_vars import EVENT_LOOP
from .compat import Concatenate, ParamSpec
Expand Down Expand Up @@ -379,6 +380,8 @@ class Threaded(ThreadedBase[P, T]):
func_type: type

def __init__(self, func: Callable[P, T]) -> None:
self.__cache: MutableMapping[Any, Any] = WeakKeyDictionary()

if isinstance(func, staticmethod):
self.func_type = staticmethod
self.func = func.__func__
Expand Down Expand Up @@ -415,14 +418,22 @@ def __get__(
instance: Any,
owner: Optional[type] = None,
) -> "Threaded[P, T] | BoundThreaded[Any, T]":
key = instance
result: Any
if key in self.__cache:
return self.__cache[key]
if self.func_type is staticmethod:
return self
result = self
elif self.func_type is classmethod:
cls = owner if instance is None else type(instance)
return BoundThreaded(self.func, cls)
result = BoundThreaded(self.func, cls)
elif instance is not None:
return BoundThreaded(self.func, instance)
return self
result = BoundThreaded(self.func, instance)
else:
result = self

self.__cache[key] = result
return result


class BoundThreaded(ThreadedBase[P, T]):
Expand Down Expand Up @@ -570,6 +581,7 @@ def __init__(

self.func = actual_func
self.max_size = max_size
self.__cache: MutableMapping[Any, Any] = WeakKeyDictionary()

@overload
def __get__(
Expand All @@ -592,14 +604,23 @@ def __get__(
instance: Any,
owner: Optional[type] = None,
) -> "ThreadedIterable[P, T] | BoundThreadedIterable[Any, T]":
key = instance
result: Any
if key in self.__cache:
return self.__cache[key]

if self.func_type is staticmethod:
return self
result = self
elif self.func_type is classmethod:
cls = owner if instance is None else type(instance)
return BoundThreadedIterable(self.func, cls, self.max_size)
result = BoundThreadedIterable(self.func, cls, self.max_size)
elif instance is not None:
return BoundThreadedIterable(self.func, instance, self.max_size)
return self
result = BoundThreadedIterable(self.func, instance, self.max_size)
else:
result = self

self.__cache[key] = result
return result


class BoundThreadedIterable(ThreadedIterableBase[P, T]):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_thread_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ def foo(self):
return 42

instance = TestClass()
assert instance.foo is instance.foo
assert instance.foo.sync_call() == 42
assert await instance.foo() == 42
assert await instance.foo.async_call() == 42
Expand All @@ -597,6 +598,7 @@ def foo():
return 42

instance = TestClass()
assert instance.foo is instance.foo
assert instance.foo.sync_call() == 42
assert await instance.foo() == 42
assert await instance.foo.async_call() == 42
Expand All @@ -610,6 +612,7 @@ def foo(cls):
return 42

instance = TestClass()
assert instance.foo is instance.foo
assert instance.foo.sync_call() == 42
assert await instance.foo() == 42
assert await instance.foo.async_call() == 42
Expand All @@ -620,6 +623,7 @@ async def test_threaded_iterator_class_func():
def foo():
yield 42

assert foo is foo
assert list(foo.sync_call()) == [42]
assert [x async for x in foo()] == [42]
assert [x async for x in foo.async_call()] == [42]
Expand All @@ -632,6 +636,7 @@ def foo(self):
yield 42

instance = TestClass()
assert instance.foo is instance.foo
assert list(instance.foo.sync_call()) == [42]
assert [x async for x in instance.foo()] == [42]
assert [x async for x in instance.foo.async_call()] == [42]
Expand All @@ -645,6 +650,7 @@ def foo():
yield 42

instance = TestClass()
assert instance.foo is instance.foo
assert list(instance.foo.sync_call()) == [42]
assert [x async for x in instance.foo()] == [42]
assert [x async for x in instance.foo.async_call()] == [42]
Expand All @@ -658,6 +664,7 @@ def foo(cls):
yield 42

instance = TestClass()
assert instance.foo is instance.foo
assert list(instance.foo.sync_call()) == [42]
assert [x async for x in instance.foo()] == [42]
assert [x async for x in instance.foo.async_call()] == [42]
Loading