Skip to content

Commit 5b87c45

Browse files
committed
Cache getter instance
1 parent 83629d7 commit 5b87c45

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

aiomisc/thread_pool.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
from types import MappingProxyType
1616
from typing import (
1717
Any, Awaitable, Callable, Coroutine, Dict, FrozenSet, Generator, Generic,
18-
Optional, Set, Tuple, TypeVar, Union, overload,
18+
Optional, Set, Tuple, TypeVar, Union, overload, MutableMapping,
1919
)
20+
from weakref import WeakKeyDictionary
2021

2122
from ._context_vars import EVENT_LOOP
2223
from .compat import Concatenate, ParamSpec
@@ -379,6 +380,8 @@ class Threaded(ThreadedBase[P, T]):
379380
func_type: type
380381

381382
def __init__(self, func: Callable[P, T]) -> None:
383+
self.__cache: MutableMapping[Any, Any] = WeakKeyDictionary()
384+
382385
if isinstance(func, staticmethod):
383386
self.func_type = staticmethod
384387
self.func = func.__func__
@@ -415,14 +418,22 @@ def __get__(
415418
instance: Any,
416419
owner: Optional[type] = None,
417420
) -> "Threaded[P, T] | BoundThreaded[Any, T]":
421+
key = instance
422+
if key in self.__cache:
423+
return self.__cache[key]
424+
418425
if self.func_type is staticmethod:
419-
return self
426+
result = self
420427
elif self.func_type is classmethod:
421428
cls = owner if instance is None else type(instance)
422-
return BoundThreaded(self.func, cls)
429+
result = BoundThreaded(self.func, cls)
423430
elif instance is not None:
424-
return BoundThreaded(self.func, instance)
425-
return self
431+
result = BoundThreaded(self.func, instance)
432+
else:
433+
result = self
434+
435+
self.__cache[key] = result
436+
return result
426437

427438

428439
class BoundThreaded(ThreadedBase[P, T]):
@@ -570,6 +581,7 @@ def __init__(
570581

571582
self.func = actual_func
572583
self.max_size = max_size
584+
self.__cache: MutableMapping[Any, Any] = WeakKeyDictionary()
573585

574586
@overload
575587
def __get__(
@@ -592,14 +604,22 @@ def __get__(
592604
instance: Any,
593605
owner: Optional[type] = None,
594606
) -> "ThreadedIterable[P, T] | BoundThreadedIterable[Any, T]":
607+
key = instance
608+
if key in self.__cache:
609+
return self.__cache[key]
610+
595611
if self.func_type is staticmethod:
596-
return self
612+
result = self
597613
elif self.func_type is classmethod:
598614
cls = owner if instance is None else type(instance)
599-
return BoundThreadedIterable(self.func, cls, self.max_size)
615+
result = BoundThreadedIterable(self.func, cls, self.max_size)
600616
elif instance is not None:
601-
return BoundThreadedIterable(self.func, instance, self.max_size)
602-
return self
617+
result = BoundThreadedIterable(self.func, instance, self.max_size)
618+
else:
619+
result = self
620+
621+
self.__cache[key] = result
622+
return result
603623

604624

605625
class BoundThreadedIterable(ThreadedIterableBase[P, T]):

tests/test_thread_pool.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ def foo(self):
584584
return 42
585585

586586
instance = TestClass()
587+
assert instance.foo is instance.foo
587588
assert instance.foo.sync_call() == 42
588589
assert await instance.foo() == 42
589590
assert await instance.foo.async_call() == 42
@@ -597,6 +598,7 @@ def foo():
597598
return 42
598599

599600
instance = TestClass()
601+
assert instance.foo is instance.foo
600602
assert instance.foo.sync_call() == 42
601603
assert await instance.foo() == 42
602604
assert await instance.foo.async_call() == 42
@@ -610,6 +612,7 @@ def foo(cls):
610612
return 42
611613

612614
instance = TestClass()
615+
assert instance.foo is instance.foo
613616
assert instance.foo.sync_call() == 42
614617
assert await instance.foo() == 42
615618
assert await instance.foo.async_call() == 42
@@ -620,6 +623,7 @@ async def test_threaded_iterator_class_func():
620623
def foo():
621624
yield 42
622625

626+
assert foo is foo
623627
assert list(foo.sync_call()) == [42]
624628
assert [x async for x in foo()] == [42]
625629
assert [x async for x in foo.async_call()] == [42]
@@ -632,6 +636,7 @@ def foo(self):
632636
yield 42
633637

634638
instance = TestClass()
639+
assert instance.foo is instance.foo
635640
assert list(instance.foo.sync_call()) == [42]
636641
assert [x async for x in instance.foo()] == [42]
637642
assert [x async for x in instance.foo.async_call()] == [42]
@@ -645,6 +650,7 @@ def foo():
645650
yield 42
646651

647652
instance = TestClass()
653+
assert instance.foo is instance.foo
648654
assert list(instance.foo.sync_call()) == [42]
649655
assert [x async for x in instance.foo()] == [42]
650656
assert [x async for x in instance.foo.async_call()] == [42]
@@ -658,6 +664,7 @@ def foo(cls):
658664
yield 42
659665

660666
instance = TestClass()
667+
assert instance.foo is instance.foo
661668
assert list(instance.foo.sync_call()) == [42]
662669
assert [x async for x in instance.foo()] == [42]
663670
assert [x async for x in instance.foo.async_call()] == [42]

0 commit comments

Comments
 (0)