Skip to content

Commit fd61d3c

Browse files
committed
add type annotations for adaptive/runner.py
1 parent e35e0db commit fd61d3c

File tree

1 file changed

+59
-26
lines changed

1 file changed

+59
-26
lines changed

adaptive/runner.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,20 @@
88
import time
99
import traceback
1010
import warnings
11+
from _asyncio import Future, Task
12+
from concurrent.futures.process import ProcessPoolExecutor
1113
from contextlib import suppress
14+
from typing import Any, Callable, List, Optional, Set, Tuple, Union
1215

16+
from distributed.cfexecutor import ClientExecutor
17+
from distributed.client import Client
18+
from ipyparallel.client.asyncresult import AsyncResult
19+
from ipyparallel.client.view import ViewExecutor
20+
from numpy import float64
21+
22+
from adaptive.learner.learner1D import Learner1D
23+
from adaptive.learner.learner2D import Learner2D
24+
from adaptive.learner.learnerND import LearnerND
1325
from adaptive.notebook_integration import in_ipynb, live_info, live_plot
1426

1527
try:
@@ -121,16 +133,16 @@ class BaseRunner(metaclass=abc.ABCMeta):
121133

122134
def __init__(
123135
self,
124-
learner,
125-
goal,
136+
learner: Union[Learner1D, Learner2D, LearnerND],
137+
goal: Callable,
126138
*,
127139
executor=None,
128140
ntasks=None,
129141
log=False,
130142
shutdown_executor=False,
131143
retries=0,
132144
raise_if_retries_exceeded=True,
133-
):
145+
) -> None:
134146

135147
self.executor = _ensure_executor(executor)
136148
self.goal = goal
@@ -157,7 +169,7 @@ def __init__(
157169
self.to_retry = {}
158170
self.tracebacks = {}
159171

160-
def _get_max_tasks(self):
172+
def _get_max_tasks(self) -> int:
161173
return self._max_tasks or _get_ncores(self.executor)
162174

163175
def _do_raise(self, e, x):
@@ -169,10 +181,10 @@ def _do_raise(self, e, x):
169181
) from e
170182

171183
@property
172-
def do_log(self):
184+
def do_log(self) -> bool:
173185
return self.log is not None
174186

175-
def _ask(self, n):
187+
def _ask(self, n: int) -> Any:
176188
points = [
177189
p for p in self.to_retry.keys() if p not in self.pending_points.values()
178190
][:n]
@@ -206,7 +218,9 @@ def overhead(self):
206218
t_total = self.elapsed_time()
207219
return (1 - t_function / t_total) * 100
208220

209-
def _process_futures(self, done_futs):
221+
def _process_futures(
222+
self, done_futs: Union[Set[Future], Set[Future], Set[AsyncResult], Set[Task]]
223+
) -> None:
210224
for fut in done_futs:
211225
x = self.pending_points.pop(fut)
212226
try:
@@ -227,7 +241,9 @@ def _process_futures(self, done_futs):
227241
self.log.append(("tell", x, y))
228242
self.learner.tell(x, y)
229243

230-
def _get_futures(self):
244+
def _get_futures(
245+
self,
246+
) -> Union[List[Task], List[Future], List[Future], List[AsyncResult]]:
231247
# Launch tasks to replace the ones that completed
232248
# on the last iteration, making sure to fill workers
233249
# that have started since the last iteration.
@@ -248,7 +264,7 @@ def _get_futures(self):
248264
futures = list(self.pending_points.keys())
249265
return futures
250266

251-
def _remove_unfinished(self):
267+
def _remove_unfinished(self) -> List[Future]:
252268
# remove points with 'None' values from the learner
253269
self.learner.remove_unfinished()
254270
# cancel any outstanding tasks
@@ -257,7 +273,7 @@ def _remove_unfinished(self):
257273
fut.cancel()
258274
return remaining
259275

260-
def _cleanup(self):
276+
def _cleanup(self) -> None:
261277
if self.shutdown_executor:
262278
# XXX: temporary set wait=True for Python 3.7
263279
# see https://github.com/python-adaptive/adaptive/issues/156
@@ -347,16 +363,16 @@ class BlockingRunner(BaseRunner):
347363

348364
def __init__(
349365
self,
350-
learner,
351-
goal,
366+
learner: Union[LearnerND, Learner2D, Learner1D],
367+
goal: Callable,
352368
*,
353369
executor=None,
354370
ntasks=None,
355371
log=False,
356372
shutdown_executor=False,
357373
retries=0,
358374
raise_if_retries_exceeded=True,
359-
):
375+
) -> None:
360376
if inspect.iscoroutinefunction(learner.function):
361377
raise ValueError(
362378
"Coroutine functions can only be used " "with 'AsyncRunner'."
@@ -373,10 +389,12 @@ def __init__(
373389
)
374390
self._run()
375391

376-
def _submit(self, x):
392+
def _submit(
393+
self, x: Union[Tuple[int, int], int, Tuple[float64, float64], float]
394+
) -> Union[Future, AsyncResult]:
377395
return self.executor.submit(self.learner.function, x)
378396

379-
def _run(self):
397+
def _run(self) -> None:
380398
first_completed = concurrent.FIRST_COMPLETED
381399

382400
if self._get_max_tasks() < 1:
@@ -476,8 +494,8 @@ class AsyncRunner(BaseRunner):
476494

477495
def __init__(
478496
self,
479-
learner,
480-
goal=None,
497+
learner: Union[Learner1D, Learner2D],
498+
goal: Optional[Callable] = None,
481499
*,
482500
executor=None,
483501
ntasks=None,
@@ -486,7 +504,7 @@ def __init__(
486504
ioloop=None,
487505
retries=0,
488506
raise_if_retries_exceeded=True,
489-
):
507+
) -> None:
490508

491509
if goal is None:
492510

@@ -539,7 +557,9 @@ def goal(_):
539557
"'adaptive.notebook_extension()'"
540558
)
541559

542-
def _submit(self, x):
560+
def _submit(
561+
self, x: Union[Tuple[int, int], int, Tuple[float64, float64], float]
562+
) -> Union[Task, Future]:
543563
ioloop = self.ioloop
544564
if inspect.iscoroutinefunction(self.learner.function):
545565
return ioloop.create_task(self.learner.function(x))
@@ -604,7 +624,7 @@ def live_info(self, *, update_interval=0.1):
604624
"""
605625
return live_info(self, update_interval=update_interval)
606626

607-
async def _run(self):
627+
async def _run(self) -> None:
608628
first_completed = asyncio.FIRST_COMPLETED
609629

610630
if self._get_max_tasks() < 1:
@@ -668,7 +688,7 @@ async def _saver(save_kwargs=save_kwargs, interval=interval):
668688
Runner = AsyncRunner
669689

670690

671-
def simple(learner, goal):
691+
def simple(learner: Any, goal: Callable) -> None:
672692
"""Run the learner until the goal is reached.
673693
674694
Requests a single point from the learner, evaluates
@@ -694,7 +714,16 @@ def simple(learner, goal):
694714
learner.tell(x, y)
695715

696716

697-
def replay_log(learner, log):
717+
def replay_log(
718+
learner: LearnerND,
719+
log: List[
720+
Union[
721+
Tuple[str, int],
722+
Tuple[str, Tuple[int, int, int], float],
723+
Tuple[str, Tuple[float, float, float], float],
724+
]
725+
],
726+
) -> None:
698727
"""Apply a sequence of method calls to a learner.
699728
700729
This is useful for debugging runners.
@@ -713,7 +742,7 @@ def replay_log(learner, log):
713742
# --- Useful runner goals
714743

715744

716-
def stop_after(*, seconds=0, minutes=0, hours=0):
745+
def stop_after(*, seconds=0, minutes=0, hours=0) -> Callable:
717746
"""Stop a runner after a specified time.
718747
719748
For example, to specify a runner that should stop after
@@ -756,7 +785,7 @@ class SequentialExecutor(concurrent.Executor):
756785
This executor is mainly for testing.
757786
"""
758787

759-
def submit(self, fn, *args, **kwargs):
788+
def submit(self, fn: Callable, *args, **kwargs) -> Future:
760789
fut = concurrent.Future()
761790
try:
762791
fut.set_result(fn(*args, **kwargs))
@@ -771,7 +800,9 @@ def shutdown(self, wait=True):
771800
pass
772801

773802

774-
def _ensure_executor(executor):
803+
def _ensure_executor(
804+
executor: Optional[Union[Client, Client, ProcessPoolExecutor, SequentialExecutor]]
805+
) -> Union[SequentialExecutor, ProcessPoolExecutor, ViewExecutor, ClientExecutor]:
775806
if executor is None:
776807
executor = _default_executor(**_default_executor_kwargs)
777808

@@ -788,7 +819,9 @@ def _ensure_executor(executor):
788819
)
789820

790821

791-
def _get_ncores(ex):
822+
def _get_ncores(
823+
ex: Union[SequentialExecutor, ProcessPoolExecutor, ViewExecutor, ClientExecutor]
824+
) -> int:
792825
"""Return the maximum number of cores that an executor can use."""
793826
if with_ipyparallel and isinstance(ex, ipyparallel.client.view.ViewExecutor):
794827
return len(ex.view)

0 commit comments

Comments
 (0)