Skip to content

Commit abb0943

Browse files
committed
Cache getter instance
1 parent 83629d7 commit abb0943

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

aiomisc/thread_pool.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ class Threaded(ThreadedBase[P, T]):
379379
func_type: type
380380

381381
def __init__(self, func: Callable[P, T]) -> None:
382+
self.__wrapped: Dict[Any, Any] = {}
383+
382384
if isinstance(func, staticmethod):
383385
self.func_type = staticmethod
384386
self.func = func.__func__
@@ -415,14 +417,22 @@ def __get__(
415417
instance: Any,
416418
owner: Optional[type] = None,
417419
) -> "Threaded[P, T] | BoundThreaded[Any, T]":
420+
key = (instance, owner)
421+
if key in self.__wrapped:
422+
return self.__wrapped[key]
423+
418424
if self.func_type is staticmethod:
419-
return self
425+
result = self
420426
elif self.func_type is classmethod:
421427
cls = owner if instance is None else type(instance)
422-
return BoundThreaded(self.func, cls)
428+
result = BoundThreaded(self.func, cls)
423429
elif instance is not None:
424-
return BoundThreaded(self.func, instance)
425-
return self
430+
result = BoundThreaded(self.func, instance)
431+
else:
432+
result = self
433+
434+
self.__wrapped[key] = result
435+
return result
426436

427437

428438
class BoundThreaded(ThreadedBase[P, T]):
@@ -570,6 +580,7 @@ def __init__(
570580

571581
self.func = actual_func
572582
self.max_size = max_size
583+
self.__wrapped: Dict[Any, Any] = {}
573584

574585
@overload
575586
def __get__(
@@ -592,14 +603,22 @@ def __get__(
592603
instance: Any,
593604
owner: Optional[type] = None,
594605
) -> "ThreadedIterable[P, T] | BoundThreadedIterable[Any, T]":
606+
key = (instance, owner)
607+
if key in self.__wrapped:
608+
return self.__wrapped[key]
609+
595610
if self.func_type is staticmethod:
596-
return self
611+
result = self
597612
elif self.func_type is classmethod:
598613
cls = owner if instance is None else type(instance)
599-
return BoundThreadedIterable(self.func, cls, self.max_size)
614+
result = BoundThreadedIterable(self.func, cls, self.max_size)
600615
elif instance is not None:
601-
return BoundThreadedIterable(self.func, instance, self.max_size)
602-
return self
616+
result = BoundThreadedIterable(self.func, instance, self.max_size)
617+
else:
618+
result = self
619+
620+
self.__wrapped[key] = result
621+
return result
603622

604623

605624
class BoundThreadedIterable(ThreadedIterableBase[P, T]):

tests/test_thread_pool.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ def foo(self):
584584
return 42
585585

586586
instance = TestClass()
587+
assert instance.foo is instance.foo
587588
assert instance.foo.sync_call() == 42
588589
assert await instance.foo() == 42
589590
assert await instance.foo.async_call() == 42
@@ -597,6 +598,7 @@ def foo():
597598
return 42
598599

599600
instance = TestClass()
601+
assert instance.foo is instance.foo
600602
assert instance.foo.sync_call() == 42
601603
assert await instance.foo() == 42
602604
assert await instance.foo.async_call() == 42
@@ -610,6 +612,7 @@ def foo(cls):
610612
return 42
611613

612614
instance = TestClass()
615+
assert instance.foo is instance.foo
613616
assert instance.foo.sync_call() == 42
614617
assert await instance.foo() == 42
615618
assert await instance.foo.async_call() == 42
@@ -620,6 +623,7 @@ async def test_threaded_iterator_class_func():
620623
def foo():
621624
yield 42
622625

626+
assert foo is foo
623627
assert list(foo.sync_call()) == [42]
624628
assert [x async for x in foo()] == [42]
625629
assert [x async for x in foo.async_call()] == [42]
@@ -632,6 +636,7 @@ def foo(self):
632636
yield 42
633637

634638
instance = TestClass()
639+
assert instance.foo is instance.foo
635640
assert list(instance.foo.sync_call()) == [42]
636641
assert [x async for x in instance.foo()] == [42]
637642
assert [x async for x in instance.foo.async_call()] == [42]
@@ -645,6 +650,7 @@ def foo():
645650
yield 42
646651

647652
instance = TestClass()
653+
assert instance.foo is instance.foo
648654
assert list(instance.foo.sync_call()) == [42]
649655
assert [x async for x in instance.foo()] == [42]
650656
assert [x async for x in instance.foo.async_call()] == [42]
@@ -658,6 +664,7 @@ def foo(cls):
658664
yield 42
659665

660666
instance = TestClass()
667+
assert instance.foo is instance.foo
661668
assert list(instance.foo.sync_call()) == [42]
662669
assert [x async for x in instance.foo()] == [42]
663670
assert [x async for x in instance.foo.async_call()] == [42]

0 commit comments

Comments
 (0)