Skip to content

Commit f9021d2

Browse files
committed
Threaded is now class
1 parent 8e85546 commit f9021d2

File tree

4 files changed

+193
-72
lines changed

4 files changed

+193
-72
lines changed

aiomisc/iterator_wrapper.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from types import TracebackType
1010
from typing import (
1111
Any, AsyncIterator, Awaitable, Callable, Deque, Generator, NoReturn,
12-
Optional, Type, TypeVar, Union,
12+
Optional, Type, TypeVar, Union, Generic, ParamSpec,
1313
)
1414
from weakref import finalize
1515

@@ -19,8 +19,9 @@
1919

2020
T = TypeVar("T")
2121
R = TypeVar("R")
22+
P = ParamSpec("P")
2223

23-
GenType = Generator[T, R, None]
24+
GenType = Generator[T, None, None]
2425
FuncType = Callable[[], GenType]
2526

2627

@@ -144,7 +145,7 @@ class IteratorWrapperStatistic(Statistic):
144145
enqueued: int
145146

146147

147-
class IteratorWrapper(AsyncIterator, EventLoopMixin):
148+
class IteratorWrapper(Generic[P, T], AsyncIterator, EventLoopMixin):
148149
__slots__ = (
149150
"__channel",
150151
"__close_event",
@@ -155,9 +156,11 @@ class IteratorWrapper(AsyncIterator, EventLoopMixin):
155156
) + EventLoopMixin.__slots__
156157

157158
def __init__(
158-
self, gen_func: FuncType,
159+
self,
160+
gen_func: Callable[P, Generator[T, None, None]],
159161
loop: Optional[asyncio.AbstractEventLoop] = None,
160-
max_size: int = 0, executor: Optional[Executor] = None,
162+
max_size: int = 0,
163+
executor: Optional[Executor] = None,
161164
statistic_name: Optional[str] = None,
162165
):
163166

@@ -227,11 +230,9 @@ async def wait_closed(self) -> None:
227230
await asyncio.gather(self.__gen_task, return_exceptions=True)
228231

229232
def _run(self) -> Any:
230-
return self.loop.run_in_executor(
231-
self.executor, self._in_thread,
232-
)
233+
return self.loop.run_in_executor(self.executor, self._in_thread)
233234

234-
def __aiter__(self) -> AsyncIterator[Any]:
235+
def __aiter__(self) -> AsyncIterator[T]:
235236
if not self.loop.is_running():
236237
raise RuntimeError("Event loop is not running")
237238

@@ -242,7 +243,7 @@ def __aiter__(self) -> AsyncIterator[Any]:
242243
self.__gen_task = gen_task
243244
return IteratorProxy(self, self.close)
244245

245-
async def __anext__(self) -> Awaitable[T]:
246+
async def __anext__(self) -> T:
246247
try:
247248
item, is_exc = await self.__channel.get()
248249
except ChannelClosed:
@@ -269,13 +270,13 @@ async def __aexit__(
269270
await self.close()
270271

271272

272-
class IteratorProxy(AsyncIterator):
273+
class IteratorProxy(Generic[T], AsyncIterator):
273274
def __init__(
274-
self, iterator: AsyncIterator,
275+
self, iterator: AsyncIterator[T],
275276
finalizer: Callable[[], Any],
276277
):
277278
self.__iterator = iterator
278279
finalize(self, finalizer)
279280

280-
def __anext__(self) -> Awaitable[Any]:
281+
def __anext__(self) -> Awaitable[T]:
281282
return self.__iterator.__anext__()

aiomisc/service/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
T = TypeVar("T")
12-
CoroutineType = Union[Coroutine[Any, Any, T], Generator[Any, None, T]]
12+
CoroutineType = Union[Coroutine[Any, Any, T]]
1313

1414

1515
class ServiceMeta(ABCMeta):

aiomisc/thread_pool.py

Lines changed: 175 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import contextvars
33
import inspect
44
import logging
5+
import os
56
import threading
67
import time
78
import warnings
@@ -12,21 +13,24 @@
1213
from queue import SimpleQueue
1314
from types import MappingProxyType
1415
from typing import (
15-
Any, Awaitable, Callable, Coroutine, Dict, FrozenSet, Optional, Set, Tuple,
16-
TypeVar,
16+
Any, Awaitable, Callable, Coroutine, Dict, FrozenSet, Generic,
17+
Optional, Set, Tuple, TypeVar, Generator, overload, Union
1718
)
1819

1920
from ._context_vars import EVENT_LOOP
2021
from .compat import ParamSpec
2122
from .counters import Statistic
2223
from .iterator_wrapper import IteratorWrapper
2324

24-
2525
P = ParamSpec("P")
2626
T = TypeVar("T")
2727
F = TypeVar("F", bound=Callable[..., Any])
2828
log = logging.getLogger(__name__)
2929

30+
THREADED_ITERABLE_DEFAULT_MAX_SIZE = int(
31+
os.getenv("THREADED_ITERABLE_DEFAULT_MAX_SIZE", 1024)
32+
)
33+
3034

3135
def context_partial(
3236
func: F, *args: Any,
@@ -327,6 +331,7 @@ async def lazy_wrapper() -> T:
327331
return await loop.run_in_executor(
328332
executor, partial(func, *args, **kwargs),
329333
)
334+
330335
return lazy_wrapper()
331336

332337

@@ -340,22 +345,54 @@ async def _awaiter(future: asyncio.Future) -> T:
340345
raise
341346

342347

348+
class Threaded(Generic[P, T]):
349+
__slots__ = ("func",)
350+
351+
def __init__(self, func: Callable[P, T]) -> None:
352+
if asyncio.iscoroutinefunction(func):
353+
raise TypeError("Can not wrap coroutine")
354+
if inspect.isgeneratorfunction(func):
355+
raise TypeError("Can not wrap generator function")
356+
self.func = func
357+
358+
def sync_call(self, *args: P.args, **kwargs: P.kwargs) -> T:
359+
return self.func(*args, **kwargs)
360+
361+
def async_call(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T]:
362+
return run_in_executor(func=self.func, args=args, kwargs=kwargs)
363+
364+
def __repr__(self) -> str:
365+
return f"<Threaded {self.func.__name__} at {id(self):#x}>"
366+
367+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T]:
368+
return self.async_call(*args, **kwargs)
369+
370+
def __get__(self, instance: Any, owner: Optional[type] = None) -> Any:
371+
if instance is None:
372+
return self
373+
return partial(self.async_call, instance)
374+
375+
376+
@overload
377+
def threaded(func: Callable[P, T]) -> Threaded[P, T]: ...
378+
379+
380+
@overload
343381
def threaded(
344-
func: Callable[P, T],
345-
) -> Callable[P, Awaitable[T]]:
346-
if asyncio.iscoroutinefunction(func):
347-
raise TypeError("Can not wrap coroutine")
382+
func: Callable[P, Generator[T, None, None]]
383+
) -> Callable[P, IteratorWrapper[P, T]]: ...
348384

349-
if inspect.isgeneratorfunction(func):
350-
return threaded_iterable(func)
351385

352-
@wraps(func)
353-
def wrap(
354-
*args: P.args, **kwargs: P.kwargs,
355-
) -> Awaitable[T]:
356-
return run_in_executor(func=func, args=args, kwargs=kwargs)
386+
def threaded(
387+
func: Callable[P, T] | Callable[P, Generator[T, None, None]]
388+
) -> Threaded[P, T] | Callable[P, IteratorWrapper[P, T]]:
389+
if inspect.isgeneratorfunction(func):
390+
return threaded_iterable(
391+
func,
392+
max_size=THREADED_ITERABLE_DEFAULT_MAX_SIZE
393+
)
357394

358-
return wrap
395+
return Threaded(func) # type: ignore
359396

360397

361398
def run_in_new_thread(
@@ -390,67 +427,156 @@ def run_in_new_thread(
390427
return future
391428

392429

430+
class ThreadedSeparate(Threaded[P, T]):
431+
"""
432+
A decorator to run a function in a separate thread.
433+
It returns an `asyncio.Future` that can be awaited.
434+
"""
435+
436+
def __init__(self, func: Callable[P, T], detach: bool = True) -> None:
437+
super().__init__(func)
438+
self.detach = detach
439+
440+
def async_call(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T]:
441+
return run_in_new_thread(
442+
self.func, args=args, kwargs=kwargs, detach=self.detach,
443+
)
444+
445+
393446
def threaded_separate(
394-
func: F,
447+
func: Callable[P, T],
395448
detach: bool = True,
396-
) -> Callable[..., Awaitable[Any]]:
449+
) -> ThreadedSeparate[P, T]:
397450
if isinstance(func, bool):
398451
return partial(threaded_separate, detach=detach)
399452

400453
if asyncio.iscoroutinefunction(func):
401454
raise TypeError("Can not wrap coroutine")
402455

403-
@wraps(func)
404-
def wrap(*args: Any, **kwargs: Any) -> Any:
405-
future = run_in_new_thread(
406-
func, args=args, kwargs=kwargs, detach=detach,
456+
return ThreadedSeparate(func, detach=detach)
457+
458+
459+
class ThreadedIterable(Generic[P, T]):
460+
def __init__(
461+
self,
462+
func: Callable[P, Generator[T, None, None]],
463+
max_size: int = 0
464+
) -> None:
465+
self.func = func
466+
self.max_size = max_size
467+
468+
def sync_call(
469+
self, *args: P.args, **kwargs: P.kwargs
470+
) -> Generator[T, None, None]:
471+
return self.func(*args, **kwargs)
472+
473+
def async_call(
474+
self, *args: P.args, **kwargs: P.kwargs
475+
) -> IteratorWrapper[P, T]:
476+
return self.create_wrapper(*args, **kwargs)
477+
478+
def create_wrapper(
479+
self, *args: P.args, **kwargs: P.kwargs
480+
) -> IteratorWrapper[P, T]:
481+
return IteratorWrapper(
482+
partial(self.func, *args, **kwargs),
483+
max_size=self.max_size,
407484
)
408-
return future
409485

410-
return wrap
486+
def __call__(
487+
self,
488+
*args: P.args,
489+
**kwargs: P.kwargs
490+
) -> IteratorWrapper[P, T]:
491+
return self.async_call(*args, **kwargs)
411492

493+
def __get__(
494+
self,
495+
instance: Any,
496+
owner: Optional[type] = None
497+
) -> Any:
498+
if instance is None:
499+
return self
500+
return partial(self.async_call, instance)
412501

502+
503+
@overload
413504
def threaded_iterable(
414-
func: Optional[F] = None,
505+
func: Callable[P, Generator[T, None, None]],
506+
*,
415507
max_size: int = 0,
416-
) -> Any:
417-
if isinstance(func, int):
418-
return partial(threaded_iterable, max_size=func)
419-
if func is None:
420-
return partial(threaded_iterable, max_size=max_size)
508+
) -> "ThreadedIterable[P, T]": ...
421509

422-
@wraps(func)
423-
def wrap(*args: Any, **kwargs: Any) -> Any:
424-
return IteratorWrapper(
425-
partial(func, *args, **kwargs),
426-
max_size=max_size,
427-
)
428510

429-
return wrap
511+
@overload
512+
def threaded_iterable(
513+
*,
514+
max_size: int = 0,
515+
) -> Callable[
516+
[Callable[P, Generator[T, None, None]]], ThreadedIterable[P, T]]: ...
430517

431518

519+
def threaded_iterable(
520+
func: Optional[Callable[P, Generator[T, None, None]]] = None,
521+
*,
522+
max_size: int = 0,
523+
) -> Union[
524+
ThreadedIterable[P, T],
525+
Callable[[Callable[P, Generator[T, None, None]]],
526+
ThreadedIterable[P, T]]
527+
]:
528+
if func is None:
529+
return lambda f: ThreadedIterable(f, max_size=max_size)
530+
531+
return ThreadedIterable(func, max_size=max_size)
532+
432533
class IteratorWrapperSeparate(IteratorWrapper):
433534
def _run(self) -> Any:
434535
return run_in_new_thread(self._in_thread)
435536

436537

538+
class ThreadedIterableSeparate(ThreadedIterable[P, T]):
539+
def create_wrapper(
540+
self, *args: P.args, **kwargs: P.kwargs
541+
) -> IteratorWrapperSeparate:
542+
return IteratorWrapperSeparate(
543+
partial(self.func, *args, **kwargs),
544+
max_size=self.max_size,
545+
)
546+
547+
548+
@overload
437549
def threaded_iterable_separate(
438-
func: Optional[F] = None,
550+
func: Callable[P, Generator[T, None, None]],
551+
*,
439552
max_size: int = 0,
440-
) -> Any:
441-
if isinstance(func, int):
442-
return partial(threaded_iterable_separate, max_size=func)
553+
) -> "ThreadedIterable[P, T]": ...
554+
555+
556+
@overload
557+
def threaded_iterable_separate(
558+
*,
559+
max_size: int = 0,
560+
) -> Callable[
561+
[Callable[P, Generator[T, None, None]]],
562+
ThreadedIterableSeparate[P, T]
563+
]: ...
564+
565+
566+
def threaded_iterable_separate(
567+
func: Optional[Callable[P, Generator[T, None, None]]] = None,
568+
*,
569+
max_size: int = 0,
570+
) -> Union[
571+
ThreadedIterable[P, T],
572+
Callable[[Callable[P, Generator[T, None, None]]],
573+
ThreadedIterableSeparate[P, T]]
574+
]:
443575
if func is None:
444-
return partial(threaded_iterable_separate, max_size=max_size)
576+
return lambda f: ThreadedIterableSeparate(f, max_size=max_size)
445577

446-
@wraps(func)
447-
def wrap(*args: Any, **kwargs: Any) -> Any:
448-
return IteratorWrapperSeparate(
449-
partial(func, *args, **kwargs),
450-
max_size=max_size,
451-
)
578+
return ThreadedIterableSeparate(func, max_size=max_size)
452579

453-
return wrap
454580

455581

456582
class CoroutineWaiter:
@@ -509,4 +635,5 @@ def sync_await(
509635
) -> T:
510636
async def awaiter() -> T:
511637
return await func(*args, **kwargs)
638+
512639
return wait_coroutine(awaiter())

0 commit comments

Comments
 (0)