20
20
21
21
@uses_nth_neighbors (0 )
22
22
def uniform_loss (
23
- xs : Union [Tuple [float , float ], Tuple [np . float64 , np . float64 ]],
24
- ys : Union [Tuple [float , float ], Tuple [np . float64 , np . float64 ]],
25
- ) -> Union [np . float64 , float ]:
23
+ xs : Union [Tuple [float , float ], Tuple [float , float ]],
24
+ ys : Union [Tuple [float , float ], Tuple [float , float ]],
25
+ ) -> Union [float , float ]:
26
26
"""Loss function that samples the domain uniformly.
27
27
28
28
Works with `~adaptive.Learner1D` only.
@@ -43,18 +43,9 @@ def uniform_loss(
43
43
44
44
@uses_nth_neighbors (0 )
45
45
def default_loss (
46
- xs : Union [
47
- Tuple [float , float ],
48
- Tuple [np .float64 , float ],
49
- Tuple [np .float64 , np .float64 ],
50
- Tuple [float , np .float64 ],
51
- ],
52
- ys : Union [
53
- Tuple [float , float ],
54
- Tuple [np .ndarray , np .ndarray ],
55
- Tuple [np .float64 , np .float64 ],
56
- ],
57
- ) -> np .float64 :
46
+ xs : Tuple [float , float ],
47
+ ys : Union [Tuple [np .ndarray , np .ndarray ], Tuple [float , float ]],
48
+ ) -> float :
58
49
"""Calculate loss on a single interval.
59
50
60
51
Currently returns the rescaled length of the interval. If one of the
@@ -71,7 +62,7 @@ def default_loss(
71
62
72
63
73
64
@uses_nth_neighbors (1 )
74
- def triangle_loss (xs : Any , ys : Any ) -> Union [np . float64 , float ]:
65
+ def triangle_loss (xs : Any , ys : Any ) -> Union [float , float ]:
75
66
xs = [x for x in xs if x is not None ]
76
67
ys = [y for y in ys if y is not None ]
77
68
@@ -110,10 +101,8 @@ def curvature_loss(xs, ys):
110
101
111
102
112
103
def linspace (
113
- x_left : Union [int , np .float64 , float ],
114
- x_right : Union [int , np .float64 , float ],
115
- n : int ,
116
- ) -> Union [List [float ], List [np .float64 ]]:
104
+ x_left : Union [int , float , float ], x_right : Union [int , float , float ], n : int ,
105
+ ) -> Union [List [float ], List [float ]]:
117
106
"""This is equivalent to
118
107
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
119
108
but it is 15-30 times faster for small 'n'."""
@@ -136,7 +125,7 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
136
125
137
126
138
127
def _get_intervals (
139
- x : Union [int , np . float64 , float ], neighbors : SortedDict , nth_neighbors : int
128
+ x : Union [int , float , float ], neighbors : SortedDict , nth_neighbors : int
140
129
) -> Any :
141
130
nn = nth_neighbors
142
131
i = neighbors .index (x )
@@ -262,38 +251,36 @@ def npoints(self) -> int:
262
251
return len (self .data )
263
252
264
253
@cache_latest
265
- def loss (self , real : bool = True ) -> Union [int , np . float64 , float ]:
254
+ def loss (self , real : bool = True ) -> Union [int , float , float ]:
266
255
losses = self .losses if real else self .losses_combined
267
256
if not losses :
268
257
return np .inf
269
258
max_interval , max_loss = losses .peekitem (0 )
270
259
return max_loss
271
260
272
261
def _scale_x (
273
- self , x : Optional [Union [float , int , np . float64 ]]
274
- ) -> Optional [Union [float , np . float64 ]]:
262
+ self , x : Optional [Union [float , int , float ]]
263
+ ) -> Optional [Union [float , float ]]:
275
264
if x is None :
276
265
return None
277
266
return x / self ._scale [0 ]
278
267
279
268
def _scale_y (
280
- self , y : Optional [Union [int , np .ndarray , np . float64 , float ]]
281
- ) -> Optional [Union [float , np . float64 , np .ndarray ]]:
269
+ self , y : Optional [Union [int , np .ndarray , float , float ]]
270
+ ) -> Optional [Union [float , float , np .ndarray ]]:
282
271
if y is None :
283
272
return None
284
273
y_scale = self ._scale [1 ] or 1
285
274
return y / y_scale
286
275
287
- def _get_point_by_index (self , ind : int ) -> Optional [Union [int , np . float64 , float ]]:
276
+ def _get_point_by_index (self , ind : int ) -> Optional [Union [int , float , float ]]:
288
277
if ind < 0 or ind >= len (self .neighbors ):
289
278
return None
290
279
return self .neighbors .keys ()[ind ]
291
280
292
281
def _get_loss_in_interval (
293
- self ,
294
- x_left : Union [int , np .float64 , float ],
295
- x_right : Union [int , np .float64 , float ],
296
- ) -> Union [int , np .float64 , float ]:
282
+ self , x_left : Union [int , float , float ], x_right : Union [int , float , float ],
283
+ ) -> Union [int , float , float ]:
297
284
assert x_left is not None and x_right is not None
298
285
299
286
if x_right - x_left < self ._dx_eps :
@@ -314,9 +301,7 @@ def _get_loss_in_interval(
314
301
return self .loss_per_interval (xs_scaled , ys_scaled )
315
302
316
303
def _update_interpolated_loss_in_interval (
317
- self ,
318
- x_left : Union [int , np .float64 , float ],
319
- x_right : Union [int , np .float64 , float ],
304
+ self , x_left : Union [int , float , float ], x_right : Union [int , float , float ],
320
305
) -> None :
321
306
if x_left is None or x_right is None :
322
307
return
@@ -333,9 +318,7 @@ def _update_interpolated_loss_in_interval(
333
318
self .losses_combined [a , b ] = (b - a ) * loss / dx
334
319
a = b
335
320
336
- def _update_losses (
337
- self , x : Union [int , np .float64 , float ], real : bool = True
338
- ) -> None :
321
+ def _update_losses (self , x : Union [int , float , float ], real : bool = True ) -> None :
339
322
"""Update all losses that depend on x"""
340
323
# When we add a new point x, we should update the losses
341
324
# (x_left, x_right) are the "real" neighbors of 'x'.
@@ -378,7 +361,7 @@ def _update_losses(
378
361
self .losses_combined [x , b ] = float ("inf" )
379
362
380
363
@staticmethod
381
- def _find_neighbors (x : Union [int , np . float64 , float ], neighbors : SortedDict ) -> Any :
364
+ def _find_neighbors (x : Union [int , float , float ], neighbors : SortedDict ) -> Any :
382
365
if x in neighbors :
383
366
return neighbors [x ]
384
367
pos = neighbors .bisect_left (x )
@@ -388,7 +371,7 @@ def _find_neighbors(x: Union[int, np.float64, float], neighbors: SortedDict) ->
388
371
return x_left , x_right
389
372
390
373
def _update_neighbors (
391
- self , x : Union [int , np . float64 , float ], neighbors : SortedDict
374
+ self , x : Union [int , float , float ], neighbors : SortedDict
392
375
) -> None :
393
376
if x not in neighbors : # The point is new
394
377
x_left , x_right = self ._find_neighbors (x , neighbors )
@@ -397,9 +380,7 @@ def _update_neighbors(
397
380
neighbors .get (x_right , [None , None ])[0 ] = x
398
381
399
382
def _update_scale (
400
- self ,
401
- x : Union [int , np .float64 , float ],
402
- y : Union [float , int , np .float64 , np .ndarray ],
383
+ self , x : Union [int , float , float ], y : Union [float , int , float , np .ndarray ],
403
384
) -> None :
404
385
"""Update the scale with which the x and y-values are scaled.
405
386
@@ -427,7 +408,7 @@ def _update_scale(
427
408
self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
428
409
self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
429
410
430
- def tell (self , x : Union [int , np . float64 , float ], y : Any ) -> None :
411
+ def tell (self , x : Union [int , float , float ], y : Any ) -> None :
431
412
if x in self .data :
432
413
# The point is already evaluated before
433
414
return
@@ -462,7 +443,7 @@ def tell(self, x: Union[int, np.float64, float], y: Any) -> None:
462
443
463
444
self ._oldscale = deepcopy (self ._scale )
464
445
465
- def tell_pending (self , x : Union [int , np . float64 , float ]) -> None :
446
+ def tell_pending (self , x : Union [int , float , float ]) -> None :
466
447
if x in self .data :
467
448
# The point is already evaluated before
468
449
return
@@ -678,7 +659,7 @@ def _set_data(self, data: Dict[Union[int, float], float]) -> None:
678
659
self .tell_many (* zip (* data .items ()))
679
660
680
661
681
- def loss_manager (x_scale : Union [int , np . float64 , float ]) -> ItemSortedDict :
662
+ def loss_manager (x_scale : Union [int , float , float ]) -> ItemSortedDict :
682
663
def sort_key (ival , loss ):
683
664
loss , ival = finite_loss (ival , loss , x_scale )
684
665
return - loss , ival
@@ -688,9 +669,7 @@ def sort_key(ival, loss):
688
669
689
670
690
671
def finite_loss (
691
- ival : Any ,
692
- loss : Union [int , np .float64 , float ],
693
- x_scale : Union [int , np .float64 , float ],
672
+ ival : Any , loss : Union [int , float , float ], x_scale : Union [int , float , float ],
694
673
) -> Any :
695
674
"""Get the socalled finite_loss of an interval in order to be able to
696
675
sort intervals that have infinite loss."""
0 commit comments