Skip to content

Commit 8e10b5d

Browse files
committed
type hint fixes for adaptive/runner.py
1 parent 2f7c9d6 commit 8e10b5d

File tree

1 file changed

+129
-103
lines changed

1 file changed

+129
-103
lines changed

adaptive/runner.py

Lines changed: 129 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,39 @@
99
import traceback
1010
import warnings
1111
from _asyncio import Future, Task
12-
from concurrent.futures.process import ProcessPoolExecutor
1312
from contextlib import suppress
14-
from typing import Any, Callable, List, Optional, Set, Tuple, Union
15-
16-
from distributed.cfexecutor import ClientExecutor
17-
from distributed.client import Client
18-
from ipyparallel.client.asyncresult import AsyncResult
19-
from ipyparallel.client.view import ViewExecutor
13+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
2014

2115
from adaptive.learner import BaseLearner
2216
from adaptive.notebook_integration import in_ipynb, live_info, live_plot
2317

18+
_ThirdPartyClient = []
19+
_ThirdPartyExecutor = []
20+
2421
try:
2522
import ipyparallel
23+
from ipyparallel.client.asyncresult import AsyncResult
2624

2725
with_ipyparallel = True
26+
_ThirdPartyClient.append(ipyparallel.Client)
27+
_ThirdPartyExecutor.append(ipyparallel.client.view.ViewExecutor)
2828
except ModuleNotFoundError:
2929
with_ipyparallel = False
3030

3131
try:
3232
import distributed
3333

3434
with_distributed = True
35+
_ThirdPartyClient.append(distributed.client.Client)
36+
_ThirdPartyExecutor.append(distributed.cfexecutor.ClientExecutor)
3537
except ModuleNotFoundError:
3638
with_distributed = False
3739

3840
try:
3941
import mpi4py.futures
4042

4143
with_mpi4py = True
44+
_ThirdPartyExecutor.append(mpi4py.futures.MPIPoolExecutor)
4245
except ModuleNotFoundError:
4346
with_mpi4py = False
4447

@@ -47,6 +50,8 @@
4750

4851
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
4952

53+
ThirdPartyClient = Union[tuple(_ThirdPartyClient)]
54+
ThirdPartyExecutor = Union[tuple(_ThirdPartyExecutor)]
5055

5156
if os.name == "nt":
5257
if with_distributed:
@@ -72,8 +77,80 @@ def _default_executor(*args, **kwargs):
7277
_default_executor_kwargs = {}
7378

7479

80+
# -- Internal executor-related, things
81+
82+
83+
class SequentialExecutor(concurrent.Executor):
84+
"""A trivial executor that runs functions synchronously.
85+
86+
This executor is mainly for testing.
87+
"""
88+
89+
def submit(self, fn: Callable, *args, **kwargs) -> Future:
90+
fut = concurrent.Future()
91+
try:
92+
fut.set_result(fn(*args, **kwargs))
93+
except Exception as e:
94+
fut.set_exception(e)
95+
return fut
96+
97+
def map(self, fn, *iterable, timeout=None, chunksize=1):
98+
return map(fn, iterable)
99+
100+
def shutdown(self, wait=True):
101+
pass
102+
103+
104+
def _ensure_executor(
105+
executor: Optional[Union[ThirdPartyClient, concurrent.Executor]]
106+
) -> concurrent.Executor:
107+
if executor is None:
108+
executor = _default_executor(**_default_executor_kwargs)
109+
110+
if isinstance(executor, concurrent.Executor):
111+
return executor
112+
elif with_ipyparallel and isinstance(executor, ipyparallel.Client):
113+
return executor.executor()
114+
elif with_distributed and isinstance(executor, distributed.Client):
115+
return executor.get_executor()
116+
else:
117+
raise TypeError(
118+
"Only a concurrent.futures.Executor, distributed.Client,"
119+
" or ipyparallel.Client can be used."
120+
)
121+
122+
123+
def _get_ncores(
124+
ex: Union[
125+
ThirdPartyExecutor,
126+
concurrent.ProcessPoolExecutor,
127+
concurrent.ThreadPoolExecutor,
128+
SequentialExecutor,
129+
]
130+
) -> int:
131+
"""Return the maximum number of cores that an executor can use."""
132+
if with_ipyparallel and isinstance(ex, ipyparallel.client.view.ViewExecutor):
133+
return len(ex.view)
134+
elif isinstance(
135+
ex, (concurrent.ProcessPoolExecutor, concurrent.ThreadPoolExecutor)
136+
):
137+
return ex._max_workers # not public API!
138+
elif isinstance(ex, SequentialExecutor):
139+
return 1
140+
elif with_distributed and isinstance(ex, distributed.cfexecutor.ClientExecutor):
141+
return sum(n for n in ex._client.ncores().values())
142+
elif with_mpi4py and isinstance(ex, mpi4py.futures.MPIPoolExecutor):
143+
ex.bootup() # wait until all workers are up and running
144+
return ex._pool.size # not public API!
145+
else:
146+
raise TypeError(f"Cannot get number of cores for {ex.__class__}")
147+
148+
149+
# -- Runner definitions
150+
151+
75152
class BaseRunner(metaclass=abc.ABCMeta):
76-
r"""Base class for runners that use `concurrent.futures.Executors`.
153+
r"""Base class for runners that use `concurrent.futures.Executor`\'s.
77154
78155
Parameters
79156
----------
@@ -133,12 +210,17 @@ def __init__(
133210
learner: BaseLearner,
134211
goal: Callable,
135212
*,
136-
executor=None,
137-
ntasks=None,
138-
log=False,
139-
shutdown_executor=False,
140-
retries=0,
141-
raise_if_retries_exceeded=True,
213+
executor: Union[
214+
ThirdPartyExecutor,
215+
concurrent.ProcessPoolExecutor,
216+
concurrent.ThreadPoolExecutor,
217+
SequentialExecutor,
218+
] = None,
219+
ntasks: int = None,
220+
log: bool = False,
221+
shutdown_executor: bool = False,
222+
retries: int = 0,
223+
raise_if_retries_exceeded: bool = True,
142224
) -> None:
143225

144226
self.executor = _ensure_executor(executor)
@@ -216,7 +298,12 @@ def overhead(self):
216298
return (1 - t_function / t_total) * 100
217299

218300
def _process_futures(
219-
self, done_futs: Union[Set[Future], Set[Future], Set[AsyncResult], Set[Task]]
301+
self,
302+
done_futs: Union[
303+
Set[Future],
304+
Set[AsyncResult], # XXX: AsyncResult might not be imported
305+
Set[Task],
306+
],
220307
) -> None:
221308
for fut in done_futs:
222309
x = self.pending_points.pop(fut)
@@ -240,7 +327,11 @@ def _process_futures(
240327

241328
def _get_futures(
242329
self,
243-
) -> Union[List[Task], List[Future], List[Future], List[AsyncResult]]:
330+
) -> Union[
331+
List[Task],
332+
List[Future],
333+
List[AsyncResult], # XXX: AsyncResult might not be imported
334+
]:
244335
# Launch tasks to replace the ones that completed
245336
# on the last iteration, making sure to fill workers
246337
# that have started since the last iteration.
@@ -363,8 +454,13 @@ def __init__(
363454
learner: BaseLearner,
364455
goal: Callable,
365456
*,
366-
executor=None,
367-
ntasks=None,
457+
executor: Union[
458+
ThirdPartyExecutor,
459+
concurrent.ProcessPoolExecutor,
460+
concurrent.ThreadPoolExecutor,
461+
SequentialExecutor,
462+
] = None,
463+
ntasks: Optional[int] = None,
368464
log=False,
369465
shutdown_executor=False,
370466
retries=0,
@@ -386,9 +482,7 @@ def __init__(
386482
)
387483
self._run()
388484

389-
def _submit(
390-
self, x: Union[Tuple[int, int], int, Tuple[float, float], float]
391-
) -> Union[Future, AsyncResult]:
485+
def _submit(self, x: Union[Tuple[float, ...], float, int]) -> Future:
392486
return self.executor.submit(self.learner.function, x)
393487

394488
def _run(self) -> None:
@@ -494,13 +588,18 @@ def __init__(
494588
learner: BaseLearner,
495589
goal: Optional[Callable] = None,
496590
*,
497-
executor=None,
498-
ntasks=None,
499-
log=False,
500-
shutdown_executor=False,
591+
executor: Union[
592+
ThirdPartyExecutor,
593+
concurrent.ProcessPoolExecutor,
594+
concurrent.ThreadPoolExecutor,
595+
SequentialExecutor,
596+
] = None,
597+
ntasks: Optional[int] = None,
598+
log: bool = False,
599+
shutdown_executor: bool = False,
501600
ioloop=None,
502-
retries=0,
503-
raise_if_retries_exceeded=True,
601+
retries: int = 0,
602+
raise_if_retries_exceeded: bool = True,
504603
) -> None:
505604

506605
if goal is None:
@@ -640,7 +739,7 @@ async def _run(self) -> None:
640739
await asyncio.wait(remaining)
641740
self._cleanup()
642741

643-
def elapsed_time(self):
742+
def elapsed_time(self) -> float:
644743
"""Return the total time elapsed since the runner
645744
was started."""
646745
if self.task.done():
@@ -653,7 +752,7 @@ def elapsed_time(self):
653752
end_time = time.time()
654753
return end_time - self.start_time
655754

656-
def start_periodic_saving(self, save_kwargs, interval):
755+
def start_periodic_saving(self, save_kwargs: Dict[str, Any], interval: int):
657756
"""Periodically save the learner's data.
658757
659758
Parameters
@@ -711,16 +810,7 @@ def simple(learner: BaseLearner, goal: Callable) -> None:
711810
learner.tell(x, y)
712811

713812

714-
def replay_log(
715-
learner: BaseLearner,
716-
log: List[
717-
Union[
718-
Tuple[str, int],
719-
Tuple[str, Tuple[int, int, int], float],
720-
Tuple[str, Tuple[float, float, float], float],
721-
]
722-
],
723-
) -> None:
813+
def replay_log(learner: BaseLearner, log) -> None:
724814
"""Apply a sequence of method calls to a learner.
725815
726816
This is useful for debugging runners.
@@ -771,67 +861,3 @@ def stop_after(*, seconds=0, minutes=0, hours=0) -> Callable:
771861
"""
772862
stop_time = time.time() + seconds + 60 * minutes + 3600 * hours
773863
return lambda _: time.time() > stop_time
774-
775-
776-
# -- Internal executor-related, things
777-
778-
779-
class SequentialExecutor(concurrent.Executor):
780-
"""A trivial executor that runs functions synchronously.
781-
782-
This executor is mainly for testing.
783-
"""
784-
785-
def submit(self, fn: Callable, *args, **kwargs) -> Future:
786-
fut = concurrent.Future()
787-
try:
788-
fut.set_result(fn(*args, **kwargs))
789-
except Exception as e:
790-
fut.set_exception(e)
791-
return fut
792-
793-
def map(self, fn, *iterable, timeout=None, chunksize=1):
794-
return map(fn, iterable)
795-
796-
def shutdown(self, wait=True):
797-
pass
798-
799-
800-
def _ensure_executor(
801-
executor: Optional[Union[Client, Client, ProcessPoolExecutor, SequentialExecutor]]
802-
) -> Union[SequentialExecutor, ProcessPoolExecutor, ViewExecutor, ClientExecutor]:
803-
if executor is None:
804-
executor = _default_executor(**_default_executor_kwargs)
805-
806-
if isinstance(executor, concurrent.Executor):
807-
return executor
808-
elif with_ipyparallel and isinstance(executor, ipyparallel.Client):
809-
return executor.executor()
810-
elif with_distributed and isinstance(executor, distributed.Client):
811-
return executor.get_executor()
812-
else:
813-
raise TypeError(
814-
"Only a concurrent.futures.Executor, distributed.Client,"
815-
" or ipyparallel.Client can be used."
816-
)
817-
818-
819-
def _get_ncores(
820-
ex: Union[SequentialExecutor, ProcessPoolExecutor, ViewExecutor, ClientExecutor]
821-
) -> int:
822-
"""Return the maximum number of cores that an executor can use."""
823-
if with_ipyparallel and isinstance(ex, ipyparallel.client.view.ViewExecutor):
824-
return len(ex.view)
825-
elif isinstance(
826-
ex, (concurrent.ProcessPoolExecutor, concurrent.ThreadPoolExecutor)
827-
):
828-
return ex._max_workers # not public API!
829-
elif isinstance(ex, SequentialExecutor):
830-
return 1
831-
elif with_distributed and isinstance(ex, distributed.cfexecutor.ClientExecutor):
832-
return sum(n for n in ex._client.ncores().values())
833-
elif with_mpi4py and isinstance(ex, mpi4py.futures.MPIPoolExecutor):
834-
ex.bootup() # wait until all workers are up and running
835-
return ex._pool.size # not public API!
836-
else:
837-
raise TypeError(f"Cannot get number of cores for {ex.__class__}")

0 commit comments

Comments
 (0)