|
9 | 9 | import numpy as np
|
10 | 10 | from numpy import float64, int64
|
11 | 11 |
|
12 |
| -from adaptive.learner.average_learner import AverageLearner |
13 | 12 | from adaptive.learner.base_learner import BaseLearner
|
14 |
| -from adaptive.learner.learner1D import Learner1D |
15 |
| -from adaptive.learner.learner2D import Learner2D |
16 |
| -from adaptive.learner.learnerND import LearnerND |
17 |
| -from adaptive.learner.sequence_learner import SequenceLearner, _IgnoreFirstArgument |
18 | 13 | from adaptive.notebook_integration import ensure_holoviews
|
19 | 14 | from adaptive.utils import cache_latest, named_product, restore
|
20 | 15 |
|
21 | 16 |
|
22 |
| -def dispatch( |
23 |
| - child_functions: Union[List[Callable], List[partial], List[_IgnoreFirstArgument]], |
24 |
| - arg: Any, |
25 |
| -) -> Union[int, float64, float]: |
| 17 | +def dispatch(child_functions: List[Callable], arg: Any,) -> Union[int, float64, float]: |
26 | 18 | index, x = arg
|
27 | 19 | return child_functions[index](x)
|
28 | 20 |
|
@@ -79,17 +71,7 @@ class BalancingLearner(BaseLearner):
|
79 | 71 | """
|
80 | 72 |
|
81 | 73 | def __init__(
|
82 |
| - self, |
83 |
| - learners: Union[ |
84 |
| - List[SequenceLearner], |
85 |
| - List[AverageLearner], |
86 |
| - List[Learner2D], |
87 |
| - List[Learner1D], |
88 |
| - List[LearnerND], |
89 |
| - ], |
90 |
| - *, |
91 |
| - cdims=None, |
92 |
| - strategy="loss_improvements" |
| 74 | + self, learners: List[BaseLearner], *, cdims=None, strategy="loss_improvements" |
93 | 75 | ) -> None:
|
94 | 76 | self.learners = learners
|
95 | 77 |
|
@@ -246,7 +228,7 @@ def _ask_and_tell_based_on_cycle(
|
246 | 228 |
|
247 | 229 | return points, loss_improvements
|
248 | 230 |
|
249 |
| - def ask(self, n: int, tell_pending: bool = True) -> Any: |
| 231 | + def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[Any], List[float]]: |
250 | 232 | """Chose points for learners."""
|
251 | 233 | if n == 0:
|
252 | 234 | return [], []
|
@@ -369,7 +351,9 @@ def remove_unfinished(self):
|
369 | 351 | learner.remove_unfinished()
|
370 | 352 |
|
371 | 353 | @classmethod
|
372 |
| - def from_product(cls, f, learner_type, learner_kwargs, combos): |
| 354 | + def from_product( |
| 355 | + cls, f, learner_type, learner_kwargs, combos |
| 356 | + ) -> "BalancingLearner": |
373 | 357 | """Create a `BalancingLearner` with learners of all combinations of
|
374 | 358 | named variables’ values. The `cdims` will be set correctly, so calling
|
375 | 359 | `learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
|
|
0 commit comments