Skip to content

Commit 9137aca

Browse files
committed
remove double Unions of float
1 parent bdb85bf commit 9137aca

File tree

6 files changed

+29
-61
lines changed

6 files changed

+29
-61
lines changed

adaptive/learner/learner1D.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919

2020

2121
@uses_nth_neighbors(0)
22-
def uniform_loss(
23-
xs: Union[Tuple[float, float], Tuple[float, float]],
24-
ys: Union[Tuple[float, float], Tuple[float, float]],
25-
) -> Union[float, float]:
22+
def uniform_loss(xs: Tuple[float, float], ys: Tuple[float, float],) -> float:
2623
"""Loss function that samples the domain uniformly.
2724
2825
Works with `~adaptive.Learner1D` only.
@@ -62,7 +59,7 @@ def default_loss(
6259

6360

6461
@uses_nth_neighbors(1)
65-
def triangle_loss(xs: Any, ys: Any) -> Union[float, float]:
62+
def triangle_loss(xs: Any, ys: Any) -> float:
6663
xs = [x for x in xs if x is not None]
6764
ys = [y for y in ys if y is not None]
6865

@@ -101,7 +98,7 @@ def curvature_loss(xs, ys):
10198

10299

103100
def linspace(
104-
x_left: Union[int, float, float], x_right: Union[int, float, float], n: int,
101+
x_left: Union[int, float], x_right: Union[int, float], n: int,
105102
) -> Union[List[float], List[float]]:
106103
"""This is equivalent to
107104
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
@@ -125,7 +122,7 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
125122

126123

127124
def _get_intervals(
128-
x: Union[int, float, float], neighbors: SortedDict, nth_neighbors: int
125+
x: Union[int, float], neighbors: SortedDict, nth_neighbors: int
129126
) -> Any:
130127
nn = nth_neighbors
131128
i = neighbors.index(x)
@@ -251,23 +248,21 @@ def npoints(self) -> int:
251248
return len(self.data)
252249

253250
@cache_latest
254-
def loss(self, real: bool = True) -> Union[int, float, float]:
251+
def loss(self, real: bool = True) -> Union[int, float]:
255252
losses = self.losses if real else self.losses_combined
256253
if not losses:
257254
return np.inf
258255
max_interval, max_loss = losses.peekitem(0)
259256
return max_loss
260257

261-
def _scale_x(
262-
self, x: Optional[Union[float, int, float]]
263-
) -> Optional[Union[float, float]]:
258+
def _scale_x(self, x: Optional[Union[float, int]]) -> Optional[float]:
264259
if x is None:
265260
return None
266261
return x / self._scale[0]
267262

268263
def _scale_y(
269264
self, y: Optional[Union[int, np.ndarray, float, float]]
270-
) -> Optional[Union[float, float, np.ndarray]]:
265+
) -> Optional[Union[float, np.ndarray]]:
271266
if y is None:
272267
return None
273268
y_scale = self._scale[1] or 1
@@ -279,8 +274,8 @@ def _get_point_by_index(self, ind: int) -> Optional[Union[int, float, float]]:
279274
return self.neighbors.keys()[ind]
280275

281276
def _get_loss_in_interval(
282-
self, x_left: Union[int, float, float], x_right: Union[int, float, float],
283-
) -> Union[int, float, float]:
277+
self, x_left: Union[int, float], x_right: Union[int, float],
278+
) -> Union[int, float]:
284279
assert x_left is not None and x_right is not None
285280

286281
if x_right - x_left < self._dx_eps:
@@ -301,7 +296,7 @@ def _get_loss_in_interval(
301296
return self.loss_per_interval(xs_scaled, ys_scaled)
302297

303298
def _update_interpolated_loss_in_interval(
304-
self, x_left: Union[int, float, float], x_right: Union[int, float, float],
299+
self, x_left: Union[int, float], x_right: Union[int, float],
305300
) -> None:
306301
if x_left is None or x_right is None:
307302
return
@@ -318,7 +313,7 @@ def _update_interpolated_loss_in_interval(
318313
self.losses_combined[a, b] = (b - a) * loss / dx
319314
a = b
320315

321-
def _update_losses(self, x: Union[int, float, float], real: bool = True) -> None:
316+
def _update_losses(self, x: Union[int, float], real: bool = True) -> None:
322317
"""Update all losses that depend on x"""
323318
# When we add a new point x, we should update the losses
324319
# (x_left, x_right) are the "real" neighbors of 'x'.
@@ -361,7 +356,7 @@ def _update_losses(self, x: Union[int, float, float], real: bool = True) -> None
361356
self.losses_combined[x, b] = float("inf")
362357

363358
@staticmethod
364-
def _find_neighbors(x: Union[int, float, float], neighbors: SortedDict) -> Any:
359+
def _find_neighbors(x: Union[int, float], neighbors: SortedDict) -> Any:
365360
if x in neighbors:
366361
return neighbors[x]
367362
pos = neighbors.bisect_left(x)
@@ -370,17 +365,15 @@ def _find_neighbors(x: Union[int, float, float], neighbors: SortedDict) -> Any:
370365
x_right = keys[pos] if pos != len(neighbors) else None
371366
return x_left, x_right
372367

373-
def _update_neighbors(
374-
self, x: Union[int, float, float], neighbors: SortedDict
375-
) -> None:
368+
def _update_neighbors(self, x: Union[int, float], neighbors: SortedDict) -> None:
376369
if x not in neighbors: # The point is new
377370
x_left, x_right = self._find_neighbors(x, neighbors)
378371
neighbors[x] = [x_left, x_right]
379372
neighbors.get(x_left, [None, None])[1] = x
380373
neighbors.get(x_right, [None, None])[0] = x
381374

382375
def _update_scale(
383-
self, x: Union[int, float, float], y: Union[float, int, float, np.ndarray],
376+
self, x: Union[int, float], y: Union[float, int, float, np.ndarray],
384377
) -> None:
385378
"""Update the scale with which the x and y-values are scaled.
386379
@@ -408,7 +401,7 @@ def _update_scale(
408401
self._bbox[1][1] = max(self._bbox[1][1], y)
409402
self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
410403

411-
def tell(self, x: Union[int, float, float], y: Any) -> None:
404+
def tell(self, x: Union[int, float], y: Any) -> None:
412405
if x in self.data:
413406
# The point is already evaluated before
414407
return
@@ -443,7 +436,7 @@ def tell(self, x: Union[int, float, float], y: Any) -> None:
443436

444437
self._oldscale = deepcopy(self._scale)
445438

446-
def tell_pending(self, x: Union[int, float, float]) -> None:
439+
def tell_pending(self, x: Union[int, float]) -> None:
447440
if x in self.data:
448441
# The point is already evaluated before
449442
return
@@ -659,7 +652,7 @@ def _set_data(self, data: Dict[Union[int, float], float]) -> None:
659652
self.tell_many(*zip(*data.items()))
660653

661654

662-
def loss_manager(x_scale: Union[int, float, float]) -> ItemSortedDict:
655+
def loss_manager(x_scale: Union[int, float]) -> ItemSortedDict:
663656
def sort_key(ival, loss):
664657
loss, ival = finite_loss(ival, loss, x_scale)
665658
return -loss, ival
@@ -668,9 +661,7 @@ def sort_key(ival, loss):
668661
return sorted_dict
669662

670663

671-
def finite_loss(
672-
ival: Any, loss: Union[int, float, float], x_scale: Union[int, float, float],
673-
) -> Any:
664+
def finite_loss(ival: Any, loss: Union[int, float], x_scale: Union[int, float],) -> Any:
674665
"""Get the socalled finite_loss of an interval in order to be able to
675666
sort intervals that have infinite loss."""
676667
# If the loss is infinite we return the

adaptive/learner/learner2D.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -583,15 +583,7 @@ def tell_pending(
583583

584584
def _fill_stack(
585585
self, stack_till: int = 1
586-
) -> Union[
587-
Tuple[List[Tuple[float, float]], List[float]],
588-
Tuple[
589-
List[Union[Tuple[float, float], Tuple[float, float], Tuple[float, float]]],
590-
List[float],
591-
],
592-
Tuple[List[Union[Tuple[float, float], Tuple[float, float]]], List[float]],
593-
Tuple[List[Union[Tuple[float, float], Tuple[float, float]]], List[float]],
594-
]:
586+
) -> Tuple[List[Tuple[float, float]], List[float]]:
595587
if len(self.data) + len(self.pending_points) < self.ndim + 1:
596588
raise ValueError("too few points...")
597589

adaptive/learner/learnerND.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import random
44
from collections import OrderedDict
55
from collections.abc import Iterable
6-
from functools import partial
76
from typing import Any, Callable, List, Optional, Tuple, Union
87

98
import numpy as np
@@ -133,7 +132,7 @@ def triangle_loss(
133132
value_scale: float,
134133
neighbors: Union[List[Union[None, np.ndarray]], List[None], List[np.ndarray]],
135134
neighbor_values: Union[List[Union[None, float]], List[None], List[float]],
136-
) -> Union[int, float, float]:
135+
) -> Union[int, float]:
137136
"""
138137
Computes the average of the volumes of the simplex combined with each
139138
neighbouring point.
@@ -318,7 +317,7 @@ class LearnerND(BaseLearner):
318317

319318
def __init__(
320319
self,
321-
func: Union[Callable, partial],
320+
func: Callable,
322321
bounds: Union[
323322
Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]],
324323
np.ndarray,

adaptive/learner/sequence_learner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __call__(
2626
index_point: Union[Tuple[int, int], Tuple[int, float], Tuple[int, np.ndarray]],
2727
*args,
2828
**kwargs
29-
) -> Union[float, float]:
29+
) -> float:
3030
index, point = index_point
3131
return self.function(point, *args, **kwargs)
3232

@@ -127,7 +127,7 @@ def tell(
127127
Tuple[int, np.ndarray],
128128
Tuple[int, None],
129129
],
130-
value: Union[float, float],
130+
value: float,
131131
) -> None:
132132
index, point = point
133133
self.data[index] = value

adaptive/learner/skopt_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def remove_unfinished(self):
4949
pass
5050

5151
@cache_latest
52-
def loss(self, real: bool = True) -> Union[float, float]:
52+
def loss(self, real: bool = True) -> float:
5353
if not self.models:
5454
return np.inf
5555
else:

adaptive/tests/test_learners.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,13 @@ def maybe_skip(learner: Type[SKOptLearner]) -> Type[SKOptLearner]:
126126

127127

128128
@learn_with(Learner1D, bounds=(-1, 1))
129-
def quadratic(
130-
x: Union[int, float, float], m: uniform(0, 10), b: uniform(0, 1)
131-
) -> Union[float, float]:
129+
def quadratic(x: Union[int, float], m: uniform(0, 10), b: uniform(0, 1)) -> float:
132130
return m * x ** 2 + b
133131

134132

135133
@learn_with(Learner1D, bounds=(-1, 1))
136134
@learn_with(SequenceLearner, sequence=np.linspace(-1, 1, 201))
137-
def linear_with_peak(
138-
x: Union[int, float, float], d: uniform(-1, 1)
139-
) -> Union[float, float]:
135+
def linear_with_peak(x: Union[int, float], d: uniform(-1, 1)) -> float:
140136
a = 0.01
141137
return x + a ** 2 / (a ** 2 + (x - d) ** 2)
142138

@@ -153,7 +149,7 @@ def ring_of_fire(
153149
Tuple[float, float],
154150
],
155151
d: uniform(0.2, 1),
156-
) -> Union[float, float]:
152+
) -> float:
157153
a = 0.2
158154
x, y = xy
159155
return x + math.exp(-((x ** 2 + y ** 2 - d ** 2) ** 2) / a ** 4)
@@ -164,7 +160,7 @@ def ring_of_fire(
164160
def sphere_of_fire(
165161
xyz: Union[Tuple[float, float, float], Tuple[int, int, int], np.ndarray],
166162
d: uniform(0.2, 1),
167-
) -> Union[float, float]:
163+
) -> float:
168164
a = 0.2
169165
x, y, z = xyz
170166
return x + math.exp(-((x ** 2 + y ** 2 + z ** 2 - d ** 2) ** 2) / a ** 4) + z ** 2
@@ -224,17 +220,7 @@ def ask_randomly(
224220
Tuple[List[Union[Tuple[float, float, float], Tuple[int, int, int]]], List[float]],
225221
Tuple[List[Union[Tuple[float, float], Tuple[int, int]]], List[float]],
226222
Tuple[List[float], List[float]],
227-
Tuple[
228-
List[
229-
Union[
230-
Tuple[int, int],
231-
Tuple[float, float],
232-
Tuple[float, float],
233-
Tuple[float, float],
234-
]
235-
],
236-
List[Union[float, float]],
237-
],
223+
Tuple[List[Union[Tuple[int, int], Tuple[float, float]]], List[float]],
238224
]:
239225
n_rounds = random.randrange(*rounds)
240226
n_points = [random.randrange(*points) for _ in range(n_rounds)]

0 commit comments

Comments
 (0)