Skip to content

Commit 998c06c

Browse files
committed
Fix typeguard issues
1 parent 84a0af6 commit 998c06c

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

adaptive/learner/average_learner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from adaptive.learner.base_learner import BaseLearner
1010
from adaptive.notebook_integration import ensure_holoviews
11-
from adaptive.types import Float, Real
11+
from adaptive.types import Float, Int, Real
1212
from adaptive.utils import (
1313
assign_defaults,
1414
cache_latest,
@@ -127,7 +127,7 @@ def ask(self, n: int, tell_pending: bool = True) -> tuple[list[int], list[Float]
127127
self.tell_pending(p)
128128
return points, loss_improvements
129129

130-
def tell(self, n: int, value: Real) -> None:
130+
def tell(self, n: Int, value: Real) -> None:
131131
if n in self.data:
132132
# The point has already been added before.
133133
return

adaptive/learner/average_learner1D.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from adaptive.learner.learner1D import Learner1D, _get_intervals
1616
from adaptive.notebook_integration import ensure_holoviews
17-
from adaptive.types import Real
17+
from adaptive.types import Int, Real
1818
from adaptive.utils import assign_defaults, partial_function_from_dataframe
1919

2020
try:
@@ -183,7 +183,10 @@ def load_dataframe(
183183
x_name: str = "x",
184184
y_name: str = "y",
185185
):
186-
self.tell_many(df[[seed_name, x_name]].values, df[y_name].values)
186+
# Were using zip instead of df[[seed_name, x_name]].values because that will
187+
# make the seeds into floats
188+
seed_x = list(zip(df[seed_name].values.tolist(), df[x_name].values.tolist()))
189+
self.tell_many(seed_x, df[y_name].values)
187190
if with_default_function_args:
188191
self.function = partial_function_from_dataframe(
189192
self.function, df, function_prefix
@@ -424,7 +427,9 @@ def _calc_error_in_mean(self, ys: Iterable[Real], y_avg: Real, n: int) -> float:
424427
t_student = scipy.stats.t.ppf(1 - self.alpha, df=n - 1)
425428
return t_student * (variance_in_mean / n) ** 0.5
426429

427-
def tell_many(self, xs: Points, ys: Sequence[Real]) -> None:
430+
def tell_many(
431+
self, xs: Points | np.ndarray, ys: Sequence[Real] | np.ndarray
432+
) -> None:
428433
# Check that all x are within the bounds
429434
# TODO: remove this requirement, all other learners add the data
430435
# but ignore it going forward.
@@ -435,7 +440,7 @@ def tell_many(self, xs: Points, ys: Sequence[Real]) -> None:
435440
)
436441

437442
# Create a mapping of points to a list of samples
438-
mapping: DefaultDict[Real, DefaultDict[int, Real]] = defaultdict(
443+
mapping: DefaultDict[Real, DefaultDict[Int, Real]] = defaultdict(
439444
lambda: defaultdict(dict)
440445
)
441446
for (seed, x), y in zip(xs, ys):

adaptive/learner/learner1D.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,11 @@ def tell_pending(self, x: float) -> None:
563563

564564
def tell_many(
565565
self,
566-
xs: Sequence[Float],
567-
ys: (Sequence[Float] | Sequence[Sequence[Float]] | Sequence[np.ndarray]),
566+
xs: Sequence[Float] | np.ndarray,
567+
ys: Sequence[Float]
568+
| Sequence[Sequence[Float]]
569+
| Sequence[np.ndarray]
570+
| np.ndarray,
568571
*,
569572
force: bool = False,
570573
) -> None:

0 commit comments

Comments
 (0)