19
19
20
20
21
21
@uses_nth_neighbors (0 )
22
- def uniform_loss (
23
- xs : Union [Tuple [float , float ], Tuple [float , float ]],
24
- ys : Union [Tuple [float , float ], Tuple [float , float ]],
25
- ) -> Union [float , float ]:
22
+ def uniform_loss (xs : Tuple [float , float ], ys : Tuple [float , float ],) -> float :
26
23
"""Loss function that samples the domain uniformly.
27
24
28
25
Works with `~adaptive.Learner1D` only.
@@ -62,7 +59,7 @@ def default_loss(
62
59
63
60
64
61
@uses_nth_neighbors (1 )
65
- def triangle_loss (xs : Any , ys : Any ) -> Union [ float , float ] :
62
+ def triangle_loss (xs : Any , ys : Any ) -> float :
66
63
xs = [x for x in xs if x is not None ]
67
64
ys = [y for y in ys if y is not None ]
68
65
@@ -101,7 +98,7 @@ def curvature_loss(xs, ys):
101
98
102
99
103
100
def linspace (
104
- x_left : Union [int , float , float ], x_right : Union [int , float , float ], n : int ,
101
+ x_left : Union [int , float ], x_right : Union [int , float ], n : int ,
105
102
) -> Union [List [float ], List [float ]]:
106
103
"""This is equivalent to
107
104
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
@@ -125,7 +122,7 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
125
122
126
123
127
124
def _get_intervals (
128
- x : Union [int , float , float ], neighbors : SortedDict , nth_neighbors : int
125
+ x : Union [int , float ], neighbors : SortedDict , nth_neighbors : int
129
126
) -> Any :
130
127
nn = nth_neighbors
131
128
i = neighbors .index (x )
@@ -251,23 +248,21 @@ def npoints(self) -> int:
251
248
return len (self .data )
252
249
253
250
@cache_latest
254
- def loss (self , real : bool = True ) -> Union [int , float , float ]:
251
+ def loss (self , real : bool = True ) -> Union [int , float ]:
255
252
losses = self .losses if real else self .losses_combined
256
253
if not losses :
257
254
return np .inf
258
255
max_interval , max_loss = losses .peekitem (0 )
259
256
return max_loss
260
257
261
- def _scale_x (
262
- self , x : Optional [Union [float , int , float ]]
263
- ) -> Optional [Union [float , float ]]:
258
+ def _scale_x (self , x : Optional [Union [float , int ]]) -> Optional [float ]:
264
259
if x is None :
265
260
return None
266
261
return x / self ._scale [0 ]
267
262
268
263
def _scale_y (
269
264
self , y : Optional [Union [int , np .ndarray , float , float ]]
270
- ) -> Optional [Union [float , float , np .ndarray ]]:
265
+ ) -> Optional [Union [float , np .ndarray ]]:
271
266
if y is None :
272
267
return None
273
268
y_scale = self ._scale [1 ] or 1
@@ -279,8 +274,8 @@ def _get_point_by_index(self, ind: int) -> Optional[Union[int, float, float]]:
279
274
return self .neighbors .keys ()[ind ]
280
275
281
276
def _get_loss_in_interval (
282
- self , x_left : Union [int , float , float ], x_right : Union [int , float , float ],
283
- ) -> Union [int , float , float ]:
277
+ self , x_left : Union [int , float ], x_right : Union [int , float ],
278
+ ) -> Union [int , float ]:
284
279
assert x_left is not None and x_right is not None
285
280
286
281
if x_right - x_left < self ._dx_eps :
@@ -301,7 +296,7 @@ def _get_loss_in_interval(
301
296
return self .loss_per_interval (xs_scaled , ys_scaled )
302
297
303
298
def _update_interpolated_loss_in_interval (
304
- self , x_left : Union [int , float , float ], x_right : Union [int , float , float ],
299
+ self , x_left : Union [int , float ], x_right : Union [int , float ],
305
300
) -> None :
306
301
if x_left is None or x_right is None :
307
302
return
@@ -318,7 +313,7 @@ def _update_interpolated_loss_in_interval(
318
313
self .losses_combined [a , b ] = (b - a ) * loss / dx
319
314
a = b
320
315
321
- def _update_losses (self , x : Union [int , float , float ], real : bool = True ) -> None :
316
+ def _update_losses (self , x : Union [int , float ], real : bool = True ) -> None :
322
317
"""Update all losses that depend on x"""
323
318
# When we add a new point x, we should update the losses
324
319
# (x_left, x_right) are the "real" neighbors of 'x'.
@@ -361,7 +356,7 @@ def _update_losses(self, x: Union[int, float, float], real: bool = True) -> None
361
356
self .losses_combined [x , b ] = float ("inf" )
362
357
363
358
@staticmethod
364
- def _find_neighbors (x : Union [int , float , float ], neighbors : SortedDict ) -> Any :
359
+ def _find_neighbors (x : Union [int , float ], neighbors : SortedDict ) -> Any :
365
360
if x in neighbors :
366
361
return neighbors [x ]
367
362
pos = neighbors .bisect_left (x )
@@ -370,17 +365,15 @@ def _find_neighbors(x: Union[int, float, float], neighbors: SortedDict) -> Any:
370
365
x_right = keys [pos ] if pos != len (neighbors ) else None
371
366
return x_left , x_right
372
367
373
- def _update_neighbors (
374
- self , x : Union [int , float , float ], neighbors : SortedDict
375
- ) -> None :
368
+ def _update_neighbors (self , x : Union [int , float ], neighbors : SortedDict ) -> None :
376
369
if x not in neighbors : # The point is new
377
370
x_left , x_right = self ._find_neighbors (x , neighbors )
378
371
neighbors [x ] = [x_left , x_right ]
379
372
neighbors .get (x_left , [None , None ])[1 ] = x
380
373
neighbors .get (x_right , [None , None ])[0 ] = x
381
374
382
375
def _update_scale (
383
- self , x : Union [int , float , float ], y : Union [float , int , float , np .ndarray ],
376
+ self , x : Union [int , float ], y : Union [float , int , float , np .ndarray ],
384
377
) -> None :
385
378
"""Update the scale with which the x and y-values are scaled.
386
379
@@ -408,7 +401,7 @@ def _update_scale(
408
401
self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
409
402
self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
410
403
411
- def tell (self , x : Union [int , float , float ], y : Any ) -> None :
404
+ def tell (self , x : Union [int , float ], y : Any ) -> None :
412
405
if x in self .data :
413
406
# The point is already evaluated before
414
407
return
@@ -443,7 +436,7 @@ def tell(self, x: Union[int, float, float], y: Any) -> None:
443
436
444
437
self ._oldscale = deepcopy (self ._scale )
445
438
446
- def tell_pending (self , x : Union [int , float , float ]) -> None :
439
+ def tell_pending (self , x : Union [int , float ]) -> None :
447
440
if x in self .data :
448
441
# The point is already evaluated before
449
442
return
@@ -659,7 +652,7 @@ def _set_data(self, data: Dict[Union[int, float], float]) -> None:
659
652
self .tell_many (* zip (* data .items ()))
660
653
661
654
662
- def loss_manager (x_scale : Union [int , float , float ]) -> ItemSortedDict :
655
+ def loss_manager (x_scale : Union [int , float ]) -> ItemSortedDict :
663
656
def sort_key (ival , loss ):
664
657
loss , ival = finite_loss (ival , loss , x_scale )
665
658
return - loss , ival
@@ -668,9 +661,7 @@ def sort_key(ival, loss):
668
661
return sorted_dict
669
662
670
663
671
- def finite_loss (
672
- ival : Any , loss : Union [int , float , float ], x_scale : Union [int , float , float ],
673
- ) -> Any :
664
+ def finite_loss (ival : Any , loss : Union [int , float ], x_scale : Union [int , float ],) -> Any :
674
665
"""Get the socalled finite_loss of an interval in order to be able to
675
666
sort intervals that have infinite loss."""
676
667
# If the loss is infinite we return the
0 commit comments