Skip to content

Commit a5b50bc

Browse files
committed
pass pytest --typeguard-packages=adaptive adaptive/tests/test_learner1d.py
1 parent 4d54b6b commit a5b50bc

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

adaptive/learner/learner1D.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,18 @@
22
import itertools
33
import math
44
from copy import deepcopy
5-
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
5+
from typing import (
6+
Any,
7+
Callable,
8+
Dict,
9+
Iterable,
10+
List,
11+
Literal,
12+
Optional,
13+
Sequence,
14+
Tuple,
15+
Union,
16+
)
617

718
import numpy as np
819
from sortedcollections.recipes import ItemSortedDict
@@ -64,7 +75,11 @@ def abs_min_log_loss(xs, ys):
6475

6576
@uses_nth_neighbors(1)
6677
def triangle_loss(
67-
xs: Sequence[float], ys: Union[Iterable[float], Iterable[Iterable[float]]]
78+
xs: Sequence[Union[float, None]],
79+
ys: Union[
80+
Iterable[Union[float, None]],
81+
Iterable[Union[Iterable[float], None]],
82+
],
6883
) -> float:
6984
xs = [x for x in xs if x is not None]
7085
ys = [y for y in ys if y is not None]
@@ -399,7 +414,7 @@ def _update_scale(self, x: float, y: Union[float, np.ndarray]) -> None:
399414
self._bbox[1][1] = max(self._bbox[1][1], y)
400415
self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
401416

402-
def tell(self, x: float, y: Union[float, np.ndarray]) -> None:
417+
def tell(self, x: float, y: Union[float, Sequence[float], np.ndarray]) -> None:
403418
if x in self.data:
404419
# The point is already evaluated before
405420
return
@@ -442,7 +457,7 @@ def tell_pending(self, x: float) -> None:
442457
self._update_neighbors(x, self.neighbors_combined)
443458
self._update_losses(x, real=False)
444459

445-
def tell_many(self, xs: List[float], ys: List[Any], *, force=False) -> None:
460+
def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> None:
446461
if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
447462
# Only run this more efficient method if there are
448463
# at least 2 points and the amount of points added are
@@ -602,7 +617,7 @@ def _loss(self, mapping: ItemSortedDict, ival: Any) -> Any:
602617
loss = mapping[ival]
603618
return finite_loss(ival, loss, self._scale[0])
604619

605-
def plot(self, *, scatter_or_line="scatter"):
620+
def plot(self, *, scatter_or_line: Literal["scatter", "line"] = "scatter"):
606621
"""Returns a plot of the evaluated data.
607622
608623
Parameters

0 commit comments

Comments
 (0)