Skip to content

Commit 0f27eb8

Browse files
committed
use BaseLearner instead of Union of some learners
1 parent 18db3a3 commit 0f27eb8

File tree

3 files changed

+12
-22
lines changed

3 files changed

+12
-22
lines changed

adaptive/learner/data_saver.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
from operator import itemgetter
44
from typing import Callable, Dict, Tuple, Union
55

6-
from adaptive.learner.average_learner import AverageLearner
76
from adaptive.learner.base_learner import BaseLearner
8-
from adaptive.learner.learner1D import Learner1D
9-
from adaptive.learner.learner2D import Learner2D
10-
from adaptive.learner.learnerND import LearnerND
117
from adaptive.utils import copy_docstring_from
128

139

@@ -31,11 +27,7 @@ class DataSaver:
3127
>>> learner = DataSaver(_learner, arg_picker=itemgetter('y'))
3228
"""
3329

34-
def __init__(
35-
self,
36-
learner: Union[Learner2D, Learner1D, LearnerND, AverageLearner],
37-
arg_picker: itemgetter,
38-
) -> None:
30+
def __init__(self, learner: BaseLearner, arg_picker: itemgetter,) -> None:
3931
self.learner = learner
4032
self.extra_data = OrderedDict()
4133
self.function = learner.function

adaptive/runner.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
from ipyparallel.client.view import ViewExecutor
2020
from numpy import float64
2121

22-
from adaptive.learner.learner1D import Learner1D
23-
from adaptive.learner.learner2D import Learner2D
24-
from adaptive.learner.learnerND import LearnerND
22+
from adaptive.learner import BaseLearner
2523
from adaptive.notebook_integration import in_ipynb, live_info, live_plot
2624

2725
try:
@@ -133,7 +131,7 @@ class BaseRunner(metaclass=abc.ABCMeta):
133131

134132
def __init__(
135133
self,
136-
learner: Union[Learner1D, Learner2D, LearnerND],
134+
learner: BaseLearner,
137135
goal: Callable,
138136
*,
139137
executor=None,
@@ -363,7 +361,7 @@ class BlockingRunner(BaseRunner):
363361

364362
def __init__(
365363
self,
366-
learner: Union[LearnerND, Learner2D, Learner1D],
364+
learner: BaseLearner,
367365
goal: Callable,
368366
*,
369367
executor=None,
@@ -494,7 +492,7 @@ class AsyncRunner(BaseRunner):
494492

495493
def __init__(
496494
self,
497-
learner: Union[Learner1D, Learner2D],
495+
learner: BaseLearner,
498496
goal: Optional[Callable] = None,
499497
*,
500498
executor=None,
@@ -688,7 +686,7 @@ async def _saver(save_kwargs=save_kwargs, interval=interval):
688686
Runner = AsyncRunner
689687

690688

691-
def simple(learner: Any, goal: Callable) -> None:
689+
def simple(learner: BaseLearner, goal: Callable) -> None:
692690
"""Run the learner until the goal is reached.
693691
694692
Requests a single point from the learner, evaluates
@@ -715,7 +713,7 @@ def simple(learner: Any, goal: Callable) -> None:
715713

716714

717715
def replay_log(
718-
learner: LearnerND,
716+
learner: BaseLearner,
719717
log: List[
720718
Union[
721719
Tuple[str, int],

adaptive/tests/test_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import asyncio
22
import time
3-
from typing import Callable, Iterator, Union
3+
from typing import Callable, Iterator
44

55
import pytest
66
from distributed.client import Client
77

8-
from adaptive.learner import Learner1D, Learner2D
8+
from adaptive.learner import BaseLearner, Learner1D, Learner2D
99
from adaptive.runner import (
1010
AsyncRunner,
1111
BlockingRunner,
@@ -17,19 +17,19 @@
1717
)
1818

1919

20-
def blocking_runner(learner: Union[Learner1D, Learner2D], goal: Callable) -> None:
20+
def blocking_runner(learner: BaseLearner, goal: Callable) -> None:
2121
BlockingRunner(learner, goal, executor=SequentialExecutor())
2222

2323

24-
def async_runner(learner: Union[Learner1D, Learner2D], goal: Callable) -> None:
24+
def async_runner(learner: BaseLearner, goal: Callable) -> None:
2525
runner = AsyncRunner(learner, goal, executor=SequentialExecutor())
2626
asyncio.get_event_loop().run_until_complete(runner.task)
2727

2828

2929
runners = [simple, blocking_runner, async_runner]
3030

3131

32-
def trivial_goal(learner: Union[Learner1D, Learner2D]) -> bool:
32+
def trivial_goal(learner: BaseLearner) -> bool:
3333
return learner.npoints > 10
3434

3535

0 commit comments

Comments
 (0)