|
2 | 2 | import itertools
|
3 | 3 | import math
|
4 | 4 | from copy import deepcopy
|
5 |
| -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union |
| 5 | +from typing import ( |
| 6 | + Any, |
| 7 | + Callable, |
| 8 | + Dict, |
| 9 | + Iterable, |
| 10 | + List, |
| 11 | + Literal, |
| 12 | + Optional, |
| 13 | + Sequence, |
| 14 | + Tuple, |
| 15 | + Union, |
| 16 | +) |
6 | 17 |
|
7 | 18 | import numpy as np
|
8 | 19 | from sortedcollections.recipes import ItemSortedDict
|
@@ -64,7 +75,11 @@ def abs_min_log_loss(xs, ys):
|
64 | 75 |
|
65 | 76 | @uses_nth_neighbors(1)
|
66 | 77 | def triangle_loss(
|
67 |
| - xs: Sequence[float], ys: Union[Iterable[float], Iterable[Iterable[float]]] |
| 78 | + xs: Sequence[Union[float, None]], |
| 79 | + ys: Union[ |
| 80 | + Iterable[Union[float, None]], |
| 81 | + Iterable[Union[Iterable[float], None]], |
| 82 | + ], |
68 | 83 | ) -> float:
|
69 | 84 | xs = [x for x in xs if x is not None]
|
70 | 85 | ys = [y for y in ys if y is not None]
|
@@ -399,7 +414,7 @@ def _update_scale(self, x: float, y: Union[float, np.ndarray]) -> None:
|
399 | 414 | self._bbox[1][1] = max(self._bbox[1][1], y)
|
400 | 415 | self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
|
401 | 416 |
|
402 |
| - def tell(self, x: float, y: Union[float, np.ndarray]) -> None: |
| 417 | + def tell(self, x: float, y: Union[float, Sequence[float], np.ndarray]) -> None: |
403 | 418 | if x in self.data:
|
404 | 419 | # The point is already evaluated before
|
405 | 420 | return
|
@@ -442,7 +457,7 @@ def tell_pending(self, x: float) -> None:
|
442 | 457 | self._update_neighbors(x, self.neighbors_combined)
|
443 | 458 | self._update_losses(x, real=False)
|
444 | 459 |
|
445 |
| - def tell_many(self, xs: List[float], ys: List[Any], *, force=False) -> None: |
| 460 | + def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> None: |
446 | 461 | if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
|
447 | 462 | # Only run this more efficient method if there are
|
448 | 463 | # at least 2 points and the amount of points added are
|
@@ -602,7 +617,7 @@ def _loss(self, mapping: ItemSortedDict, ival: Any) -> Any:
|
602 | 617 | loss = mapping[ival]
|
603 | 618 | return finite_loss(ival, loss, self._scale[0])
|
604 | 619 |
|
605 |
| - def plot(self, *, scatter_or_line="scatter"): |
| 620 | + def plot(self, *, scatter_or_line: Literal["scatter", "line"] = "scatter"): |
606 | 621 | """Returns a plot of the evaluated data.
|
607 | 622 |
|
608 | 623 | Parameters
|
|
0 commit comments