9
9
import traceback
10
10
import warnings
11
11
from _asyncio import Future , Task
12
- from concurrent .futures .process import ProcessPoolExecutor
13
12
from contextlib import suppress
14
- from typing import Any , Callable , List , Optional , Set , Tuple , Union
13
+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
15
14
16
- from distributed .cfexecutor import ClientExecutor
17
- from distributed .client import Client
18
- from ipyparallel .client .asyncresult import AsyncResult
19
- from ipyparallel .client .view import ViewExecutor
20
-
21
- from adaptive .learner import BaseLearner
15
+ from adaptive .learner .base_learner import BaseLearner
22
16
from adaptive .notebook_integration import in_ipynb , live_info , live_plot
23
17
18
+ _ThirdPartyClient = []
19
+ _ThirdPartyExecutor = []
20
+
24
21
try :
25
22
import ipyparallel
23
+ from ipyparallel .client .asyncresult import AsyncResult
26
24
27
25
with_ipyparallel = True
26
+ _ThirdPartyClient .append (ipyparallel .Client )
27
+ _ThirdPartyExecutor .append (ipyparallel .client .view .ViewExecutor )
28
28
except ModuleNotFoundError :
29
29
with_ipyparallel = False
30
30
31
31
try :
32
32
import distributed
33
33
34
34
with_distributed = True
35
+ _ThirdPartyClient .append (distributed .client .Client )
36
+ _ThirdPartyExecutor .append (distributed .cfexecutor .ClientExecutor )
35
37
except ModuleNotFoundError :
36
38
with_distributed = False
37
39
38
40
try :
39
41
import mpi4py .futures
40
42
41
43
with_mpi4py = True
44
+ _ThirdPartyExecutor .append (mpi4py .futures .MPIPoolExecutor )
42
45
except ModuleNotFoundError :
43
46
with_mpi4py = False
44
47
47
50
48
51
asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
49
52
53
+ ThirdPartyClient = Union [tuple (_ThirdPartyClient )]
54
+ ThirdPartyExecutor = Union [tuple (_ThirdPartyExecutor )]
50
55
51
56
if os .name == "nt" :
52
57
if with_distributed :
@@ -72,8 +77,80 @@ def _default_executor(*args, **kwargs):
72
77
_default_executor_kwargs = {}
73
78
74
79
80
+ # -- Internal executor-related, things
81
+
82
+
83
+ class SequentialExecutor (concurrent .Executor ):
84
+ """A trivial executor that runs functions synchronously.
85
+
86
+ This executor is mainly for testing.
87
+ """
88
+
89
+ def submit (self , fn : Callable , * args , ** kwargs ) -> Future :
90
+ fut = concurrent .Future ()
91
+ try :
92
+ fut .set_result (fn (* args , ** kwargs ))
93
+ except Exception as e :
94
+ fut .set_exception (e )
95
+ return fut
96
+
97
+ def map (self , fn , * iterable , timeout = None , chunksize = 1 ):
98
+ return map (fn , iterable )
99
+
100
+ def shutdown (self , wait = True ):
101
+ pass
102
+
103
+
104
+ def _ensure_executor (
105
+ executor : Optional [Union [ThirdPartyClient , concurrent .Executor ]]
106
+ ) -> concurrent .Executor :
107
+ if executor is None :
108
+ executor = _default_executor (** _default_executor_kwargs )
109
+
110
+ if isinstance (executor , concurrent .Executor ):
111
+ return executor
112
+ elif with_ipyparallel and isinstance (executor , ipyparallel .Client ):
113
+ return executor .executor ()
114
+ elif with_distributed and isinstance (executor , distributed .Client ):
115
+ return executor .get_executor ()
116
+ else :
117
+ raise TypeError (
118
+ "Only a concurrent.futures.Executor, distributed.Client,"
119
+ " or ipyparallel.Client can be used."
120
+ )
121
+
122
+
123
+ def _get_ncores (
124
+ ex : Union [
125
+ ThirdPartyExecutor ,
126
+ concurrent .ProcessPoolExecutor ,
127
+ concurrent .ThreadPoolExecutor ,
128
+ SequentialExecutor ,
129
+ ]
130
+ ) -> int :
131
+ """Return the maximum number of cores that an executor can use."""
132
+ if with_ipyparallel and isinstance (ex , ipyparallel .client .view .ViewExecutor ):
133
+ return len (ex .view )
134
+ elif isinstance (
135
+ ex , (concurrent .ProcessPoolExecutor , concurrent .ThreadPoolExecutor )
136
+ ):
137
+ return ex ._max_workers # not public API!
138
+ elif isinstance (ex , SequentialExecutor ):
139
+ return 1
140
+ elif with_distributed and isinstance (ex , distributed .cfexecutor .ClientExecutor ):
141
+ return sum (n for n in ex ._client .ncores ().values ())
142
+ elif with_mpi4py and isinstance (ex , mpi4py .futures .MPIPoolExecutor ):
143
+ ex .bootup () # wait until all workers are up and running
144
+ return ex ._pool .size # not public API!
145
+ else :
146
+ raise TypeError (f"Cannot get number of cores for { ex .__class__ } " )
147
+
148
+
149
+ # -- Runner definitions
150
+
151
+
75
152
class BaseRunner (metaclass = abc .ABCMeta ):
76
- r"""Base class for runners that use `concurrent.futures.Executors` .
153
+ r"""Base class for runners that use `concurrent.futures.Executor`\'s .
77
154
78
155
Parameters
79
156
----------
@@ -133,12 +210,17 @@ def __init__(
133
210
learner : BaseLearner ,
134
211
goal : Callable ,
135
212
* ,
136
- executor = None ,
137
- ntasks = None ,
138
- log = False ,
139
- shutdown_executor = False ,
140
- retries = 0 ,
141
- raise_if_retries_exceeded = True ,
213
+ executor : Union [
214
+ ThirdPartyExecutor ,
215
+ concurrent .ProcessPoolExecutor ,
216
+ concurrent .ThreadPoolExecutor ,
217
+ SequentialExecutor ,
218
+ ] = None ,
219
+ ntasks : int = None ,
220
+ log : bool = False ,
221
+ shutdown_executor : bool = False ,
222
+ retries : int = 0 ,
223
+ raise_if_retries_exceeded : bool = True ,
142
224
) -> None :
143
225
144
226
self .executor = _ensure_executor (executor )
@@ -153,7 +235,7 @@ def __init__(
153
235
self .shutdown_executor = shutdown_executor or (executor is None )
154
236
155
237
self .learner = learner
156
- self .log = [] if log else None
238
+ self .log : Optional [ list ] = [] if log else None
157
239
158
240
# Timing
159
241
self .start_time = time .time ()
@@ -216,7 +298,12 @@ def overhead(self):
216
298
return (1 - t_function / t_total ) * 100
217
299
218
300
def _process_futures (
219
- self , done_futs : Union [Set [Future ], Set [Future ], Set [AsyncResult ], Set [Task ]]
301
+ self ,
302
+ done_futs : Union [
303
+ Set [Future ],
304
+ Set [AsyncResult ], # XXX: AsyncResult might not be imported
305
+ Set [Task ],
306
+ ],
220
307
) -> None :
221
308
for fut in done_futs :
222
309
x = self .pending_points .pop (fut )
@@ -240,7 +327,11 @@ def _process_futures(
240
327
241
328
def _get_futures (
242
329
self ,
243
- ) -> Union [List [Task ], List [Future ], List [Future ], List [AsyncResult ]]:
330
+ ) -> Union [
331
+ List [Task ],
332
+ List [Future ],
333
+ List [AsyncResult ], # XXX: AsyncResult might not be imported
334
+ ]:
244
335
# Launch tasks to replace the ones that completed
245
336
# on the last iteration, making sure to fill workers
246
337
# that have started since the last iteration.
@@ -363,8 +454,13 @@ def __init__(
363
454
learner : BaseLearner ,
364
455
goal : Callable ,
365
456
* ,
366
- executor = None ,
367
- ntasks = None ,
457
+ executor : Union [
458
+ ThirdPartyExecutor ,
459
+ concurrent .ProcessPoolExecutor ,
460
+ concurrent .ThreadPoolExecutor ,
461
+ SequentialExecutor ,
462
+ ] = None ,
463
+ ntasks : Optional [int ] = None ,
368
464
log = False ,
369
465
shutdown_executor = False ,
370
466
retries = 0 ,
@@ -386,9 +482,7 @@ def __init__(
386
482
)
387
483
self ._run ()
388
484
389
- def _submit (
390
- self , x : Union [Tuple [int , int ], int , Tuple [float , float ], float ]
391
- ) -> Union [Future , AsyncResult ]:
485
+ def _submit (self , x : Union [Tuple [float , ...], float , int ]) -> Future :
392
486
return self .executor .submit (self .learner .function , x )
393
487
394
488
def _run (self ) -> None :
@@ -494,13 +588,18 @@ def __init__(
494
588
learner : BaseLearner ,
495
589
goal : Optional [Callable ] = None ,
496
590
* ,
497
- executor = None ,
498
- ntasks = None ,
499
- log = False ,
500
- shutdown_executor = False ,
591
+ executor : Union [
592
+ ThirdPartyExecutor ,
593
+ concurrent .ProcessPoolExecutor ,
594
+ concurrent .ThreadPoolExecutor ,
595
+ SequentialExecutor ,
596
+ ] = None ,
597
+ ntasks : Optional [int ] = None ,
598
+ log : bool = False ,
599
+ shutdown_executor : bool = False ,
501
600
ioloop = None ,
502
- retries = 0 ,
503
- raise_if_retries_exceeded = True ,
601
+ retries : int = 0 ,
602
+ raise_if_retries_exceeded : bool = True ,
504
603
) -> None :
505
604
506
605
if goal is None :
@@ -640,7 +739,7 @@ async def _run(self) -> None:
640
739
await asyncio .wait (remaining )
641
740
self ._cleanup ()
642
741
643
- def elapsed_time (self ):
742
+ def elapsed_time (self ) -> float :
644
743
"""Return the total time elapsed since the runner
645
744
was started."""
646
745
if self .task .done ():
@@ -653,7 +752,7 @@ def elapsed_time(self):
653
752
end_time = time .time ()
654
753
return end_time - self .start_time
655
754
656
- def start_periodic_saving (self , save_kwargs , interval ):
755
+ def start_periodic_saving (self , save_kwargs : Dict [ str , Any ], interval : int ):
657
756
"""Periodically save the learner's data.
658
757
659
758
Parameters
@@ -711,16 +810,7 @@ def simple(learner: BaseLearner, goal: Callable) -> None:
711
810
learner .tell (x , y )
712
811
713
812
714
- def replay_log (
715
- learner : BaseLearner ,
716
- log : List [
717
- Union [
718
- Tuple [str , int ],
719
- Tuple [str , Tuple [int , int , int ], float ],
720
- Tuple [str , Tuple [float , float , float ], float ],
721
- ]
722
- ],
723
- ) -> None :
813
+ def replay_log (learner : BaseLearner , log ) -> None :
724
814
"""Apply a sequence of method calls to a learner.
725
815
726
816
This is useful for debugging runners.
@@ -771,67 +861,3 @@ def stop_after(*, seconds=0, minutes=0, hours=0) -> Callable:
771
861
"""
772
862
stop_time = time .time () + seconds + 60 * minutes + 3600 * hours
773
863
return lambda _ : time .time () > stop_time
774
-
775
-
776
- # -- Internal executor-related, things
777
-
778
-
779
- class SequentialExecutor (concurrent .Executor ):
780
- """A trivial executor that runs functions synchronously.
781
-
782
- This executor is mainly for testing.
783
- """
784
-
785
- def submit (self , fn : Callable , * args , ** kwargs ) -> Future :
786
- fut = concurrent .Future ()
787
- try :
788
- fut .set_result (fn (* args , ** kwargs ))
789
- except Exception as e :
790
- fut .set_exception (e )
791
- return fut
792
-
793
- def map (self , fn , * iterable , timeout = None , chunksize = 1 ):
794
- return map (fn , iterable )
795
-
796
- def shutdown (self , wait = True ):
797
- pass
798
-
799
-
800
- def _ensure_executor (
801
- executor : Optional [Union [Client , Client , ProcessPoolExecutor , SequentialExecutor ]]
802
- ) -> Union [SequentialExecutor , ProcessPoolExecutor , ViewExecutor , ClientExecutor ]:
803
- if executor is None :
804
- executor = _default_executor (** _default_executor_kwargs )
805
-
806
- if isinstance (executor , concurrent .Executor ):
807
- return executor
808
- elif with_ipyparallel and isinstance (executor , ipyparallel .Client ):
809
- return executor .executor ()
810
- elif with_distributed and isinstance (executor , distributed .Client ):
811
- return executor .get_executor ()
812
- else :
813
- raise TypeError (
814
- "Only a concurrent.futures.Executor, distributed.Client,"
815
- " or ipyparallel.Client can be used."
816
- )
817
-
818
-
819
- def _get_ncores (
820
- ex : Union [SequentialExecutor , ProcessPoolExecutor , ViewExecutor , ClientExecutor ]
821
- ) -> int :
822
- """Return the maximum number of cores that an executor can use."""
823
- if with_ipyparallel and isinstance (ex , ipyparallel .client .view .ViewExecutor ):
824
- return len (ex .view )
825
- elif isinstance (
826
- ex , (concurrent .ProcessPoolExecutor , concurrent .ThreadPoolExecutor )
827
- ):
828
- return ex ._max_workers # not public API!
829
- elif isinstance (ex , SequentialExecutor ):
830
- return 1
831
- elif with_distributed and isinstance (ex , distributed .cfexecutor .ClientExecutor ):
832
- return sum (n for n in ex ._client .ncores ().values ())
833
- elif with_mpi4py and isinstance (ex , mpi4py .futures .MPIPoolExecutor ):
834
- ex .bootup () # wait until all workers are up and running
835
- return ex ._pool .size # not public API!
836
- else :
837
- raise TypeError (f"Cannot get number of cores for { ex .__class__ } " )
0 commit comments