22
33import asyncio
44import logging
5- import sys
6- import threading
75import typing as t
86from dataclasses import dataclass , field
97
1614logger = logging .getLogger (__name__ )
1715
1816
19- def runner_exception_hook (args : threading .ExceptHookArgs ):
20- raise args .exc_type
17+ def is_event_loop_running () -> bool :
18+ try :
19+ loop = asyncio .get_running_loop ()
20+ except RuntimeError :
21+ return False
22+ else :
23+ return loop .is_running ()
2124
2225
23- # set a custom exception hook
24- # threading.excepthook = runner_exception_hook
25-
26-
27- def as_completed (loop , coros , max_workers ):
28- loop_arg_dict = {"loop" : loop } if sys .version_info [:2 ] < (3 , 10 ) else {}
26+ def as_completed (coros , max_workers ):
2927 if max_workers == - 1 :
30- return asyncio .as_completed (coros , ** loop_arg_dict )
28+ return asyncio .as_completed (coros )
3129
32- # loop argument is removed since Python 3.10
33- semaphore = asyncio .Semaphore (max_workers , ** loop_arg_dict )
30+ semaphore = asyncio .Semaphore (max_workers )
3431
3532 async def sema_coro (coro ):
3633 async with semaphore :
3734 return await coro
3835
3936 sema_coros = [sema_coro (c ) for c in coros ]
40- return asyncio .as_completed (sema_coros , ** loop_arg_dict )
41-
42-
43- class Runner (threading .Thread ):
44- def __init__ (
45- self ,
46- jobs : t .List [t .Tuple [t .Coroutine , str ]],
47- desc : str ,
48- keep_progress_bar : bool = True ,
49- raise_exceptions : bool = True ,
50- run_config : t .Optional [RunConfig ] = None ,
51- ):
52- super ().__init__ ()
53- self .jobs = jobs
54- self .desc = desc
55- self .keep_progress_bar = keep_progress_bar
56- self .raise_exceptions = raise_exceptions
57- self .run_config = run_config or RunConfig ()
58-
59- # create task
60- try :
61- self .loop = asyncio .get_event_loop ()
62- except RuntimeError :
63- self .loop = asyncio .new_event_loop ()
64- self .futures = as_completed (
65- loop = self .loop ,
66- coros = [coro for coro , _ in self .jobs ],
67- max_workers = self .run_config .max_workers ,
68- )
69-
70- async def _aresults (self ) -> t .List [t .Any ]:
71- results = []
72- for future in tqdm (
73- self .futures ,
74- desc = self .desc ,
75- total = len (self .jobs ),
76- # whether you want to keep the progress bar after completion
77- leave = self .keep_progress_bar ,
78- ):
79- r = await future
80- results .append (r )
8137
82- return results
83-
84- def run (self ):
85- results = []
86- try :
87- results = self .loop .run_until_complete (self ._aresults ())
88- finally :
89- self .results = results
38+ return asyncio .as_completed (sema_coros )
9039
9140
9241@dataclass
@@ -95,21 +44,22 @@ class Executor:
9544 keep_progress_bar : bool = True
9645 jobs : t .List [t .Any ] = field (default_factory = list , repr = False )
9746 raise_exceptions : bool = False
98- run_config : t .Optional [RunConfig ] = field (default_factory = RunConfig , repr = False )
47+ run_config : t .Optional [RunConfig ] = field (default = None , repr = False )
9948
10049 def wrap_callable_with_index (self , callable : t .Callable , counter ):
10150 async def wrapped_callable_async (* args , ** kwargs ):
10251 result = np .nan
10352 try :
10453 result = await callable (* args , ** kwargs )
10554 except MaxRetriesExceeded as e :
55+ # this only for testset generation v2
10656 logger .warning (f"max retries exceeded for { e .evolution } " )
10757 except Exception as e :
10858 if self .raise_exceptions :
10959 raise e
11060 else :
11161 logger .error (
112- "Runner in Executor raised an exception" , exc_info = True
62+ "Runner in Executor raised an exception" , exc_info = False
11363 )
11464
11565 return counter , result
@@ -120,29 +70,40 @@ def submit(
12070 self , callable : t .Callable , * args , name : t .Optional [str ] = None , ** kwargs
12171 ):
12272 callable_with_index = self .wrap_callable_with_index (callable , len (self .jobs ))
123- self .jobs .append ((callable_with_index ( * args , ** kwargs ) , name ))
73+ self .jobs .append ((callable_with_index , args , kwargs , name ))
12474
12575 def results (self ) -> t .List [t .Any ]:
126- executor_job = Runner (
127- jobs = self .jobs ,
128- desc = self .desc ,
129- keep_progress_bar = self .keep_progress_bar ,
130- raise_exceptions = self .raise_exceptions ,
131- run_config = self .run_config ,
132- )
133- executor_job .start ()
134- try :
135- executor_job .join ()
136- finally :
137- ...
138-
139- if executor_job .results is None :
140- if self .raise_exceptions :
141- raise RuntimeError (
142- "Executor failed to complete. Please check logs above for full info."
76+ if is_event_loop_running ():
77+ # an event loop is running so call nested_asyncio to fix this
78+ try :
79+ import nest_asyncio
80+ except ImportError :
81+ raise ImportError (
82+ "It seems like your running this in a jupyter-like environment. Please install nest_asyncio with `pip install nest_asyncio` to make it work."
14383 )
144- else :
145- logger .error ("Executor failed to complete. Please check logs above." )
146- return []
147- sorted_results = sorted (executor_job .results , key = lambda x : x [0 ])
84+
85+ nest_asyncio .apply ()
86+
87+ # create a generator for which returns tasks as they finish
88+ futures_as_they_finish = as_completed (
89+ coros = [afunc (* args , ** kwargs ) for afunc , args , kwargs , _ in self .jobs ],
90+ max_workers = (self .run_config or RunConfig ()).max_workers ,
91+ )
92+
93+ async def _aresults () -> t .List [t .Any ]:
94+ results = []
95+ for future in tqdm (
96+ futures_as_they_finish ,
97+ desc = self .desc ,
98+ total = len (self .jobs ),
99+ # whether you want to keep the progress bar after completion
100+ leave = self .keep_progress_bar ,
101+ ):
102+ r = await future
103+ results .append (r )
104+
105+ return results
106+
107+ results = asyncio .run (_aresults ())
108+ sorted_results = sorted (results , key = lambda x : x [0 ])
148109 return [r [1 ] for r in sorted_results ]
0 commit comments