@@ -69,6 +69,28 @@ def _wrapped(loss_per_interval):
69
69
return _wrapped
70
70
71
71
72
+ def loss_returns (return_type , return_length ):
73
+ def _wrapped (loss_per_interval ):
74
+ loss_per_interval .return_type = return_type
75
+ loss_per_interval .return_length = return_length
76
+ return loss_per_interval
77
+ return _wrapped
78
+
79
+
80
+ def inf_format (return_type , return_len = None ):
81
+ is_iterable = hasattr (return_type , '__iter__' )
82
+ if is_iterable :
83
+ return return_type (return_len * [np .inf ])
84
+ else :
85
+ return return_type (np .inf )
86
+
87
+
88
+ def ensure_tuple (x ):
89
+ if not isinstance (x , Iterable ):
90
+ x = (x ,)
91
+ return x
92
+
93
+
72
94
@uses_nth_neighbors (0 )
73
95
def uniform_loss (xs , ys ):
74
96
"""Loss function that samples the domain uniformly.
@@ -287,7 +309,8 @@ def npoints(self):
287
309
def loss (self , real = True ):
288
310
losses = self .losses if real else self .losses_combined
289
311
if not losses :
290
- return np .inf
312
+ return inf_format (self .loss_per_interval .return_type ,
313
+ self .loss_per_interval .return_length )
291
314
max_interval , max_loss = losses .peekitem (0 )
292
315
return max_loss
293
316
@@ -325,7 +348,7 @@ def _get_loss_in_interval(self, x_left, x_right):
325
348
ys_scaled = tuple (self ._scale_y (y ) for y in ys )
326
349
327
350
# we need to compute the loss for this interval
328
- return self .loss_per_interval (xs_scaled , ys_scaled )
351
+ return ensure_tuple ( self .loss_per_interval (xs_scaled , ys_scaled ) )
329
352
330
353
def _update_interpolated_loss_in_interval (self , x_left , x_right ):
331
354
if x_left is None or x_right is None :
@@ -379,13 +402,17 @@ def _update_losses(self, x, real=True):
379
402
left_loss_is_unknown = ((x_left is None ) or
380
403
(not real and x_right is None ))
381
404
if (a is not None ) and left_loss_is_unknown :
382
- self .losses_combined [a , x ] = float ('inf' )
405
+ self .losses_combined [a , x ] = inf_format (
406
+ self .loss_per_interval .return_type ,
407
+ self .loss_per_interval .return_length )
383
408
384
409
# (no real point right of x) or (no real point left of b)
385
410
right_loss_is_unknown = ((x_right is None ) or
386
411
(not real and x_left is None ))
387
412
if (b is not None ) and right_loss_is_unknown :
388
- self .losses_combined [x , b ] = float ('inf' )
413
+ self .losses_combined [x , b ] = inf_format (
414
+ self .loss_per_interval .return_type ,
415
+ self .loss_per_interval .return_length )
389
416
390
417
@staticmethod
391
418
def _find_neighbors (x , neighbors ):
@@ -660,8 +687,8 @@ def _set_data(self, data):
660
687
661
688
def loss_manager (x_scale ):
662
689
def sort_key (ival , loss ):
663
- loss , ival = finite_loss (ival , loss , x_scale )
664
- return - loss , ival
690
+ loss = [ - finite_loss (ival , l , x_scale )[ 0 ] for l in loss ]
691
+ return loss , ival
665
692
sorted_dict = sortedcollections .ItemSortedDict (sort_key )
666
693
return sorted_dict
667
694
0 commit comments