Skip to content

Commit 98979e4

Browse files
committed
WIP: allow the loss to be a tuple in the Learner1D
1 parent 0e1aa7c commit 98979e4

File tree

1 file changed

+33
-6
lines changed

1 file changed

+33
-6
lines changed

adaptive/learner/learner1D.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,28 @@ def _wrapped(loss_per_interval):
6969
return _wrapped
7070

7171

72+
def loss_returns(return_type, return_length):
73+
def _wrapped(loss_per_interval):
74+
loss_per_interval.return_type = return_type
75+
loss_per_interval.return_length = return_length
76+
return loss_per_interval
77+
return _wrapped
78+
79+
80+
def inf_format(return_type, return_len=None):
81+
is_iterable = hasattr(return_type, '__iter__')
82+
if is_iterable:
83+
return return_type(return_len * [np.inf])
84+
else:
85+
return return_type(np.inf)
86+
87+
88+
def ensure_tuple(x):
89+
if not isinstance(x, Iterable):
90+
x = (x,)
91+
return x
92+
93+
7294
@uses_nth_neighbors(0)
7395
def uniform_loss(xs, ys):
7496
"""Loss function that samples the domain uniformly.
@@ -287,7 +309,8 @@ def npoints(self):
287309
def loss(self, real=True):
288310
losses = self.losses if real else self.losses_combined
289311
if not losses:
290-
return np.inf
312+
return inf_format(self.loss_per_interval.return_type,
313+
self.loss_per_interval.return_length)
291314
max_interval, max_loss = losses.peekitem(0)
292315
return max_loss
293316

@@ -325,7 +348,7 @@ def _get_loss_in_interval(self, x_left, x_right):
325348
ys_scaled = tuple(self._scale_y(y) for y in ys)
326349

327350
# we need to compute the loss for this interval
328-
return self.loss_per_interval(xs_scaled, ys_scaled)
351+
return ensure_tuple(self.loss_per_interval(xs_scaled, ys_scaled))
329352

330353
def _update_interpolated_loss_in_interval(self, x_left, x_right):
331354
if x_left is None or x_right is None:
@@ -379,13 +402,17 @@ def _update_losses(self, x, real=True):
379402
left_loss_is_unknown = ((x_left is None) or
380403
(not real and x_right is None))
381404
if (a is not None) and left_loss_is_unknown:
382-
self.losses_combined[a, x] = float('inf')
405+
self.losses_combined[a, x] = inf_format(
406+
self.loss_per_interval.return_type,
407+
self.loss_per_interval.return_length)
383408

384409
# (no real point right of x) or (no real point left of b)
385410
right_loss_is_unknown = ((x_right is None) or
386411
(not real and x_left is None))
387412
if (b is not None) and right_loss_is_unknown:
388-
self.losses_combined[x, b] = float('inf')
413+
self.losses_combined[x, b] = inf_format(
414+
self.loss_per_interval.return_type,
415+
self.loss_per_interval.return_length)
389416

390417
@staticmethod
391418
def _find_neighbors(x, neighbors):
@@ -660,8 +687,8 @@ def _set_data(self, data):
660687

661688
def loss_manager(x_scale):
662689
def sort_key(ival, loss):
663-
loss, ival = finite_loss(ival, loss, x_scale)
664-
return -loss, ival
690+
loss = [-finite_loss(ival, l, x_scale)[0] for l in loss]
691+
return loss, ival
665692
sorted_dict = sortedcollections.ItemSortedDict(sort_key)
666693
return sorted_dict
667694

0 commit comments

Comments
 (0)