Skip to content

Commit 18db3a3

Browse files
committed
type hint fixes for adaptive/learner/balancing_learner.py
1 parent 2f03de1 commit 18db3a3

File tree

1 file changed

+6
-22
lines changed

1 file changed

+6
-22
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,12 @@
99
import numpy as np
1010
from numpy import float64, int64
1111

12-
from adaptive.learner.average_learner import AverageLearner
1312
from adaptive.learner.base_learner import BaseLearner
14-
from adaptive.learner.learner1D import Learner1D
15-
from adaptive.learner.learner2D import Learner2D
16-
from adaptive.learner.learnerND import LearnerND
17-
from adaptive.learner.sequence_learner import SequenceLearner, _IgnoreFirstArgument
1813
from adaptive.notebook_integration import ensure_holoviews
1914
from adaptive.utils import cache_latest, named_product, restore
2015

2116

22-
def dispatch(
23-
child_functions: Union[List[Callable], List[partial], List[_IgnoreFirstArgument]],
24-
arg: Any,
25-
) -> Union[int, float64, float]:
17+
def dispatch(child_functions: List[Callable], arg: Any,) -> Union[int, float64, float]:
2618
index, x = arg
2719
return child_functions[index](x)
2820

@@ -79,17 +71,7 @@ class BalancingLearner(BaseLearner):
7971
"""
8072

8173
def __init__(
82-
self,
83-
learners: Union[
84-
List[SequenceLearner],
85-
List[AverageLearner],
86-
List[Learner2D],
87-
List[Learner1D],
88-
List[LearnerND],
89-
],
90-
*,
91-
cdims=None,
92-
strategy="loss_improvements"
74+
self, learners: List[BaseLearner], *, cdims=None, strategy="loss_improvements"
9375
) -> None:
9476
self.learners = learners
9577

@@ -246,7 +228,7 @@ def _ask_and_tell_based_on_cycle(
246228

247229
return points, loss_improvements
248230

249-
def ask(self, n: int, tell_pending: bool = True) -> Any:
231+
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[Any], List[float]]:
250232
"""Chose points for learners."""
251233
if n == 0:
252234
return [], []
@@ -369,7 +351,9 @@ def remove_unfinished(self):
369351
learner.remove_unfinished()
370352

371353
@classmethod
372-
def from_product(cls, f, learner_type, learner_kwargs, combos):
354+
def from_product(
355+
cls, f, learner_type, learner_kwargs, combos
356+
) -> "BalancingLearner":
373357
"""Create a `BalancingLearner` with learners of all combinations of
374358
named variables’ values. The `cdims` will be set correctly, so calling
375359
`learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.

0 commit comments

Comments
 (0)