Skip to content

Commit fde1774

Browse files
jhoofwijkbasnijholt
authored andcommitted
change loss function signature
1 parent 2733128 commit fde1774

File tree

4 files changed

+107
-101
lines changed

4 files changed

+107
-101
lines changed

adaptive/learner/learner1D.py

Lines changed: 101 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -21,44 +21,42 @@ def uses_nth_neighbors(n):
2121
Wraps loss functions to indicate that they expect intervals together
2222
with ``n`` nearest neighbors
2323
24-
The loss function is then guaranteed to receive the data of at least the
25-
N nearest neighbors (``nth_neighbors``) in a dict that tells you what the
26-
neighboring points of these are. And the `~adaptive.Learner1D` will
27-
then make sure that the loss is updated whenever one of the
28-
``nth_neighbors`` changes.
24+
The loss function will then receive the data of the N nearest neighbors
25+
(``nth_neighbors``) aling with the data of the interval itself in a dict.
26+
The `~adaptive.Learner1D` will also make sure that the loss is updated
27+
whenever one of the ``nth_neighbors`` changes.
2928
3029
Examples
3130
--------
3231
33-
The next function is a part of the `get_curvature_loss` function.
32+
The next function is a part of the `curvature_loss_function` function.
3433
3534
>>> @uses_nth_neighbors(1)
36-
... def triangle_loss(interval, scale, data, neighbors):
37-
... x_left, x_right = interval
38-
... xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
39-
... # at the boundary, neighbors[<left boundary x>] is (None, <some other x>)
40-
... xs = [x for x in xs if x is not None]
41-
... if len(xs) <= 2:
42-
... return (x_right - x_left) / scale[0]
35+
...def triangle_loss(xs, ys):
36+
... xs = [x for x in xs if x is not None]
37+
... ys = [y for y in ys if y is not None]
4338
...
44-
... y_scale = scale[1] or 1
45-
... ys_scaled = [data[x] / y_scale for x in xs]
46-
... xs_scaled = [x / scale[0] for x in xs]
47-
... N = len(xs) - 2
48-
... pts = [(x, y) for x, y in zip(xs_scaled, ys_scaled)]
49-
... return sum(volume(pts[i:i+3]) for i in range(N)) / N
50-
51-
Or you may define a loss that favours the (local) minima of a function.
39+
... if len(xs) == 2: # we do not have enough points for a triangle
40+
... return xs[1] - xs[0]
41+
...
42+
... N = len(xs) - 2 # number of constructed triangles
43+
... if isinstance(ys[0], Iterable):
44+
... pts = [(x, *y) for x, y in zip(xs, ys)]
45+
... vol = simplex_volume_in_embedding
46+
... else:
47+
... pts = [(x, y) for x, y in zip(xs, ys)]
48+
... vol = volume
49+
... return sum(vol(pts[i:i+3]) for i in range(N)) / N
50+
51+
Or you may define a loss that favours the (local) minima of a function,
52+
assuming that you know your function will have a single float as output.
5253
5354
>>> @uses_nth_neighbors(1)
54-
... def local_minima_resolving_loss(interval, scale, data, neighbors):
55-
... x_left, x_right = interval
56-
... n_left = neighbors[x_left][0]
57-
... n_right = neighbors[x_right][1]
58-
... loss = (x_right - x_left) / scale[0]
55+
... def local_minima_resolving_loss(xs, ys):
56+
... dx = xs[2] - xs[1] # the width of the interval of interest
5957
...
60-
... if not ((n_left is not None and data[x_left] > data[n_left])
61-
... or (n_right is not None and data[x_right] > data[n_right])):
58+
... if not ((ys[0] is not None and ys[0] > ys[1])
59+
... or (ys[3] is not None and ys[3] > ys[2])):
6260
... return loss * 100
6361
...
6462
... return loss
@@ -68,9 +66,8 @@ def _wrapped(loss_per_interval):
6866
return loss_per_interval
6967
return _wrapped
7068

71-
7269
@uses_nth_neighbors(0)
73-
def uniform_loss(interval, scale, data, neighbors):
70+
def uniform_loss(xs, ys):
7471
"""Loss function that samples the domain uniformly.
7572
7673
Works with `~adaptive.Learner1D` only.
@@ -85,38 +82,36 @@ def uniform_loss(interval, scale, data, neighbors):
8582
... loss_per_interval=uniform_sampling_1d)
8683
>>>
8784
"""
88-
x_left, x_right = interval
89-
x_scale, _ = scale
90-
dx = (x_right - x_left) / x_scale
85+
dx = xs[1] - xs[0]
9186
return dx
9287

9388

9489
@uses_nth_neighbors(0)
95-
def default_loss(interval, scale, data, neighbors):
90+
def default_loss(xs, ys):
9691
"""Calculate loss on a single interval.
9792
9893
Currently returns the rescaled length of the interval. If one of the
9994
y-values is missing, returns 0 (so the intervals with missing data are
10095
never touched. This behavior should be improved later.
10196
"""
102-
x_left, x_right = interval
103-
y_right, y_left = data[x_right], data[x_left]
104-
x_scale, y_scale = scale
105-
dx = (x_right - x_left) / x_scale
106-
if y_scale == 0:
107-
loss = dx
97+
dx = xs[1] - xs[0]
98+
if isinstance(ys[0], Iterable):
99+
dy = [abs(a-b) for a, b in zip(*ys)]
100+
return np.hypot(dx, dy).max()
108101
else:
109-
dy = (y_right - y_left) / y_scale
110-
try:
111-
len(dy)
112-
loss = np.hypot(dx, dy).max()
113-
except TypeError:
114-
loss = math.hypot(dx, dy)
115-
return loss
102+
dy = ys[1] - ys[0]
103+
return np.hypot(dx, dy)
116104

117105

118-
def _loss_of_multi_interval(xs, ys):
119-
N = len(xs) - 2
106+
@uses_nth_neighbors(1)
107+
def triangle_loss(xs, ys):
108+
xs = [x for x in xs if x is not None]
109+
ys = [y for y in ys if y is not None]
110+
111+
if len(xs) == 2: # we do not have enough points for a triangle
112+
return xs[1] - xs[0]
113+
114+
N = len(xs) - 2 # number of constructed triangles
120115
if isinstance(ys[0], Iterable):
121116
pts = [(x, *y) for x, y in zip(xs, ys)]
122117
vol = simplex_volume_in_embedding
@@ -126,27 +121,15 @@ def _loss_of_multi_interval(xs, ys):
126121
return sum(vol(pts[i:i+3]) for i in range(N)) / N
127122

128123

129-
@uses_nth_neighbors(1)
130-
def triangle_loss(interval, scale, data, neighbors):
131-
x_left, x_right = interval
132-
xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
133-
xs = [x for x in xs if x is not None]
134-
135-
if len(xs) <= 2:
136-
return (x_right - x_left) / scale[0]
137-
else:
138-
y_scale = scale[1] or 1
139-
ys_scaled = [data[x] / y_scale for x in xs]
140-
xs_scaled = [x / scale[0] for x in xs]
141-
return _loss_of_multi_interval(xs_scaled, ys_scaled)
142-
143-
144-
def get_curvature_loss(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
124+
def curvature_loss_function(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
145125
@uses_nth_neighbors(1)
146-
def curvature_loss(interval, scale, data, neighbors):
147-
triangle_loss_ = triangle_loss(interval, scale, data, neighbors)
148-
default_loss_ = default_loss(interval, scale, data, neighbors)
149-
dx = (interval[1] - interval[0]) / scale[0]
126+
def curvature_loss(xs, ys):
127+
xs_middle = xs[1:3]
128+
ys_middle = xs[1:3]
129+
130+
triangle_loss_ = triangle_loss(xs, ys)
131+
default_loss_ = default_loss(xs_middle, ys_middle)
132+
dx = xs_middle[0] - xs_middle[0]
150133
return (area_factor * (triangle_loss_**0.5)
151134
+ euclid_factor * default_loss_
152135
+ horizontal_factor * dx)
@@ -209,29 +192,24 @@ class Learner1D(BaseLearner):
209192
210193
Notes
211194
-----
212-
`loss_per_interval` takes 4 parameters: ``interval``, ``scale``,
213-
``data``, and ``neighbors``, and returns a scalar; the loss over
214-
the interval.
215-
interval : (float, float)
216-
The bounds of the interval.
217-
scale : (float, float)
218-
The x and y scale over all the intervals, useful for rescaling the
219-
interval loss.
220-
data : dict(float → float)
221-
A map containing evaluated function values. It is guaranteed
222-
to have values for both of the points in 'interval'.
223-
neighbors : dict(float → (float, float))
224-
A map containing points as keys to its neighbors as a tuple.
225-
At the left ``x_left`` and right ``x_left`` most boundary it has
226-
``x_left: (None, float)`` and ``x_right: (float, None)``.
227-
228-
The `loss_per_interval` function should also have
229-
an attribute `nth_neighbors` that indicates how many of the neighboring
230-
intervals to `interval` are used. If `loss_per_interval` doesn't
231-
have such an attribute, it's assumed that is uses **no** neighboring
232-
intervals. Also see the `uses_nth_neighbors` decorator.
233-
**WARNING**: When modifying the `data` and `neighbors` datastructures
234-
the learner will behave in an undefined way.
195+
`loss_per_interval` takes 2 parameters: ``xs`` and ``ys``, and returns a
196+
scalar; the loss over the interval.
197+
xs : tuple of floats
198+
The x values of the interval, if `nth_neighbors` is greater than zero it
199+
also contains the x-values of the neighbors of the interval, in ascending
200+
order. The interval we want to know the loss of is then the middle
201+
interval. If no neighbor is available (at the edges of the domain) then
202+
`None` will take the place of the x-value of the neighbor.
203+
ys : tuple of function values
204+
The output values of the function when evaluated at the `xs`. This is
205+
either a float or a tuple of floats in the case of vector output.
206+
207+
208+
The `loss_per_interval` function may also have an attribute `nth_neighbors`
209+
that indicates how many of the neighboring intervals to `interval` are used.
210+
If `loss_per_interval` doesn't have such an attribute, it's assumed that is
211+
uses **no** neighboring intervals. Also see the `uses_nth_neighbors`
212+
decorator for more information.
235213
"""
236214

237215
def __init__(self, function, bounds, loss_per_interval=None):
@@ -300,16 +278,41 @@ def loss(self, real=True):
300278
losses = self.losses if real else self.losses_combined
301279
return max(losses.values()) if len(losses) > 0 else float('inf')
302280

281+
def _scale_x(self, x):
282+
if x is None:
283+
return None
284+
return x / self._scale[0]
285+
286+
def _scale_y(self, y):
287+
if y is None:
288+
return None
289+
y_scale = self._scale[1] or 1
290+
return y / y_scale
291+
292+
def _get_point_by_index(self, ind):
293+
if ind < 0 or ind >= len(self.neighbors):
294+
return None
295+
return self.neighbors.keys()[ind]
296+
303297
def _get_loss_in_interval(self, x_left, x_right):
304298
assert x_left is not None and x_right is not None
305299

306300
if x_right - x_left < self._dx_eps:
307301
return 0
308302

309-
# we need to compute the loss for this interval
310-
return self.loss_per_interval(
311-
(x_left, x_right), self._scale, self.data, self.neighbors)
303+
nn = self.nth_neighbors
304+
i = self.neighbors.index(x_left)
305+
start = i - nn
306+
end = i + nn + 2
307+
308+
xs = [self._get_point_by_index(i) for i in range(start, end)]
309+
ys = [self.data.get(x, None) for x in xs]
312310

311+
xs_scaled = tuple(self._scale_x(x) for x in xs)
312+
ys_scaled = tuple(self._scale_y(y) for y in ys)
313+
314+
# we need to compute the loss for this interval
315+
return self.loss_per_interval(xs_scaled, ys_scaled)
313316

314317
def _update_interpolated_loss_in_interval(self, x_left, x_right):
315318
if x_left is None or x_right is None:
@@ -419,6 +422,9 @@ def tell(self, x, y):
419422
if x in self.data:
420423
# The point is already evaluated before
421424
return
425+
if y is None:
426+
raise TypeError("Y-value may not be None, use learner.tell_pending(x)"
427+
"to indicate that this value is currently being calculated")
422428

423429
# either it is a float/int, if not, try casting to a np.array
424430
if not isinstance(y, (float, int)):

adaptive/tests/test_learner1d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from ..learner import Learner1D
7-
from ..learner.learner1D import get_curvature_loss
7+
from ..learner.learner1D import curvature_loss_function
88
from ..runner import simple
99

1010

@@ -347,7 +347,7 @@ def test_curvature_loss():
347347
def f(x):
348348
return np.tanh(20*x)
349349

350-
loss = get_curvature_loss()
350+
loss = curvature_loss_function()
351351
assert loss.nth_neighbors == 1
352352
learner = Learner1D(f, (-1, 1), loss_per_interval=loss)
353353
simple(learner, goal=lambda l: l.npoints > 100)
@@ -358,7 +358,7 @@ def test_curvature_loss_vectors():
358358
def f(x):
359359
return np.tanh(20*x), np.tanh(20*(x-0.4))
360360

361-
loss = get_curvature_loss()
361+
loss = curvature_loss_function()
362362
assert loss.nth_neighbors == 1
363363
learner = Learner1D(f, (-1, 1), loss_per_interval=loss)
364364
simple(learner, goal=lambda l: l.npoints > 100)

docs/source/reference/adaptive.learner.learner1D.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ Custom loss functions
1717

1818
.. autofunction:: adaptive.learner.learner1D.triangle_loss
1919

20-
.. autofunction:: adaptive.learner.learner1D.get_curvature_loss
20+
.. autofunction:: adaptive.learner.learner1D.curvature_loss_function

docs/source/tutorial/tutorial.Learner1D.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ by specifying ``loss_per_interval``.
150150

151151
.. jupyter-execute::
152152

153-
from adaptive.learner.learner1D import (get_curvature_loss,
153+
from adaptive.learner.learner1D import (curvature_loss_function,
154154
uniform_loss,
155155
default_loss)
156-
curvature_loss = get_curvature_loss()
156+
curvature_loss = curvature_loss_function()
157157
learner = adaptive.Learner1D(f, bounds=(-1, 1), loss_per_interval=curvature_loss)
158158
runner = adaptive.Runner(learner, goal=lambda l: l.loss() < 0.01)
159159

0 commit comments

Comments
 (0)