22import contextvars
33import inspect
44import logging
5+ import os
56import threading
67import time
78import warnings
1213from queue import SimpleQueue
1314from types import MappingProxyType
1415from typing import (
15- Any , Awaitable , Callable , Coroutine , Dict , FrozenSet , Optional , Set , Tuple ,
16- TypeVar ,
16+ Any , Awaitable , Callable , Coroutine , Dict , FrozenSet , Generic ,
17+ Optional , Set , Tuple , TypeVar , Generator , overload , Union
1718)
1819
1920from ._context_vars import EVENT_LOOP
2021from .compat import ParamSpec
2122from .counters import Statistic
2223from .iterator_wrapper import IteratorWrapper
2324
24-
2525P = ParamSpec ("P" )
2626T = TypeVar ("T" )
2727F = TypeVar ("F" , bound = Callable [..., Any ])
2828log = logging .getLogger (__name__ )
2929
30+ THREADED_ITERABLE_DEFAULT_MAX_SIZE = int (
31+ os .getenv ("THREADED_ITERABLE_DEFAULT_MAX_SIZE" , 1024 )
32+ )
33+
3034
3135def context_partial (
3236 func : F , * args : Any ,
@@ -327,6 +331,7 @@ async def lazy_wrapper() -> T:
327331 return await loop .run_in_executor (
328332 executor , partial (func , * args , ** kwargs ),
329333 )
334+
330335 return lazy_wrapper ()
331336
332337
@@ -340,22 +345,54 @@ async def _awaiter(future: asyncio.Future) -> T:
340345 raise
341346
342347
348+ class Threaded (Generic [P , T ]):
349+ __slots__ = ("func" ,)
350+
351+ def __init__ (self , func : Callable [P , T ]) -> None :
352+ if asyncio .iscoroutinefunction (func ):
353+ raise TypeError ("Can not wrap coroutine" )
354+ if inspect .isgeneratorfunction (func ):
355+ raise TypeError ("Can not wrap generator function" )
356+ self .func = func
357+
358+ def sync_call (self , * args : P .args , ** kwargs : P .kwargs ) -> T :
359+ return self .func (* args , ** kwargs )
360+
361+ def async_call (self , * args : P .args , ** kwargs : P .kwargs ) -> Awaitable [T ]:
362+ return run_in_executor (func = self .func , args = args , kwargs = kwargs )
363+
364+ def __repr__ (self ) -> str :
365+ return f"<Threaded { self .func .__name__ } at { id (self ):#x} >"
366+
367+ def __call__ (self , * args : P .args , ** kwargs : P .kwargs ) -> Awaitable [T ]:
368+ return self .async_call (* args , ** kwargs )
369+
370+ def __get__ (self , instance : Any , owner : Optional [type ] = None ) -> Any :
371+ if instance is None :
372+ return self
373+ return partial (self .async_call , instance )
374+
375+
376+ @overload
377+ def threaded (func : Callable [P , T ]) -> Threaded [P , T ]: ...
378+
379+
380+ @overload
343381def threaded (
344- func : Callable [P , T ],
345- ) -> Callable [P , Awaitable [T ]]:
346- if asyncio .iscoroutinefunction (func ):
347- raise TypeError ("Can not wrap coroutine" )
382+ func : Callable [P , Generator [T , None , None ]]
383+ ) -> Callable [P , IteratorWrapper [P , T ]]: ...
348384
349- if inspect .isgeneratorfunction (func ):
350- return threaded_iterable (func )
351385
352- @wraps (func )
353- def wrap (
354- * args : P .args , ** kwargs : P .kwargs ,
355- ) -> Awaitable [T ]:
356- return run_in_executor (func = func , args = args , kwargs = kwargs )
386+ def threaded (
387+ func : Callable [P , T ] | Callable [P , Generator [T , None , None ]]
388+ ) -> Threaded [P , T ] | Callable [P , IteratorWrapper [P , T ]]:
389+ if inspect .isgeneratorfunction (func ):
390+ return threaded_iterable (
391+ func ,
392+ max_size = THREADED_ITERABLE_DEFAULT_MAX_SIZE
393+ )
357394
358- return wrap
395+ return Threaded ( func ) # type: ignore
359396
360397
361398def run_in_new_thread (
@@ -390,67 +427,156 @@ def run_in_new_thread(
390427 return future
391428
392429
430+ class ThreadedSeparate (Threaded [P , T ]):
431+ """
432+ A decorator to run a function in a separate thread.
433+ It returns an `asyncio.Future` that can be awaited.
434+ """
435+
436+ def __init__ (self , func : Callable [P , T ], detach : bool = True ) -> None :
437+ super ().__init__ (func )
438+ self .detach = detach
439+
440+ def async_call (self , * args : P .args , ** kwargs : P .kwargs ) -> Awaitable [T ]:
441+ return run_in_new_thread (
442+ self .func , args = args , kwargs = kwargs , detach = self .detach ,
443+ )
444+
445+
393446def threaded_separate (
394- func : F ,
447+ func : Callable [ P , T ] ,
395448 detach : bool = True ,
396- ) -> Callable [..., Awaitable [ Any ] ]:
449+ ) -> ThreadedSeparate [ P , T ]:
397450 if isinstance (func , bool ):
398451 return partial (threaded_separate , detach = detach )
399452
400453 if asyncio .iscoroutinefunction (func ):
401454 raise TypeError ("Can not wrap coroutine" )
402455
403- @wraps (func )
404- def wrap (* args : Any , ** kwargs : Any ) -> Any :
405- future = run_in_new_thread (
406- func , args = args , kwargs = kwargs , detach = detach ,
456+ return ThreadedSeparate (func , detach = detach )
457+
458+
459+ class ThreadedIterable (Generic [P , T ]):
460+ def __init__ (
461+ self ,
462+ func : Callable [P , Generator [T , None , None ]],
463+ max_size : int = 0
464+ ) -> None :
465+ self .func = func
466+ self .max_size = max_size
467+
468+ def sync_call (
469+ self , * args : P .args , ** kwargs : P .kwargs
470+ ) -> Generator [T , None , None ]:
471+ return self .func (* args , ** kwargs )
472+
473+ def async_call (
474+ self , * args : P .args , ** kwargs : P .kwargs
475+ ) -> IteratorWrapper [P , T ]:
476+ return self .create_wrapper (* args , ** kwargs )
477+
478+ def create_wrapper (
479+ self , * args : P .args , ** kwargs : P .kwargs
480+ ) -> IteratorWrapper [P , T ]:
481+ return IteratorWrapper (
482+ partial (self .func , * args , ** kwargs ),
483+ max_size = self .max_size ,
407484 )
408- return future
409485
410- return wrap
486+ def __call__ (
487+ self ,
488+ * args : P .args ,
489+ ** kwargs : P .kwargs
490+ ) -> IteratorWrapper [P , T ]:
491+ return self .async_call (* args , ** kwargs )
411492
493+ def __get__ (
494+ self ,
495+ instance : Any ,
496+ owner : Optional [type ] = None
497+ ) -> Any :
498+ if instance is None :
499+ return self
500+ return partial (self .async_call , instance )
412501
502+
503+ @overload
413504def threaded_iterable (
414- func : Optional [F ] = None ,
505+ func : Callable [P , Generator [T , None , None ]],
506+ * ,
415507 max_size : int = 0 ,
416- ) -> Any :
417- if isinstance (func , int ):
418- return partial (threaded_iterable , max_size = func )
419- if func is None :
420- return partial (threaded_iterable , max_size = max_size )
508+ ) -> "ThreadedIterable[P, T]" : ...
421509
422- @wraps (func )
423- def wrap (* args : Any , ** kwargs : Any ) -> Any :
424- return IteratorWrapper (
425- partial (func , * args , ** kwargs ),
426- max_size = max_size ,
427- )
428510
429- return wrap
511+ @overload
512+ def threaded_iterable (
513+ * ,
514+ max_size : int = 0 ,
515+ ) -> Callable [
516+ [Callable [P , Generator [T , None , None ]]], ThreadedIterable [P , T ]]: ...
430517
431518
519+ def threaded_iterable (
520+ func : Optional [Callable [P , Generator [T , None , None ]]] = None ,
521+ * ,
522+ max_size : int = 0 ,
523+ ) -> Union [
524+ ThreadedIterable [P , T ],
525+ Callable [[Callable [P , Generator [T , None , None ]]],
526+ ThreadedIterable [P , T ]]
527+ ]:
528+ if func is None :
529+ return lambda f : ThreadedIterable (f , max_size = max_size )
530+
531+ return ThreadedIterable (func , max_size = max_size )
532+
432533class IteratorWrapperSeparate (IteratorWrapper ):
433534 def _run (self ) -> Any :
434535 return run_in_new_thread (self ._in_thread )
435536
436537
538+ class ThreadedIterableSeparate (ThreadedIterable [P , T ]):
539+ def create_wrapper (
540+ self , * args : P .args , ** kwargs : P .kwargs
541+ ) -> IteratorWrapperSeparate :
542+ return IteratorWrapperSeparate (
543+ partial (self .func , * args , ** kwargs ),
544+ max_size = self .max_size ,
545+ )
546+
547+
548+ @overload
437549def threaded_iterable_separate (
438- func : Optional [F ] = None ,
550+ func : Callable [P , Generator [T , None , None ]],
551+ * ,
439552 max_size : int = 0 ,
440- ) -> Any :
441- if isinstance (func , int ):
442- return partial (threaded_iterable_separate , max_size = func )
553+ ) -> "ThreadedIterable[P, T]" : ...
554+
555+
556+ @overload
557+ def threaded_iterable_separate (
558+ * ,
559+ max_size : int = 0 ,
560+ ) -> Callable [
561+ [Callable [P , Generator [T , None , None ]]],
562+ ThreadedIterableSeparate [P , T ]
563+ ]: ...
564+
565+
566+ def threaded_iterable_separate (
567+ func : Optional [Callable [P , Generator [T , None , None ]]] = None ,
568+ * ,
569+ max_size : int = 0 ,
570+ ) -> Union [
571+ ThreadedIterable [P , T ],
572+ Callable [[Callable [P , Generator [T , None , None ]]],
573+ ThreadedIterableSeparate [P , T ]]
574+ ]:
443575 if func is None :
444- return partial ( threaded_iterable_separate , max_size = max_size )
576+ return lambda f : ThreadedIterableSeparate ( f , max_size = max_size )
445577
446- @wraps (func )
447- def wrap (* args : Any , ** kwargs : Any ) -> Any :
448- return IteratorWrapperSeparate (
449- partial (func , * args , ** kwargs ),
450- max_size = max_size ,
451- )
578+ return ThreadedIterableSeparate (func , max_size = max_size )
452579
453- return wrap
454580
455581
456582class CoroutineWaiter :
@@ -509,4 +635,5 @@ def sync_await(
509635) -> T :
510636 async def awaiter () -> T :
511637 return await func (* args , ** kwargs )
638+
512639 return wait_coroutine (awaiter ())
0 commit comments