8
8
import time
9
9
import traceback
10
10
import warnings
11
+ from _asyncio import Future , Task
12
+ from concurrent .futures .process import ProcessPoolExecutor
11
13
from contextlib import suppress
14
+ from typing import Any , Callable , List , Optional , Set , Tuple , Union
12
15
16
+ from distributed .cfexecutor import ClientExecutor
17
+ from distributed .client import Client
18
+ from ipyparallel .client .asyncresult import AsyncResult
19
+ from ipyparallel .client .view import ViewExecutor
20
+ from numpy import float64
21
+
22
+ from adaptive .learner .learner1D import Learner1D
23
+ from adaptive .learner .learner2D import Learner2D
24
+ from adaptive .learner .learnerND import LearnerND
13
25
from adaptive .notebook_integration import in_ipynb , live_info , live_plot
14
26
15
27
try :
@@ -121,16 +133,16 @@ class BaseRunner(metaclass=abc.ABCMeta):
121
133
122
134
def __init__ (
123
135
self ,
124
- learner ,
125
- goal ,
136
+ learner : Union [ Learner1D , Learner2D , LearnerND ] ,
137
+ goal : Callable ,
126
138
* ,
127
139
executor = None ,
128
140
ntasks = None ,
129
141
log = False ,
130
142
shutdown_executor = False ,
131
143
retries = 0 ,
132
144
raise_if_retries_exceeded = True ,
133
- ):
145
+ ) -> None :
134
146
135
147
self .executor = _ensure_executor (executor )
136
148
self .goal = goal
@@ -157,7 +169,7 @@ def __init__(
157
169
self .to_retry = {}
158
170
self .tracebacks = {}
159
171
160
- def _get_max_tasks (self ):
172
+ def _get_max_tasks (self ) -> int :
161
173
return self ._max_tasks or _get_ncores (self .executor )
162
174
163
175
def _do_raise (self , e , x ):
@@ -169,10 +181,10 @@ def _do_raise(self, e, x):
169
181
) from e
170
182
171
183
@property
172
- def do_log (self ):
184
+ def do_log (self ) -> bool :
173
185
return self .log is not None
174
186
175
- def _ask (self , n ) :
187
+ def _ask (self , n : int ) -> Any :
176
188
points = [
177
189
p for p in self .to_retry .keys () if p not in self .pending_points .values ()
178
190
][:n ]
@@ -206,7 +218,9 @@ def overhead(self):
206
218
t_total = self .elapsed_time ()
207
219
return (1 - t_function / t_total ) * 100
208
220
209
- def _process_futures (self , done_futs ):
221
+ def _process_futures (
222
+ self , done_futs : Union [Set [Future ], Set [Future ], Set [AsyncResult ], Set [Task ]]
223
+ ) -> None :
210
224
for fut in done_futs :
211
225
x = self .pending_points .pop (fut )
212
226
try :
@@ -227,7 +241,9 @@ def _process_futures(self, done_futs):
227
241
self .log .append (("tell" , x , y ))
228
242
self .learner .tell (x , y )
229
243
230
- def _get_futures (self ):
244
+ def _get_futures (
245
+ self ,
246
+ ) -> Union [List [Task ], List [Future ], List [Future ], List [AsyncResult ]]:
231
247
# Launch tasks to replace the ones that completed
232
248
# on the last iteration, making sure to fill workers
233
249
# that have started since the last iteration.
@@ -248,7 +264,7 @@ def _get_futures(self):
248
264
futures = list (self .pending_points .keys ())
249
265
return futures
250
266
251
- def _remove_unfinished (self ):
267
+ def _remove_unfinished (self ) -> List [ Future ] :
252
268
# remove points with 'None' values from the learner
253
269
self .learner .remove_unfinished ()
254
270
# cancel any outstanding tasks
@@ -257,7 +273,7 @@ def _remove_unfinished(self):
257
273
fut .cancel ()
258
274
return remaining
259
275
260
- def _cleanup (self ):
276
+ def _cleanup (self ) -> None :
261
277
if self .shutdown_executor :
262
278
# XXX: temporary set wait=True for Python 3.7
263
279
# see https://github.com/python-adaptive/adaptive/issues/156
@@ -347,16 +363,16 @@ class BlockingRunner(BaseRunner):
347
363
348
364
def __init__ (
349
365
self ,
350
- learner ,
351
- goal ,
366
+ learner : Union [ LearnerND , Learner2D , Learner1D ] ,
367
+ goal : Callable ,
352
368
* ,
353
369
executor = None ,
354
370
ntasks = None ,
355
371
log = False ,
356
372
shutdown_executor = False ,
357
373
retries = 0 ,
358
374
raise_if_retries_exceeded = True ,
359
- ):
375
+ ) -> None :
360
376
if inspect .iscoroutinefunction (learner .function ):
361
377
raise ValueError (
362
378
"Coroutine functions can only be used " "with 'AsyncRunner'."
@@ -373,10 +389,12 @@ def __init__(
373
389
)
374
390
self ._run ()
375
391
376
- def _submit (self , x ):
392
+ def _submit (
393
+ self , x : Union [Tuple [int , int ], int , Tuple [float64 , float64 ], float ]
394
+ ) -> Union [Future , AsyncResult ]:
377
395
return self .executor .submit (self .learner .function , x )
378
396
379
- def _run (self ):
397
+ def _run (self ) -> None :
380
398
first_completed = concurrent .FIRST_COMPLETED
381
399
382
400
if self ._get_max_tasks () < 1 :
@@ -476,8 +494,8 @@ class AsyncRunner(BaseRunner):
476
494
477
495
def __init__ (
478
496
self ,
479
- learner ,
480
- goal = None ,
497
+ learner : Union [ Learner1D , Learner2D ] ,
498
+ goal : Optional [ Callable ] = None ,
481
499
* ,
482
500
executor = None ,
483
501
ntasks = None ,
@@ -486,7 +504,7 @@ def __init__(
486
504
ioloop = None ,
487
505
retries = 0 ,
488
506
raise_if_retries_exceeded = True ,
489
- ):
507
+ ) -> None :
490
508
491
509
if goal is None :
492
510
@@ -539,7 +557,9 @@ def goal(_):
539
557
"'adaptive.notebook_extension()'"
540
558
)
541
559
542
- def _submit (self , x ):
560
+ def _submit (
561
+ self , x : Union [Tuple [int , int ], int , Tuple [float64 , float64 ], float ]
562
+ ) -> Union [Task , Future ]:
543
563
ioloop = self .ioloop
544
564
if inspect .iscoroutinefunction (self .learner .function ):
545
565
return ioloop .create_task (self .learner .function (x ))
@@ -604,7 +624,7 @@ def live_info(self, *, update_interval=0.1):
604
624
"""
605
625
return live_info (self , update_interval = update_interval )
606
626
607
- async def _run (self ):
627
+ async def _run (self ) -> None :
608
628
first_completed = asyncio .FIRST_COMPLETED
609
629
610
630
if self ._get_max_tasks () < 1 :
@@ -668,7 +688,7 @@ async def _saver(save_kwargs=save_kwargs, interval=interval):
668
688
Runner = AsyncRunner
669
689
670
690
671
- def simple (learner , goal ) :
691
+ def simple (learner : Any , goal : Callable ) -> None :
672
692
"""Run the learner until the goal is reached.
673
693
674
694
Requests a single point from the learner, evaluates
@@ -694,7 +714,16 @@ def simple(learner, goal):
694
714
learner .tell (x , y )
695
715
696
716
697
- def replay_log (learner , log ):
717
+ def replay_log (
718
+ learner : LearnerND ,
719
+ log : List [
720
+ Union [
721
+ Tuple [str , int ],
722
+ Tuple [str , Tuple [int , int , int ], float ],
723
+ Tuple [str , Tuple [float , float , float ], float ],
724
+ ]
725
+ ],
726
+ ) -> None :
698
727
"""Apply a sequence of method calls to a learner.
699
728
700
729
This is useful for debugging runners.
@@ -713,7 +742,7 @@ def replay_log(learner, log):
713
742
# --- Useful runner goals
714
743
715
744
716
- def stop_after (* , seconds = 0 , minutes = 0 , hours = 0 ):
745
+ def stop_after (* , seconds = 0 , minutes = 0 , hours = 0 ) -> Callable :
717
746
"""Stop a runner after a specified time.
718
747
719
748
For example, to specify a runner that should stop after
@@ -756,7 +785,7 @@ class SequentialExecutor(concurrent.Executor):
756
785
This executor is mainly for testing.
757
786
"""
758
787
759
- def submit (self , fn , * args , ** kwargs ):
788
+ def submit (self , fn : Callable , * args , ** kwargs ) -> Future :
760
789
fut = concurrent .Future ()
761
790
try :
762
791
fut .set_result (fn (* args , ** kwargs ))
@@ -771,7 +800,9 @@ def shutdown(self, wait=True):
771
800
pass
772
801
773
802
774
- def _ensure_executor (executor ):
803
+ def _ensure_executor (
804
+ executor : Optional [Union [Client , Client , ProcessPoolExecutor , SequentialExecutor ]]
805
+ ) -> Union [SequentialExecutor , ProcessPoolExecutor , ViewExecutor , ClientExecutor ]:
775
806
if executor is None :
776
807
executor = _default_executor (** _default_executor_kwargs )
777
808
@@ -788,7 +819,9 @@ def _ensure_executor(executor):
788
819
)
789
820
790
821
791
- def _get_ncores (ex ):
822
+ def _get_ncores (
823
+ ex : Union [SequentialExecutor , ProcessPoolExecutor , ViewExecutor , ClientExecutor ]
824
+ ) -> int :
792
825
"""Return the maximum number of cores that an executor can use."""
793
826
if with_ipyparallel and isinstance (ex , ipyparallel .client .view .ViewExecutor ):
794
827
return len (ex .view )
0 commit comments